In [1]:
import torch
from torch import nn
from torchvision import models, transforms
import torchvision
import os
from time import time
import onnx
import onnxruntime
from onnx import numpy_helper
import numpy as np
from matplotlib import pyplot as plt
import cv2

import os    
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
PATH = "models/"
input_size = 224
batch_size = 128
n_channel = 3
test_size = 5

In [3]:
label_colors_list = [
        (64, 128, 64), # animal
        (192, 0, 128), # archway
        (0, 128, 192), # bicyclist
        (0, 128, 64), #bridge
        (128, 0, 0), # building
        (64, 0, 128), #car
        (64, 0, 192), # car luggage pram...???...
        (192, 128, 64), # child
        (192, 192, 128), # column pole
        (64, 64, 128), # fence
        (128, 0, 192), # lane marking driving
        (192, 0, 64), # lane maring non driving
        (128, 128, 64), # misc text
        (192, 0, 192), # motor cycle scooter
        (128, 64, 64), # other moving
        (64, 192, 128), # parking block
        (64, 64, 0), # pedestrian
        (128, 64, 128), # road
        (128, 128, 192), # road shoulder
        (0, 0, 192), # sidewalk
        (192, 128, 128), # sign symbol
        (128, 128, 128), # sky
        (64, 128, 192), # suv pickup truck
        (0, 0, 64), # traffic cone
        (0, 64, 64), # traffic light
        (192, 64, 128), # train
        (128, 128, 0), # tree
        (192, 128, 192), # truck/bus
        (64, 0, 64), # tunnel
        (192, 192, 0), # vegetation misc.
        (0, 0, 0),  # 0=background/void
        (64, 192, 0), # wall
    ]

CLASSES_TO_TRAIN = [
        'animal', 'archway', 'bicyclist', 'bridge', 'building', 'car', 
        'cartluggagepram', 'child', 'columnpole', 'fence', 'lanemarkingdrve', 
        'lanemarkingnondrve', 'misctext', 'motorcyclescooter', 'othermoving',
        'parkingblock', 'pedestrian', 'road', 'road shoulder', 'sidewalk',
        'signsymbol', 'sky', 'suvpickuptruck', 'trafficcone', 'trafficlight', 
        'train', 'tree', 'truckbase', 'tunnel', 'vegetationmisc', 'void',
        'wall'
        ]

ALL_CLASSES = ['animal', 'archway', 'bicyclist', 'bridge', 'building', 'car', 
        'cartluggagepram', 'child', 'columnpole', 'fence', 'lanemarkingdrve', 
        'lanemarkingnondrve', 'misctext', 'motorcyclescooter', 'othermoving', 
        'parkingblock', 'pedestrian', 'road', 'road shoulder', 'sidewalk', 
        'signsymbol', 'sky', 'suvpickuptruck', 'trafficcone', 'trafficlight', 
        'train', 'tree', 'truckbase', 'tunnel', 'vegetationmisc', 'void',
        'wall']

class_values = [ALL_CLASSES.index(cls.lower()) for cls in CLASSES_TO_TRAIN]

In [4]:
def get_test_data(dataloader, size):
    X_test, Y_test = next(iter(dataloader))
    batch_size = len(X_test)
    n = size//batch_size
    for i, batch in enumerate(dataloader):
        if i < n:
            X_tmp, Y_tmp = batch
            X_test = torch.cat((X_test, X_tmp), 0)
            Y_test = torch.cat((Y_test, Y_tmp), 0)
    return X_test, Y_test

def draw_test_segmentation_map(outputs):
    """
    This function will apply color mask as per the output that we
    get when executing `test.py` or `test_vid.py` on a single image 
    or a video. NOT TO BE USED WHILE TRAINING OR VALIDATING.
    """
    labels = torch.argmax(outputs.squeeze(), dim=0).detach().cpu().numpy()
    red_map = np.zeros_like(labels).astype(np.uint8)
    green_map = np.zeros_like(labels).astype(np.uint8)
    blue_map = np.zeros_like(labels).astype(np.uint8)
    
    for label_num in range(0, len(label_colors_list)):
        if label_num in class_values:
            idx = labels == label_num
            red_map[idx] = np.array(label_colors_list)[label_num, 0]
            green_map[idx] = np.array(label_colors_list)[label_num, 1]
            blue_map[idx] = np.array(label_colors_list)[label_num, 2]
        
    segmented_image = np.stack([red_map, green_map, blue_map], axis=2)
    return segmented_image

def image_overlay(image, segmented_image):
    """
    This function will apply an overlay of the output segmentation
    map on top of the orifinal input image. MAINLY TO BE USED WHEN
    EXECUTING `test.py` or `test_vid.py`.
    """
    alpha = 0.6 # how much transparency to apply
    beta = 1 - alpha # alpha + beta should equal 1
    gamma = 0 # scalar added to each sum
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)
    cv2.addWeighted(segmented_image, alpha, image, beta, gamma, image)
    return image

In [5]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
transformfcnTrain=transforms.Compose([ # Cette fois on utilise pas de grayscale car nous avons un gros modele pré-entrainé
        transforms.RandomResizedCrop(input_size), # selection aléatoire d'une zone de la taille voulue (augmentation des données en apprentissage)
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
transformfcnTest=transforms.Compose([
        transforms.Resize(input_size), # selection de la zone centrale de la taille voulue
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
fcn_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformfcnTrain)
fcn_trainloader = torch.utils.data.DataLoader(fcn_trainset, batch_size=batch_size, pin_memory=True, shuffle=True)
fcn_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformfcnTest)
fcn_testloader = torch.utils.data.DataLoader(fcn_testset, batch_size=batch_size, pin_memory=True, shuffle=True)
X_test, Y_test = get_test_data(fcn_testloader, 300)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
model = models.segmentation.fcn_resnet50(pretrained=True)
model.load_state_dict(torch.load(os.path.join(PATH,"fcnresnet.pth"), map_location='cpu'))
model.eval()
dummy_input = torch.randn(test_size, n_channel, input_size, input_size)  
torch.onnx.export(model,   
                  dummy_input, 
                  str(PATH+"fcnresnet.onnx"),
                  export_params=True,
                  do_constant_folding=True, 
                  input_names = ['modelInput'],
                  output_names = ['modelOutput'])
X_test = X_test[:test_size]
sess = onnxruntime.InferenceSession(str(PATH+"fcnresnet.onnx"))
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

t0 = time()
pred = sess.run([output_name], {input_name: np.array(X_test).astype(np.float32)})[0]
print("Time for ",test_size," images with ONNX inference", (time() - t0))
print(pred.shape)

ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).
We recommend using opset 11 and above for models using this operator.


Time for  5  images with ONNX inference 0.6331841945648193
(5, 21, 224, 224)


In [13]:
segmented_image = draw_test_segmentation_map(model(X_test)["out"])[0]
print(segmented_image.shape)
segmented_image = segmented_image.transpose((1, 0, 2))
print(segmented_image.shape)
print(X_test[0].shape)

(224, 3, 224)
(3, 224, 224)
torch.Size([3, 224, 224])


In [14]:
cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)
final_image = image_overlay(X_test[0], segmented_image)
cv2.imshow('image', final_image)

error: OpenCV(4.5.1) c:\users\appveyor\appdata\local\temp\1\pip-req-build-kh7iq4w7\opencv\modules\imgproc\src\color.simd_helpers.hpp:92: error: (-2:Unspecified error) in function '__cdecl cv::impl::`anonymous-namespace'::CvtHelper<struct cv::impl::`anonymous namespace'::Set<3,4,-1>,struct cv::impl::A0x206ccf44::Set<3,4,-1>,struct cv::impl::A0x206ccf44::Set<0,2,5>,2>::CvtHelper(const class cv::_InputArray &,const class cv::_OutputArray &,int)'
> Invalid number of channels in input image:
>     'VScn::contains(scn)'
> where
>     'scn' is 224


In [None]:
# test = pred[0]
# plt.imshow(X_test[0].mean(axis=0))
# plt.show()

# test_sum = np.median(test, axis=0)
# plt.imshow(test_sum)
# plt.show()