In [1]:
# Libraries

import torch
import torchvision
import os
import glob
import sys
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
import torch.nn as nn
import torch.nn.functional as F
from networks import SegNet

# Constant variables
EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
LR = 0.001
INPUT_CHANNELS = 3
OUTPUT_CHANNELS = 13
IMG_SIZE = (288, 384)
# Paths

TRAIN_RGB_PATH = "/home/shubham/MTML_Pth/pytorch-nyuv2/nyuv2/train_rgb/"
TRAIN_SEG_PATH = "/home/shubham/MTML_Pth/pytorch-nyuv2/nyuv2/train_seg13/"

TEST_RGB_PATH = "/home/shubham/MTML_Pth/pytorch-nyuv2/nyuv2/test_rgb/"
TEST_SEG_PATH = "/home/shubham/MTML_Pth/pytorch-nyuv2/nyuv2/test_seg13/"

CHECKPOINT_DIR = '/home/shubham/MTML_Pth/checkpoints/'


In [2]:
def load_dataset(flag):
    """
    returns dictionary of images and their corresponding annotations split into train, val and test
    :params: flag - (task) segmentation, depth or surface normal
    """ 
    data = {}
    label = {}
    
    TRAIN_PATH_IMG = None
    TRAIN_PATH_LAB = None
    TEST_PATH_IMG = None
    TEST_PATH_LAB = None
    
    if flag == "segmentation":
        TRAIN_PATH_IMG =  TRAIN_RGB_PATH
        TRAIN_PATH_LAB = TRAIN_SEG_PATH
        TEST_PATH_IMG = TEST_RGB_PATH
        TEST_PATH_LAB = TEST_SEG_PATH
        
    elif flag == "depth":
        TRAIN_PATH_IMG = None
        TRAIN_PATH_LAB = None # fill in later
        
    train_images = glob.glob(TRAIN_PATH_IMG + "*.png")
    train_labels = glob.glob(TRAIN_PATH_LAB + "*.png")
    
    index = np.random.permutation(len(train_images))
    images = np.array(train_images)[index]
    labels = np.array(train_labels)[index]
    
    length = int(len(images)*0.85)
   
    data["train"], data["val"] = images[:length], images[length:]
    label["train"], label["val"] = labels[:length], labels[length:]
    data["test"] = glob.glob(TEST_PATH_IMG + "*.png")
    label["test"] = glob.glob(TEST_PATH_LAB + "*.png")
    
    return data, label


In [3]:
class DatasetLoader(Dataset):
    
    def __init__(self, data, ground_truth, transform = None):
        self.data = data
        self.gt = ground_truth
        self.length = len(data)
        self.transform = transform

    def __len__(self):
        return self.length
  
    def __getitem__(self, idx):
        img = Image.open(self.data[idx])
        img = img.resize(IMG_SIZE, Image.BILINEAR)
        img = np.array(img)
        gt = Image.open(self.gt[idx])
        gt = gt.resize(IMG_SIZE, Image.BILINEAR)
        gt = np.array(gt)
        
        if self.transform:
            img = self.transform(img)
            
        return img, gt


In [8]:
def train_loop(model, tloader, vloader, criterion, optimizer):
    """
    returns loss and accuracy of the model for 1 epoch.
    params: model -  vgg16
          tloader - train dataset
          vloader - val dataset
          criterion - loss function
          optimizer - Adam optimizer
    """
    total = 0
    correct = 0
    train_losses = []
    valid_losses = []
    t_epoch_accuracy = 0

    v_epoch_accuracy = 0
    
    model.train()
    model.to(DEVICE)
    
    for ind, (image, label) in enumerate(tloader):
     
        image = image.to(DEVICE)
        label = label.type(torch.LongTensor)
        label = label.to(DEVICE)
        
        optimizer.zero_grad()

        output, _= model(image)
        print(output.shape, type(output), label.shape, type(label))
        loss = criterion(output, label)
        train_losses.append(loss.item())
        loss.backward()
        optimizer.step()

    t_epoch_loss = np.average(train_losses)
    
    total = 0
    correct = 0
    
    model.eval()
    with torch.no_grad():
        for ind, (image, label) in enumerate(vloader):
            image = image.to(DEVICE)
            label = label.type(torch.LongTensor)
            label = label.to(DEVICE)
            output,_ = model(image)
            loss = criterion(output, label)
            valid_losses.append(loss.item())
    
    
    v_epoch_loss = np.average(valid_losses)
        
    
    return t_epoch_loss, v_epoch_loss



In [9]:
def train_model(trainloader, valloader):
    """
    returns losses (train and val), accuracies (train and val), trained_model
    params: trainloader = train dataset
            valloader = validation dataset
    """
    
    model = SegNet(INPUT_CHANNELS, OUTPUT_CHANNELS).to(DEVICE)
    
    criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    
    train_loss = []
    val_loss = []
    train_acc = []
    val_acc = []
    
    
    for epoch in range(EPOCHS):
        print("Running Epoch {}".format(epoch+1))

        epoch_train_loss,  epoch_val_loss = train_loop(model, trainloader, valloader, criterion, optimizer)
        train_loss.append(epoch_train_loss)   
        val_loss.append(epoch_val_loss)
  

        print("Training loss: {:.4f}".format(epoch_train_loss))
        print("Validation loss: {:.4f}".format(epoch_val_loss))
        print("--------------------------------------------------------")
        
        
        if (epoch+1)%5 == 0:
            torch.save(model.state_dict(), CHECKPOINT_DIR + "/segnet_epoch_" + str(epoch+1) + ".pth")

    print("Training completed!")
    losses = [train_loss, val_loss]
    accuracies = [train_acc, val_acc]
    
    return losses, accuracies, model



In [10]:
def run_inference(model, testloader):
    """
    returns performance of the model on test dataset
    """
    total = 0
    correct = 0
    
    model.eval()
    with torch.no_grad():
        for ind, (image, label) in enumerate(testloader):
            image = image.to(DEVICE)
            label = label.to(DEVICE, dtype=torch.long)

            output = model(image)
            _, predicted = torch.max(output.data, 1)
            total += label.size(0)
            correct += (predicted==label).sum().item()
    
    
    accuracy = 100*correct/total
    print("Test Accuracy: {}".format(accuracy))
    

    


def get_data_loader(data, label, flag):
    """
    returns train/test/val dataloaders
    params: flag = train/test/val
    """

    dataset = DatasetLoader(data[flag], label[flag], transform=torchvision.transforms.ToTensor()) 
    dataloader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=4)

    return dataloader


In [11]:
def main():
    
    data, labels = load_dataset("segmentation")
    train_loader = get_data_loader(data, labels, "train")
    val_loader = get_data_loader(data, labels,"val")
    test_loader = get_data_loader(data, labels,"test")
 
    # train model
    losses, accuracies, model = train_model(train_loader, val_loader)
    
#     # plot trained metrics
#     loss_curve = "loss"
#     draw_training_curves(losses[0], losses[1],loss_curve)
#     acc_curve = "accuracy"
#     draw_training_curves(accuracies[0], accuracies[1] ,acc_curve)
    
#     # inference on test data and results
#     create_confusion_matrix(model, val_loader)
    
#     # save the final model
#     torch.save(model.state_dict(), CHECKPOINT_DIR + "/cnn_model_final_" + str(EPOCHS) + ".pth")
    
    return

In [12]:
if __name__ == "__main__":
    main()

Running Epoch 1
torch.Size([8, 13, 384, 288]) <class 'torch.Tensor'>


RuntimeError: CUDA error: device-side assert triggered