In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from skimage import io
import numpy as np
class ContextNet(nn.Module):
    def __init__(self):
        super(ContextNet,self).__init__()
        
        self.conv1_a = nn.Conv2d(3,64,3)
        self.bnorm1_a = nn.BatchNorm2d(64)
        self.conv1_b = nn.Conv2d(64,64,3)
        self.bnorm1_b = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2,return_indices=True)
        
        self.conv2_a = nn.Conv2d(64,128,3)
        self.bnorm2_a = nn.BatchNorm2d(128)
        self.conv2_b = nn.Conv2d(128,128,3)
        self.bnorm2_b = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2,return_indices=True)
        
        self.conv3_a = nn.Conv2d(128,256,3)
        self.bnorm3_a = nn.BatchNorm2d(256)
        self.conv3_b = nn.Conv2d(256,256,3)
        self.bnorm3_b = nn.BatchNorm2d(256)
        self.conv3_c = nn.Conv2d(256,256,3)
        self.bnorm3_c = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(2,return_indices=True)
        
        self.unpool4 = nn.MaxUnpool2d(2)
        self.conv4_a = nn.Conv2d(256,256,3,padding=2)
        self.bnorm4_a = nn.BatchNorm2d(256)
        self.conv4_b = nn.Conv2d(256,256,3,padding=2)
        self.bnorm4_b = nn.BatchNorm2d(256)
        self.conv4_c = nn.Conv2d(256,128,3,padding=2)
        self.bnorm4_c = nn.BatchNorm2d(128)
        
        self.unpool5 = nn.MaxUnpool2d(2)
        self.conv5_a = nn.Conv2d(128,128,3,padding=2)
        self.bnorm5_a = nn.BatchNorm2d(128)
        self.conv5_b = nn.Conv2d(128,64,3,padding=2)
        self.bnorm5_b = nn.BatchNorm2d(64)
        
        self.unpool6 = nn.MaxUnpool2d(2)
        self.conv6_a = nn.Conv2d(64,64,3,padding=2)
        self.bnorm6_a = nn.BatchNorm2d(64)
        self.conv6_b = nn.Conv2d(64,3,3,padding=2)
        self.bnorm6_b = nn.BatchNorm2d(3)
#         self.softmax = nn.Softmax2d()

    def forward(self,x):
        
        y = F.relu(self.bnorm1_a(self.conv1_a(x)))
        y = F.relu(self.bnorm1_b(self.conv1_b(y)))
        pool_size1=y.size()
        y,ind1=self.pool1(y)
        
        y = F.relu(self.bnorm2_a(self.conv2_a(y)))
        y = F.relu(self.bnorm2_b(self.conv2_b(y)))
        pool_size2=y.size()
        y,ind2=self.pool2(y)
        
        y = F.relu(self.bnorm3_a(self.conv3_a(y)))
        y = F.relu(self.bnorm3_b(self.conv3_b(y)))
        y = F.relu(self.bnorm3_c(self.conv3_c(y)))
        pool_size3=y.size()
        y,ind3=self.pool3(y)
        
        y = self.unpool4(y,ind3,output_size=pool_size3)
        y = F.relu(self.bnorm4_a(self.conv4_a(y)))
        y = F.relu(self.bnorm4_b(self.conv4_b(y)))
        y = F.relu(self.bnorm4_c(self.conv4_c(y)))
        
        y = self.unpool5(y,ind2,output_size=pool_size2)
        y = F.relu(self.bnorm5_a(self.conv5_a(y)))
        y = F.relu(self.bnorm5_b(self.conv5_b(y)))
        
        y = self.unpool4(y,ind1,output_size=pool_size1)
        y = F.relu(self.bnorm6_a(self.conv6_a(y)))
        y = F.relu(self.bnorm6_b(self.conv6_b(y)))
        
        return y

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,utils
import re,glob
import matplotlib.pyplot as plt
from skimage import io
from skimage.transform import resize
import torch
class ObstacleDataset(Dataset):
    def __init__(self,files,transforms=None):
        #super(ObstacleDataset,self).__init__()
        self.files=files
        self.transforms=transforms
    
    def __len__(self):
        return(len(self.files))
    
    def __getitem__(self,idx):
        file = self.files[idx]
        rgb = io.imread(file[0])
        segmentation_mask = io.imread(file[1])
        sample = {'rgb': rgb,'mask': segmentation_mask}
        if self.transforms:
            sample = self.transforms(sample)
        
        return sample

class RandomHorizontalFlip(object):
    """Randomly flips images in a sample in the horizontal direction
    
    Args:
        prob: probability of horizontal flip
    """
    
    def __init__(self,prob=0.5):
        self.prob= prob
    
    def __call__(self,sample):
        do_flip = np.random.rand()
        if do_flip<self.prob:
            return sample
        rgb,mask = sample['rgb'],sample['mask']
        rgb = np.flip(rgb,1)
        mask = np.flip(mask,1)
        return {'rgb':rgb,'mask':mask}
    
class ToTensor(object):
    """Converts ndarrys in sample to Tensors"""
    def __call__(self,sample):
        rgb,mask = sample['rgb'],sample['mask']
        rgb = rgb.transpose((2,0,1))
        mask[mask==255]=0
        # perform one-hot encoding on mask later if required
        rgb=torch.from_numpy(np.flip(rgb,axis=0).copy())
        return {'rgb':rgb.type("torch.FloatTensor"),
               'mask':torch.from_numpy(mask.copy(),)}
class Normalize(object):
    """Performs normalization on Tensor"""
    def __call__(self,sample):
        rgb,mask=sample['rgb'],sample['mask']
        norm = transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
        rgb=norm(rgb)
        return {'rgb':rgb,'mask':mask}
class Resize(object):
    def __call__(self,sample):
        rgb,mask=sample['rgb'],sample['mask']
        rgb=resize(rgb, (rgb.shape[0] / 4, rgb.shape[1] / 4))
        mask=resize(mask, (mask.shape[0] / 4, mask.shape[1] / 4))
        return {'rgb':rgb,'mask':mask}
        
def show_examples(sample_batch):
    rgb_batch,mask_batch = sample_batch['rgb'],sample_batch['mask']
    rgb_grid= utils.make_grid(rgb_batch)
    mask_grid= utils.make_grid(mask_batch)
    
    ax = plt.subplot(211)
    ax.axis("off")
    plt.imshow(rgb_grid.numpy().transpose((1,2,0)))
    ax = plt.subplot(212)
    ax.axis("off")
    plt.imshow(100*mask_grid.numpy().transpose((1,2,0)))
    return
    

img_files={}
for x in ["train","test"]:
    img_files[x] = glob.glob("/home/mohit/leftImg8bit/"+x+"/*/*.png")

files={"train":[],"test":[]}

for x in ["train","test"]:
    for ifile in img_files[x]:
        path = re.search("/home/mohit/leftImg8bit/(.*)_leftImg8bit.png",ifile)
        files[x].append([ifile,"/home/mohit/gtCoarse/"+path.group(1)+"_gtCoarse_labelTrainIds.png"])
        
tnfs = transforms.Compose([RandomHorizontalFlip(),ToTensor()])

obs_dataset = {'train':ObstacleDataset(files["train"],tnfs),
               'test':ObstacleDataset(files['test'],tnfs)}

dataloader = {x:DataLoader(obs_dataset[x],batch_size=2,
                          shuffle=True,num_workers=1) for x in ['train','test']}

In [None]:
import time
def train_model(model,criterion,optimizer,scheduler=None,num_epochs=100):
    start = time.time()
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch,num_epochs-1))
        print('-'*12)
        
        if scheduler:
            scheduler.step()
        model.train()
        pixel_detection_rate = 0.0

        for sample in dataloader['train']:
            inputs = sample['rgb']
            masks = sample['mask']

            optimizer.zero_grad()

#             outputs = model(inputs)
            print(np.unique(masks.numpy()))
#             loss = criterion(outputs,masks)
#             print("loss: ",loss)
#             loss.backward()
#             optimizer.step()
            break
        time_elap= time.time()-start
        start=time.time()
        print("Time elapsed: {:.0f}:{:.0f}".format(time_elap//60,time_elap%60))
    return model

In [None]:
criterion = nn.CrossEntropyLoss()
model = ContextNet()
optimizer = optim.Adam(model.parameters())
train_model(model,criterion,optimizer)

In [None]:
sample = next(iter(dataloader['train']))

In [None]:
model=ContextNet()
outputs=model(sample['rgb'])