In [9]:
# Libraries

import torch
import torchvision
import os
import glob
import sys
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
import torch.nn as nn
import torch.nn.functional as F

from networks import SegNet
# from load_nyuv2_dataset import load_dataset

# Constant variables
EPOCHS = 100
DEVICE =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 20
LR = 0.0001  #0.000
INPUT_CHANNELS = 3
OUTPUT_CHANNELS = 14
IMG_SIZE = (288, 384)
# Paths


CHECKPOINT_DIR = '/home4/shubham/MTML_Pth/checkpoints/'
VIS_RESULTS_PATH = '/home4/shubham/MTML_Pth/results/'

In [2]:
class DatasetLoader(Dataset):
    
    def __init__(self, data, ground_truth, transform = None):
        self.data = data
        self.gt = ground_truth
        self.length = len(data)
        self.transform = transform

    def __len__(self):
        return self.length
  
    def __getitem__(self, idx):
        img = Image.open(self.data[idx])
        img = img.resize(IMG_SIZE, Image.BILINEAR)
        gt = Image.open(self.gt[idx])
        gt = gt.resize(IMG_SIZE, Image.NEAREST)
        
        if self.transform:
            img, gt = self.transform(img, gt)
            gt = np.array(gt)
            
        return img, gt


In [3]:
class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    
    """
    def __init__(self, patience=7, verbose=False, delta=0, path=CHECKPOINT_DIR+'early_stopping_sgd_segmentation_model.pth'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 10
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'early_stopping_vgg16model.pth'   
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
    
    def __call__(self, val_loss, model):
        
        score = -val_loss
        
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            
        elif score < self.best_score + self.delta:
            self.counter += 1
            
            if self.counter >= self.patience:
                self.early_stop = True
                
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0   
    
    def save_checkpoint(self, val_loss, model):
        """
        saves the current best version of the model if there is decrease in validation loss
        """
        torch.save(model.state_dict(), self.path)
        self.vall_loss_min = val_loss
        

In [4]:
class ConfMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, pred, target):
        n = self.num_classes
        if self.mat is None:
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
        with torch.no_grad():
            k = (target >= 0) & (target < n)
            inds = n * target[k].to(torch.int64) + pred[k]
            self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

    def get_metrics(self):
        h = self.mat.float()
        acc = torch.diag(h).sum() / h.sum()
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return torch.mean(iu), acc
    

In [5]:
def train_loop(model, tloader, vloader, criterion, optimizer):
    """
    returns loss and accuracy of the model for 1 epoch.
    params: model -  vgg16
          tloader - train dataset
          vloader - val dataset
          criterion - loss function
          optimizer - Adam optimizer
    """
    total = 0
    correct = 0
    train_losses = []
    valid_losses = []
    t_mean_iou = []
    v_mean_iou = []
    
    model.train()
    model.to(DEVICE)
    train_cm = ConfMatrix(OUTPUT_CHANNELS)
    test_cm = ConfMatrix(OUTPUT_CHANNELS)
    
    for ind, (image, label) in enumerate(tloader):
     
        image = image.to(DEVICE)
        label = label.type(torch.LongTensor)
        label = label.to(DEVICE)
        
        optimizer.zero_grad()

        output, _ = model(image)
        
        loss = criterion(output, torch.squeeze(label,1))
        train_losses.append(loss.item())
        loss.backward()
        optimizer.step()
        
        pred_values, predicted = torch.max(output, 1)
        train_cm.update(predicted, label)
       
    
    t_epoch_iou, train_accuracy = train_cm.get_metrics()    
    t_epoch_loss = np.average(train_losses)
    

    total = 0
    correct = 0
    
    model.eval()
    with torch.no_grad():
        for ind, (image, label) in enumerate(vloader):
            image = image.to(DEVICE)
            label = label.type(torch.LongTensor)
            label = label.to(DEVICE)
            output,_ = model(image)
            loss = criterion(output, torch.squeeze(label,1))
            pred_values, predicted = torch.max(output, 1)
            test_cm.update(predicted, label)
            valid_losses.append(loss.item())
            
            if ind == 0:
                fig = plt.figure(figsize=(11,11))
                ax = plt.subplot(1, 3, 1)
                plt.imshow(image[0].cpu().numpy().transpose((1, 2, 0)))
                ax = plt.subplot(1, 3, 2)
                plt.imshow(torch.squeeze(label[0].cpu(),0))
                ax = plt.subplot(1, 3, 3)
                plt.imshow(predicted[0].cpu())
                plt.show()
                fig2 = plt.figure(figsize=(11,11))
                ax = plt.subplot(1, 3, 1)
                plt.imshow(image[1].cpu().numpy().transpose((1, 2, 0)))
                ax = plt.subplot(1, 3, 2)
                plt.imshow(torch.squeeze(label[1].cpu(),0))
                ax = plt.subplot(1, 3, 3) 
                plt.imshow(predicted[1].cpu())
                plt.show()
    
    v_epoch_loss = np.average(valid_losses)
    v_epoch_iou, val_accuracy = test_cm.get_metrics() 
    
    return model, t_epoch_loss, v_epoch_loss, train_accuracy, val_accuracy, t_epoch_iou, v_epoch_iou



In [6]:
def train_model(trainloader, valloader):
    """
    returns losses (train and val), accuracies (train and val), trained_model
    params: trainloader = train dataset
            valloader = validation dataset
    """
    
    model = SegNet(INPUT_CHANNELS, OUTPUT_CHANNELS).to(DEVICE)
    
    criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    
    train_loss = []
    val_loss = []
    train_acc = []
    val_acc = []
    
    early_stop = EarlyStopping(patience=7)
    
    for epoch in range(EPOCHS):
        print("Running Epoch {}".format(epoch+1))

        model, epoch_train_loss,  epoch_val_loss, train_ac, val_ac, train_iou, val_iou = train_loop( model, trainloader, valloader, criterion, optimizer)
        train_loss.append(epoch_train_loss)   
        val_loss.append(epoch_val_loss)
        train_acc.append(train_ac)
        val_acc.append(val_ac)

        print("Training loss: {0:.4f}   Training accuracy: {1:.4f}   Training mIoU: {2:.4f}".format(epoch_train_loss, train_ac, train_iou))
        print("Validation loss: {0:.4f} Validation accuracy: {1:.4f} Validation mIoU: {2:.4f}".format(epoch_val_loss, val_ac, val_iou))
        print("--------------------------------------------------------")
        
        early_stop(epoch_val_loss, model)
    
        if early_stop.early_stop:
            print("Early stopping")
            break 

        if (epoch+1)%5 == 0:
            torch.save(model.state_dict(), CHECKPOINT_DIR + "/segnet_epoch_" + str(epoch+1) + ".pth")

    print("Training completed!")
    losses = [train_loss, val_loss]
    accuracies = [train_acc, val_acc]
    
    return losses, accuracies, model



In [7]:
def run_inference(testloader):
    """
    returns performance of the model on test dataset
    """
    correct, total = 0, 0
    model = SegNet(INPUT_CHANNELS, OUTPUT_CHANNELS).to(DEVICE)
    criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    v_mean_iou = []
    valid_losses = []
    model.eval()
    print("Loading pre-trained weights...")
    final_model_path = CHECKPOINT_DIR+'early_stopping_segmentation_model.pth'
    model.load_state_dict(torch.load(final_model_path))
    print("Weights loaded!")
    test_cm = ConfMatrix(OUTPUT_CHANNELS)
    
    with torch.no_grad():
        for ind, (image, label) in enumerate(testloader):
            image = image.to(DEVICE)
            label = label.type(torch.LongTensor)
            label = label.to(DEVICE)
            output,_ = model(image)
            pred_values, predicted = torch.max(output, 1)
            test_cm.update()
            
            for i in range(25):
                fig = plt.figure(figsize=(11,11))
                ax = plt.subplot(1, 3, 1)
                plt.imshow(image[i].cpu().numpy().transpose((1, 2, 0)))
                ax = plt.subplot(1, 3, 2)
                plt.imshow(torch.squeeze(label[i].cpu(),0))
                ax = plt.subplot(1, 3, 3)
                plt.imshow(predicted[i].cpu())
                plt.show()
                if i == 5:
                    break
        
        iou, accuracy = test_cm.get_metrics()
        print("Mean Pixel IoU: ",iou)
        print("Accuracy: ",accuracy)
   
    

def train_transformation(image, gt):
    
    p = np.random.uniform(0, 1)
    
    if p <= 0.5 :
        img_out = image.transpose(Image.FLIP_LEFT_RIGHT)
        img = torchvision.transforms.ToTensor()(img_out)
        gt = gt.transpose(Image.FLIP_LEFT_RIGHT)
        
    else:
        img = torchvision.transforms.ToTensor()(image)
    
    return img, gt

def test_transformation(image, gt):
    img = torchvision.transforms.ToTensor()(image)
    return img, gt


def get_data_loader(data, label, flag):
    """
    returns train/test/val dataloaders
    params: flag = train/test/val
    """
    if flag == "train":
        dataset = DatasetLoader(data[flag], label[flag], transform=train_transformation) 
        dataloader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=4)
 
    else:
        dataset = DatasetLoader(data[flag], label[flag], transform=test_transformation)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=4)

    return dataloader

def draw_training_curves(train_losses, test_losses,curve_name):
    plt.clf()
   
    plt.plot(train_losses, label='Training {}'.format(curve_name))
    plt.plot(test_losses, label='Testing {}'.format(curve_name))
    plt.legend(frameon=False)
    plt.savefig(VIS_RESULTS_PATH + "/{}_sgd_segmentation.png".format(curve_name))

In [8]:

data, labels = load_dataset("segmentation")
train_loader = get_data_loader(data, labels, "train")
val_loader = get_data_loader(data, labels,"val")
test_loader = get_data_loader(data, labels,"test")

# train model
losses, accuracies, model = train_model(train_loader, val_loader)



In [None]:

# plot trained metrics
loss_curve = "loss"
draw_training_curves(losses[0], losses[1],loss_curve)
loss_curve = "accuracy"
draw_training_curves(accuracies[0], accuracies[1],loss_curve)

In [None]:
test_loader = get_data_loader(data, labels,"test")
run_inference(test_loader)