In [1]:
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from utils.data_load import KittiDataset
from model.ensemblenet_model import EnsembleNet
import time
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.backends.cudnn.enabled = False

In [4]:
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()
    #img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
    img = torch.from_numpy(KittiDataset.preprocess(None, full_img, Image_Size, scale_factor, is_mask=False))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)
    
    with torch.no_grad():
        #output = net(img).cpu()
        output = net(img)
        
        if 1 == len(output): 
            output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
            if net.n_classes > 1:
                mask = output.argmax(dim=1)
            else:
                mask = torch.sigmoid(output) > out_threshold
                
        elif 3 == len(output):
            out_1 = F.interpolate(output[0], (full_img.size[1], full_img.size[0]), mode='bilinear')
            out_2 = F.interpolate(output[1], (full_img.size[1], full_img.size[0]), mode='bilinear')
            out_3 = F.interpolate(output[2], (full_img.size[1], full_img.size[0]), mode='bilinear')

            if net.n_classes > 1:
                vot = (F.softmax(out_1, dim=1) + F.softmax(out_2, dim=1) + F.softmax(out_3, dim=1)) / 3.0
                mask = vot.argmax(dim=1)
            else:
                #vot = (F.softmax(out_1, dim=1) + F.softmax(out_2, dim=1) + F.softmax(out_3, dim=1)) / 3.0
                print('error')

    return mask[0].cpu().long().squeeze().numpy()

In [5]:
Num_Class = 2
Num_Channel = 3
Model_Name = 'ensemble_voting'
Scale = 0.5
Threshold = 0.5
Image_Size = [384, 1216]

imgdir = 'data/data_road/testing/image_2/'
in_files = os.listdir(imgdir)

### E-Net

In [4]:
#net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
net = EnsembleNet(Model_Name, Num_Channel, Num_Class)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net.to(device=device)
state_dict = torch.load('../trained_enet/checkpoint_epoch51.pth', map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict)

print('model loaded!')

start = time.time()
for i, filename in enumerate(in_files):
    img = Image.open(imgdir + filename)

    mask = predict_img(net=net,
                       full_img=img,
                       scale_factor=Scale,
                       out_threshold=Threshold,
                       device=device)
print('{} Predict Image 290 total Inference Time : {}'.format(Model_Name, time.time() - start))

model loaded!
enet Predict Image 290 total Inference Time : 19.632378339767456


### U-Net

In [9]:
net = EnsembleNet(Model_Name, Num_Channel, Num_Class)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net.to(device=device)
state_dict = torch.load('../trained_unet/checkpoint_epoch51.pth', map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict)

print('model loaded!')

start = time.time()
for i, filename in enumerate(in_files):
    img = Image.open(imgdir + filename)

    mask = predict_img(net=net,
                       full_img=img,
                       scale_factor=Scale,
                       out_threshold=Threshold,
                       device=device)
print('{} Predict Image 290 total Inference Time : {}'.format(Model_Name, time.time() - start))

model loaded!
unet Predict Image 290 total Inference Time : 13.551630020141602


### SegNet

In [4]:
net = EnsembleNet(Model_Name, Num_Channel, Num_Class)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net.to(device=device)
state_dict = torch.load('../trained_segnet/checkpoint_epoch51.pth', map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict)

print('model loaded!')

start = time.time()
for i, filename in enumerate(in_files):
    img = Image.open(imgdir + filename)

    mask = predict_img(net=net,
                       full_img=img,
                       scale_factor=Scale,
                       out_threshold=Threshold,
                       device=device)
print('{} Predict Image 290 total Inference Time : {}'.format(Model_Name, time.time() - start))

model loaded!
segnet Predict Image 290 total Inference Time : 22.73341178894043


### Ensemble-Net Fusion

In [6]:
net = EnsembleNet(Model_Name, Num_Channel, Num_Class)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net.to(device=device)
state_dict = torch.load('../trained_ensemble_fusion/checkpoint_epoch40.pth', map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict, strict = False)

print('model loaded!')

start = time.time()
for i, filename in enumerate(in_files):
    img = Image.open(imgdir + filename)

    mask = predict_img(net=net,
                       full_img=img,
                       scale_factor=Scale,
                       out_threshold=Threshold,
                       device=device)
print('{} Predict Image 290 total Inference Time : {}'.format(Model_Name, time.time() - start))

model loaded!
ensemble_voting Predict Image 290 total Inference Time : 36.52558779716492


### Ensemble-Net voting

In [4]:
net = EnsembleNet(Model_Name, Num_Channel, Num_Class)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net.to(device=device)
state_dict = torch.load('../trained_ensemble_voting/checkpoint_epoch2.pth', map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict, strict = False)

print('model loaded!')

start = time.time()
for i, filename in enumerate(in_files):
    img = Image.open(imgdir + filename)

    mask = predict_img(net=net,
                       full_img=img,
                       scale_factor=Scale,
                       out_threshold=Threshold,
                       device=device)

print('{} Predict Image 290 total Inference Time : {}'.format(Model_Name, time.time() - start))

model loaded!
ensemble_voting Predict Image 290 total Inference Time : 35.61333513259888
