In [2]:
import matplotlib.pyplot as plt
import torch
import argparse
import os
import cv2
import sys
import numpy as np
from collections import OrderedDict
from models import get_model

ModuleNotFoundError: No module named 'torch'

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
def convert_state_dict(state_dict):
    """Converts a state dict saved from a dataParallel module to normal
       module state_dict inplace
       :param state_dict is the loaded DataParallel model_state
    """
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    return new_state_dict
def decode_segmap(temp,label_colours):
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0, 19):
        r[temp == l] = label_colours[l][0]
        g[temp == l] = label_colours[l][1]
        b[temp == l] = label_colours[l][2]
    print('hi')
    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    return rgb
def init_model(model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_classes = 19
    # Setup Model
    model = get_model({"arch": "hardnet"}, n_classes)
    state = convert_state_dict(torch.load(model_path, map_location=device)["model_state"])
    model.load_state_dict(state)
    model.eval()
    model.to(device)
    return device, model

In [None]:
def process_img(img, size, device, model):
    #print("Read Input Image from : {}".format(img_path))

    img_resized = cv2.resize(img, (size[1], size[0]))  # uint8 with RGB mode
    img = img_resized.astype(np.float16)

    # norm
    value_scale = 255
    mean = [0.406, 0.456, 0.485]
    mean = [item * value_scale for item in mean]
    std = [0.225, 0.224, 0.229]
    std = [item * value_scale for item in std]
    img = (img - mean) / std

    # NHWC -> NCHW
    img = img.transpose(2, 0, 1)
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float()

    images = img.to(device)
    outputs = model(images)
    colors = [  # [  0,   0,   0],
        [128, 64, 128],
        [244, 35, 232],
        [70, 70, 70],
        [102, 102, 156],
        [190, 153, 153],
        [153, 153, 153],
        [250, 170, 30],
        [220, 220, 0],
        [107, 142, 35],
        [152, 251, 152],
        [0, 130, 180],
        [220, 20, 60],
        [255, 0, 0],
        [0, 0, 142],
        [0, 0, 70],
        [0, 60, 100],
        [0, 80, 100],
        [0, 0, 230],
        [119, 11, 32],
    ]
    
    label_colours = dict(zip(range(19), colors))
    print('Output shape: ',outputs.shape)
    pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
    print(pred.shape)
    decoded = decode_segmap(temp=pred,label_colours=label_colours)
    print(img_resized.shape)
    print(decoded.shape)
    
    return decoded

In [None]:
def plot_segmendet_image(img, img_decoded):
    print("result")
    image=img.astype(np.float32)/255.0
    plt.imshow(np.concatenate((image, img_decoded), axis=0))
    plt.show()


In [None]:
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder,filename))
        #print(os.path.join(folder,filename))
        if img is not None:
            images.append(img)
    return images

In [None]:
images_path = './data'
images = load_images_from_folder(images_path)
device,model = init_model("pretrained/hardnet70_cityscapes_model.pkl")

for image in images:
    img_decoded = process_img(image,[375,1242],device,model.cuda())
    plot_segmendet_image(image, img_decoded)