In [1]:
from collections import namedtuple
from Segnet import network
from data_loader import data_loader_seg
# from LoadWeights import preload_encoder_weights

import torch 
import numpy as np 
from torch.autograd import Variable
import torch.nn as nn
from torchvision import datasets,models,transforms
import torch.optim as optim
from PIL import Image
import pickle

In [2]:
#--------------------------------------------------------------------------------
# Definitions
#--------------------------------------------------------------------------------

# a label and all meta information
Label = namedtuple( 'Label' , [

    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label
    ] )


In [3]:
#--------------------------------------------------------------------------------
# A list of all labels
#--------------------------------------------------------------------------------

labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]

In [4]:
#--------------------------------------------------------------------------------
# Create dictionaries for a fast lookup
#--------------------------------------------------------------------------------

# Please refer to the main method below for example usages!

# name to label object
# name2label      = { label.name    : label for label in labels           }

# id to label object
# id2label        = { label.id      : label for label in labels           }

# trainId to label object
# trainId2label   = { label.trainId : label for label in reversed(labels) }

# color to label object 
color2label = { label.color : label for label in labels }

# category to list of label objects
# category2labels = {}
# for label in labels:
#     category = label.category
#     if category in category2labels:
#         category2labels[category].append(label)
#     else:
#         category2labels[category] = [label]

In [5]:
model_ft = network(35)
    
if torch.cuda.is_available():
    model_ft = model_ft.cuda()

In [6]:
# preload_encoder_weights(model_ft)

#APPLY TRANSFORM IF NEEDED
trans = transforms.Compose([ 
    transforms.CenterCrop((1200, 350)), 
    transforms.ToTensor(),
])

In [7]:
dsets_train = data_loader_seg('./Dataset/data_semantics/training/',trans=None, c2id=color2label)
dsets_enqueuer_training = torch.utils.data.DataLoader(dsets_train, batch_size=1, num_workers=10, drop_last=False)

In [8]:
#for idx, data in enumerate(dsets_enqueuer_training, 1):
       #image,image_seg = data['image'], data['image_seg']

In [9]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model_ft.parameters(),lr = 0.001, betas=(0.9, 0.999), eps=1e-08)

if torch.cuda.is_available():
    criterion = criterion.cuda()

loss_data = 0.0
loss_data_testing = 0.0

loss_per_epoch_lst = []

In [10]:
print("\n\n\n......Training......")
loss_lst_train = []
loss_lst_test = []



for Epoch in range(100):
    
    for idx, data in enumerate(dsets_enqueuer_training, 1):
        #implements a batch process rather than individual processsing images
        # resulting in (1, 375 ,1242, 3)
        # we need (1, 375, 1242)
        
        image,image_seg = data['image'], data['image_seg']
        
    
        
        
        print(type(image))
        print(type(image_seg))
        
        if torch.cuda.is_available():
            image, image_seg = Variable(image.cuda(), requires_grad = False), Variable(image_seg.cuda(), requires_grad = False)
        else:
            image, image_seg = Variable(image, requires_grad = False), Variable(image_seg, requires_grad = False)

       
        print(type(image))
        model_ft.train(True)
        output = model_ft(image)
        optimizer.zero_grad()
        
        #loss = criterion(output,image_seg)
        #loss.backward()
        #optimizer.step()

        #loss_data += loss.data
        
        break
    break
print("\n\nDONE......")




......Training......
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.autograd.variable.Variable'>

Layer1...
torch.Size([1, 64, 173, 613])

Layer2...
torch.Size([1, 128, 84, 304])

Layer3...
torch.Size([1, 256, 39, 149])

Layer4...
torch.Size([1, 512, 16, 71])

Layer7...
torch.Size([1, 256, 39, 149])

Layer8...
torch.Size([1, 128, 84, 304])

Layer9...
torch.Size([1, 64, 173, 613])

Layer10...
torch.Size([1, 64, 350, 1230])

conv1x1
size of out_conv1x1: torch.Size([1, 35, 350, 1230])

Softmax Layer...


DONE......


In [11]:
print (type(output))

<class 'torch.autograd.variable.Variable'>
