In [1]:
# TODO: Define here your training and validation loops.
from datasets.cityscapes import CityScapes
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from models.bisenet.build_bisenet import BiSeNet
from utils import poly_lr_scheduler
import pandas as pd
import numpy as np

def intersectionAndUnionGPU(output, target, K, ignore_index=255):
    # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
    assert (output.dim() in [1, 2, 3])
    assert output.shape == target.shape
    output = output.view(-1)
    target = target.view(-1)
    #next line is to ignore areas that are classified as void
    output[target == ignore_index] = ignore_index
    intersection = output[output == target]
    #TODO: take each pixel (r,g,b) and convert to a single value representing the class.
    #TODO: Then, calculate the intersection, union and target areas.
    
    area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1)
    area_output = torch.histc(output, bins=K, min=0, max=K-1)
    area_target = torch.histc(target, bins=K, min=0, max=K-1)
    area_union = area_output + area_target - area_intersection
    return area_intersection, area_union, area_target

def train(model,optimizer, train_loader, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = inputs.float()
        #targets = targets.int()
        
        #Compute prediction and loss
        outputs,_,_ = model(inputs)
        print(batch_idx)
        #print("size targets before",targets.size())
        """
        encoded_torch=[]
        op=targets.permute(0,2,3,1)
        #print("op",op.size())
        for i, x in enumerate(op):
            x=x.to(dtype=torch.int64)
            partial_encoded=RGBtoOneHot(x,colorDict)
            #print("partial",partial_encoded)
            #for numpy
            #partial_encoded=RGBtoOneHot(x,colorDict)
            #partial_encoded_torch=torch.from_numpy(partial_encoded)
            #print(partial_encoded_torch.shape)
            partial_encoded_torch =  partial_encoded.unsqueeze(0)
            if i==0:
                encoded_torch=partial_encoded_torch
            else:
                encoded_torch=torch.cat((encoded_torch,partial_encoded_torch))
        #print(encoded_torch[0,:,:])
        #print("size targets",encoded_torch.size())
        #print("size outputs",outputs.size())
        #preds = preds.int()
        """
        unique_values = torch.unique(targets)
        print("target labels",unique_values)
        loss = loss_fn(outputs.to(dtype=torch.float64), targets.squeeze().to(dtype=torch.int64))
        #loss = criterion (outputs,targets)

        #BackPropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        #total += targets.size(0)
        #correct += predicted.eq(targets).sum().item()

    train_loss = running_loss / len(train_loader)
    #train_accuracy = 100. * correct / total
    #print(f'Train Epoch: {epoch} Loss: {train_loss:.6f} Acc: {train_accuracy:.2f}%')
    return train_loss

# Test loop
# calculate_label_prediction is a flag used to decide wether to calculate or not ground_truth and predicted tensor
def test(model, test_loader, loss_fn):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx,(inputs, targets) in enumerate(test_loader):
            #ground_truth.append(targets)
            
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.float()
            #targets = targets.int()
            
            #Compute prediction and loss
            outputs = model(inputs)
            print(batch_idx)
            #print("size targets before",targets.size())
            encoded_torch=[]
            op=targets.permute(0,2,3,1)
            #print("op",op.size())
            for i, x in enumerate(op):
                x=x.to(dtype=torch.int64)
                partial_encoded=RGBtoOneHot(x,colorDict)
                #print("partial",partial_encoded)
                #for numpy
                #partial_encoded=RGBtoOneHot(x,colorDict)
                #partial_encoded_torch=torch.from_numpy(partial_encoded)
                #print(partial_encoded_torch.shape)
                partial_encoded_torch =  partial_encoded.unsqueeze(0)
                if i==0:
                    encoded_torch=partial_encoded_torch
                else:
                    encoded_torch=torch.cat((encoded_torch,partial_encoded_torch))
            #print(encoded_torch[0,:,:])
            #print("size targets",encoded_torch.size())
            #print("size outputs",outputs.size())
            #preds = preds.int()
            loss = loss_fn(outputs.to(dtype=torch.float64), encoded_torch.to(dtype=torch.int64))
            #loss = loss_fn(outputs, targets)
    
            test_loss += loss.item()
            #probability, predicted = outputs.max(1)
            #We need to convert to probabilities with softmax
            #soft_outputs = torch.nn.functional.softmax(outputs, dim=1) #pass through softmax
            #probability, predicted = soft_outputs.topk(1, dim = 1) # select top probability as prediction
            #probability=torch.squeeze(probability)
            #predicted=torch.squeeze(predicted)
            #total += targets.size(0)
            #correct += predicted.eq(targets).sum().item()
    test_loss = test_loss / len(test_loader)
    #test_accuracy = 100. * correct / total
    return test_loss

dataset_path='datasets/Cityscapes/Cityscapes/Cityspaces/'
annotation_train=dataset_path+'gtFine/train'
image_train=dataset_path+'images/train'

annotation_val=dataset_path+'gtFine/val'
image_val=dataset_path+'images/val'

# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
cityscapes_train = CityScapes(annotations_dir=annotation_train, images_dir=image_train,transform=transforms.Resize(size = (512,1024)))
cityscapes_val = CityScapes(annotations_dir=annotation_val, images_dir=image_val,transform=transforms.Resize(size = (512,1024)))

train_loader = DataLoader(cityscapes_train, batch_size=4, shuffle=True)
val_loader = DataLoader(cityscapes_val, batch_size=4, shuffle=True)

# Define the model and load it to the device
bisenet = BiSeNet(num_classes=19, context_path='resnet18')
bisenet.to(device)
epochs = 10
optimizer = torch.optim.Adam(bisenet.parameters(), lr=0.001)
poly_lr_scheduler(optimizer, 0.01, 1, lr_decay_iter=1, max_iter=300, power=0.9)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=255)
print(cityscapes_train.__len__())
print(len(cityscapes_train.map_index_to_image))

1572
1572


In [None]:
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=19):
    with torch.no_grad():
        print("pred_mask size",pred_mask.shape)
        print("mask size",mask.shape)
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                #print("true class size",true_class.shape)
                #print("true label size",true_label.shape)
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

def convert_tensor_to_image(tensor):
    image = tensor.permute(1, 2, 0)
    return image
def RGBtoOneHot(rgb, colorDict):
  #arr = np.zeros(rgb.shape[:2],dtype=np.int32) ## rgb shape: (h,w,3); arr shape: (h,w)
    arr_torch=torch.zeros(rgb.to(device).shape[:2],dtype=torch.int32).to(device)
    for label, color in enumerate(colorDict.keys()):
    #color_np = np.array(color)
        color_torch=torch.tensor(color).to(device)
        if label < len(colorDict.keys()):
      #arr[np.all(rgb[:,:,:2].to('cpu').numpy() == color_np, axis=-1)] = label
      #arr[np.all(rgb[:,:,:3].to('cpu').numpy() == color_np, axis=-1)] = colorDict[color][1] #1 = 'id'
            arr_torch[torch.all(rgb[:,:,:3].to(device) == color_torch, axis=-1)] = colorDict[color][2] #1 = 'id'
    return arr_torch

# PLotting dataset


In [None]:
(input,output)=next(iter(train_loader))
output.size()
fig, axes = plt.subplots(2, 1)
input_transpose=convert_tensor_to_image(input[0])
output_transpose=convert_tensor_to_image(output[0])
axes[0].imshow(input_transpose)
axes[1].imshow(output_transpose)
plt.show()
output.size()

In [None]:
from collections import namedtuple
Label = namedtuple('Label', 
['name','id','trainId','category',
'categoryId','hasInstances',
'ignoreInEval','color',])
labels = [
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) )]
    

In [None]:
colorDict = {label.color: label for label in labels}

In [None]:
"""def RGBtoOneHot(rgb, colorDict):
  #arr = np.zeros(rgb.shape[:2],dtype=np.int32) ## rgb shape: (h,w,3); arr shape: (h,w)
  arr_torch=torch.zeros(rgb.to(device).shape[:2],dtype=torch.int32).to(device)
  for label, color in enumerate(colorDict.keys()):
    #color_np = np.array(color)
    color_torch=torch.tensor(color).to(device)
    if label < len(colorDict.keys()):
      #arr[np.all(rgb[:,:,:2].to('cpu').numpy() == color_np, axis=-1)] = label
      #arr[np.all(rgb[:,:,:3].to('cpu').numpy() == color_np, axis=-1)] = colorDict[color][1] #1 = 'id'
      arr_torch[torch.all(rgb[:,:,:3].to(device) == color_torch, axis=-1)] = colorDict[color][2] #1 = 'id'
  return arr_torch
"""

In [None]:
RGBtoOneHot(output_transpose,colorDict).shape

In [None]:
#to load the model
bisenet = BiSeNet(num_classes=19, context_path='resnet18')
bisenet.to(device)
bisenet.load_state_dict(torch.load('bisenet_epoch_37_weights.pth'))
epoch_beginning=5

In [None]:
for epoch in range(0,epochs):
    train_loss=train(bisenet, optimizer, train_loader, loss_fn)
    file_name='bisenet_epoch_'+str(epoch)+'_weights.pth'
    torch.save(bisenet.state_dict(),file_name)
    test_loss = test(bisenet, val_loader, loss_fn)
    #print("size ground truth: ",ground_truth.size())
    #print(f"Epoch n.{epoch} - Test accuracy: {test_acc}")  # You should get values around 90% accuracy on the test set
    print(f"Epoch n.{epoch} - Test loss: {test_loss}")  # You should get values around 90% accuracy on the test set

# make prediction

In [None]:
from torchvision.io import read_image 
import posixpath
import torch
import torch.nn as nn
import numpy as np
dataset_path='datasets/Cityscapes/Cityscapes/Cityspaces/'
annotation_val=dataset_path+'gtFine/val'
transform = transforms.Resize((512,512))
def make_prediction(model,image_path):
    # set model to evaluation mode
    iou_score=0
    accuracy=0
    model.eval()
    #retrieve image and annotation
    image = read_image(image_path)
    print("image size",image.shape)
    path=image_path.split('/')
    image_name = posixpath.join(path[-2],path[-1])
    annotation_path = posixpath.join(annotation_val, image_name.replace("_leftImg8bit.png","_gtFine_color.png"))
    annotation = read_image(annotation_path)[0:3,:,:]
    input = transform(image)
    annotation = transform(annotation)
    annotation=annotation.permute(1, 2, 0)
    annotation_encoded=RGBtoOneHot(annotation,colorDict)
    
    #generate prediction
    with torch.no_grad():
        
        #input=image
        input=input.float().to(device)
        print("")
        print("generating prediction..")
        #we add unsqueezeto create a batch dimension
        output = model(input.unsqueeze(0))
        #print("output",output.shape)
        #_, preds = torch.max(outputs, 1)
        iou_score += mIoU(output, annotation_encoded)
        accuracy += pixel_accuracy(output, annotation_encoded)
        softmax = nn.Softmax(dim=1)
        preds = torch.argmax(softmax(output),axis=1)
    return input,image,annotation,annotation_encoded,preds,iou_score,accuracy
    
	# turn off gradient tracking
	
input,image,annotation,annotation_encoded,preds,iou_score,accuracy=make_prediction(bisenet,'datasets/Cityscapes/Cityscapes/Cityspaces/images/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png')

In [None]:
print("annotation size after permutation",annotation.shape)
print("annotation encoded size",annotation_encoded.shape)
#print("prediction size",output.shape)
print("prediction size after softmax",preds.shape)
#visualization
preds_custom=preds.squeeze()
print("scueezed preds size",preds_custom.shape)
fig, axes = plt.subplots(4, 1)
fig.tight_layout()
#axes[0].imshow(image.permute(1, 2, 0))
axes[0].imshow(transform(image).permute(1, 2, 0).cpu())
axes[1].imshow(annotation_encoded.cpu())
axes[2].imshow(annotation)
axes[3].imshow(preds_custom.cpu())
axes[0].set_title('Image',fontsize=10)
axes[1].set_title('Annotation encoded',fontsize=10)
axes[2].set_title('Annotation',fontsize=10)
axes[3].set_title('prediction',fontsize=10)
plt.show()

In [None]:
print(accuracy,iou_score)
print('bisenet_epoch_37_weights')

In [None]:
%matplotlib notebook
%matplotlib inline

In [None]:
!pip install ipympl

In [None]:
from matplotlib import cm
from torchvision.io import read_image 
from PIL import Image 
test_path='datasets/Cityscapes/Cityscapes/Cityspaces/gtFine/train/hanover/hanover_000000_000164_gtFine_labelTrainIds.png'
image = Image.open(test_path)
    #if(key%100==0): print(key)
image_array = np.array(image)
unique_values = np.unique(image_array)
test=read_image(test_path)
fig, axes = plt.subplots(1, 1)
new_inferno = cm.get_cmap('hsv', 13)
axes.imshow(test.squeeze(0))
#cmap = plt.get_cmap('bwr')
#plt.set_cmap(cmap)
plt.show()
print(unique_values)
print(torch.unique(test))

# trash

In [None]:
torch.save(bisenet.state_dict(),file_name)

In [None]:
train_loader.dataset.__getitem__(0)[1].shape

In [None]:
#TODO: remove from dict all the labels with train_id = 255
dict={
     "128_64_128"  :7,
     "244_35_232"  :8,
     "70_70_70" :11,
     "102_102_156" :12,
     "190_153_153" :13,
     "153_153_153" :17,
     "250_170_30" :19,
     "220_220_0" :20,
     "107_142_35" :21,
     "152_251_152" :22,
     "70_130_180" :23,
     "220_20_60" :24,
     "255_0_0" :25,
     "0_0_142" :26,
     "0_0_70" :27,
     "0_60_100" :28,
     "0_80_100" :31,
     "0_0_230" :32,
     "119_11_32" :33}
"""
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),

        #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,(111, 74,  0) ),
    Label(  'ground'               ,  6 ,( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,(128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,(244, 35,232) ),
    Label(  'parking'              ,  9 ,(250,170,160) ),
    Label(  'rail track'           , 10 ,(230,150,140) ),
    Label(  'building'             , 11 ,( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,(102,102,156) ),
    Label(  'fence'                , 13 ,(190,153,153) ),
    Label(  'guard rail'           , 14 ,(180,165,180) ),
    Label(  'bridge'               , 15 ,(150,100,100) ),
    Label(  'tunnel'               , 16 ,(150,120, 90) ),
    Label(  'pole'                 , 17 ,(153,153,153) ),
    Label(  'polegroup'            , 18 ,(153,153,153) ),
    Label(  'traffic light'        , 19 ,(250,170, 30) ),
    Label(  'traffic sign'         , 20 ,(220,220,  0) ),
    Label(  'vegetation'           , 21 ,(107,142, 35) ),
    Label(  'terrain'              , 22 ,(152,251,152) ),
    Label(  'sky'                  , 23 ,( 70,130,180) ),
    Label(  'person'               , 24 ,(220, 20, 60) ),
    Label(  'rider'                , 25 ,(255,  0,  0) ),
    Label(  'car'                  , 26 ,(  0,  0,142) ),
    Label(  'truck'                , 27 ,(  0,  0, 70) ),
    Label(  'bus'                  , 28 ,(  0, 60,100) ),
    Label(  'caravan'              , 29 ,(  0,  0, 90) ),
    Label(  'trailer'              , 30 ,(  0,  0,110) ),
    Label(  'train'                , 31 ,(  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,(  0,  0,230) ),
    Label(  'bicycle'              , 33 ,(119, 11, 32) ),
    Label(  'license plate'        , -1 ,(  0,  0,142) ),

     {"111_74_0"  :5,
     "81_0_81"  :6,
     "128_64_128"  :7,
     "244_35_232"  :8,
     "250_170_160"  :9,
     "230_150_140" :10,
     "70_70_70" :11,
     "102_102_156" :12,
     "190_153_153" :13,
     "180_165_180" :14,
     "150_100_100" :15,
     "150_120_90" :16,
     "153_153_153" :17,
     "153_153_153" :18,
     "250_170_30" :19,
     "220_220_0" :20,
     "107_142_35" :21,
     "152_251_152" :22,
     "70_130_180" :23,
     "220_20_60" :24,
     "255_0_0" :25,
     "0_0_142" :26,
     "0_0_70" :27,
     "0_60_100" :28,
     "0_0_ 90" :29,
     "0_0_110" :30,
     "0_80_100" :31,
     "0_0_230" :32,
     "119_11_32" :33,
     "0_0_142" :-1}
"""

In [None]:
print(torch.cuda.is_available())

In [None]:
#TODO: crop input to 224x224
#TODO: create dictionary to map each label to color
#dict={"12864128":10} # 0=road, the key is concatenation of rgb values
dict={
     "128_64_128"  :7,
     "244_35_232"  :8,
     "70_70_70" :11,
     "102_102_156" :12,
     "190_153_153" :13,
     "153_153_153" :17,
     "250_170_30" :19,
     "220_220_0" :20,
     "107_142_35" :21,
     "152_251_152" :22,
     "70_130_180" :23,
     "220_20_60" :24,
     "255_0_0" :25,
     "0_0_142" :26,
     "0_0_70" :27,
     "0_60_100" :28,
     "0_80_100" :31,
     "0_0_230" :32,
     "119_11_32" :33}
def map_value_to_label():
    #use as key the concatenation of rgb. this gives the label
    pass
def map_labels(a):
    key=str(a[0])+"_"+str(a[1])+"_"+str(a[2])
    #label=dict.get(key)
    #if label == None:
    #    return -1
    #return label
    return key
    #print(a[0],a[1],a[2])

"""
print(op.size())
x=op.to('cpu').numpy()#-->(r,g,b,a)
#y=pd.DataFrame(x)
x.shape
#x=x.reshape(32,1024,2048,4)
x_rgb=x[:,:,:,0:3]
#print(x[:,:,:,0:3])
y=np.apply_along_axis(map_labels, -1, x_rgb)
y.shape
y[0][0][0]
"""
"""
x[:,:,:,3]= 1
#x[:,:,:,3] =
from operator import itemgetter
x[:,:,:,3] =itemgetter(str(x[:,:,:,0])+"_"+str(x[:,:,:,1])+"_"+str(x[:,:,:,2]))(dict)
#itemgetter("128_64_128","244_35_232")(dict)
#dict.get("128_64_128","244_35_232")
"""

In [None]:
labels = [
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) )]

In [None]:
from PIL import Image 
all_unique_values = set()

for image_path in cityscapes_train.map_index_to_annotation:
    image = Image.open(image_path)
    #if(key%100==0): print(key)
    image_array = np.array(image)
    
    unique_values = np.unique(image_array)
    all_unique_values.update(unique_values)

print(sorted(all_unique_values))


In [2]:
(input,output)=next(iter(train_loader))
unique_values = np.unique(output)
print(unique_values)

AttributeError: 'PngImageFile' object has no attribute 'shape'