In [None]:
# Libraries

import torch
import torchvision
import re
import os
import glob
import sys
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
import torch.nn as nn
from scipy import io
import torchsummary
from load_scannet_vp_dataset import load_dataset
from EarlyStopping import EarlyStopping
import time
# Constant variables
EPOCHS = 100
PATIENCE = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
LR = 0.01  
EVAL = True
in_features = 131072
# Paths

CHECKPOINT_DIR = '/home4/shubham/MTML_Pth/checkpoints/vanishing_points'
VIS_RESULTS_PATH = '/home4/shubham/MTML_Pth/results/vanishing_points'


In [None]:
class VGG16Model(nn.Module):

    def __init__(self):
        super(VGG16Model,self).__init__()
    
        self.vgg16 = torchvision.models.vgg16(pretrained=True)
        self.vgg16 = self.vgg16.features

        for params in self.vgg16.parameters():
            params.requires_grad = False

        self.head = nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(in_features, 4096),
                    nn.BatchNorm1d(4096),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(4096, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(1024, 9)
                )
    
    def forward(self, image):
        
        features = self.vgg16(image)
        scores = self.head(features)
        return scores

# model = VGG16Model().to(DEVICE)
# summary(model, (3,512,512))

In [None]:
class DatasetLoader(Dataset):
    
    def __init__(self, image, v_points, transform = None):
        self.image = image
        self.v_points = v_points
        self.length = len(image)
        self.transform = transform

    def __len__(self):
        return self.length
  
    def __getitem__(self, idx):
        
        classes = []
        
        image = Image.open(self.image[idx])
        vps = np.load(self.v_points[idx])
        classes.extend(vps['x'])
        classes.extend(vps['y'])
        classes.extend(vps['z'])
        
        if self.transform:
            image = self.transform(image)
            
            
        return image, np.array(classes)

In [None]:

def train_loop(model, trainloader, testloader, criterion, optimizer):
    """
    returns loss and accuracy of the model for 1 epoch.
    params: model -  vgg16
          trainloader - train dataset
          testloader - test dataset
          criterion - loss function
          optimizer - Adam optimizer
    """
    total = 0
    correct = 0
    train_losses = []
    test_losses = []
    train_epoch_accuracy = 0

    test_epoch_accuracy = 0
    
    model.train()
    model.to(DEVICE)
    
    for image, points in trainloader:
        image = image.to(DEVICE)
        points = points.to(DEVICE)
        
        optimizer.zero_grad()

        output = model(image)

        loss = criterion(output.float(), points.float())
        train_losses.append(loss.item())

        loss.backward()
        optimizer.step()

    train_epoch_loss = np.average(train_losses)
    
    total = 0
    correct = 0
    
    model.eval()
    with torch.no_grad():
        for image, points in testloader:
            image = image.to(DEVICE)
            points = points.to(DEVICE)

            output = model(image)
            loss = criterion(output.float(), points.float())
            test_losses.append(loss.item())
     

    test_epoch_loss = np.average(test_losses)
        
    
    return train_epoch_loss, test_epoch_loss

In [None]:
def train_model(trainloader, testloader):
    """
    returns losses (train and val), accuracies (train and val), trained_model
    params: trainloader = train dataset
            testloader = validation dataset
    """
    flag = False
    model = VGG16Model().to(DEVICE)
    criterion = nn.MSELoss().to(DEVICE)
    
    optimizer = torch.optim.Adam(model.parameters(), lr = LR)
    
    train_loss = []
    test_loss = []

    
    early_stop = EarlyStopping(patience=PATIENCE,path=CHECKPOINT_DIR+'/early_stopping_vgg16_model.pth')
    
    for epoch in range(EPOCHS):
        print("Running Epoch {}".format(epoch+1))
        start = time.time()
        epoch_train_loss, epoch_test_loss = train_loop(model, trainloader, testloader, criterion, optimizer)
        train_loss.append(epoch_train_loss)
        test_loss.append(epoch_test_loss)

        print("Time taken: {:.2f}".format((time.time()-start)/60.))
        print("Training loss: {0:.4f}  Testing loss: {1:0.4f}".format(epoch_train_loss, epoch_test_loss))
        print("--------------------------------------------------------")
        
        early_stop(epoch_test_loss, model)
    
        if early_stop.early_stop:
            print("Early stopping")
            flag = True
            break
        
        if (epoch+1)%5 == 0:
            torch.save(model.state_dict(), CHECKPOINT_DIR + "/vgg16_epoch_" + str(epoch+1) + ".pth")

    print("Training completed!")
    losses = [train_loss, test_loss]
 
    
    return losses, model, flag

In [None]:


    
def draw_training_curves(train_losses, test_losses, curve_name):
    """
    plots training and testing loss/accuracy curves
    params: train_losses = training loss
            test_losses = validation loss
            curve_name = loss or accuracy
    """
    
    plt.clf()
        
    plt.xlim([0,EPOCHS])
    plt.plot(train_losses, label='Training {}'.format(curve_name))
    plt.plot(test_losses, label='Testing {}'.format(curve_name))
    plt.legend(frameon=False)
    plt.savefig(VIS_RESULTS_PATH + "/{}_vgg16.png".format(curve_name))


    
def get_transformations(flag):
    """
    returns series of augmentations and transformations
    params: flag = train/test
    """
    transfrms = []
        
    if flag == "train":
        transfrms.append(torchvision.transforms.ColorJitter((1.2, 2.0)))
        transfrms.append(torchvision.transforms.RandomHorizontalFlip(p=0.5))
        
    transfrms.append(torchvision.transforms.ToTensor())   
    
    return torchvision.transforms.Compose(transfrms)


def get_data_loader(train_dataset, test_dataset):
    """
    returns train/test/val dataloaders
    """

    trainloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)
    testloader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle=True)

    return trainloader, testloader
    

In [None]:

train_transformation = get_transformations("train")
test_transformation = get_transformations("test")

data, label = load_dataset()


if EVAL:
    train_dataset = DatasetLoader(data['train'], label['train'], train_transformation)
    test_dataset = DatasetLoader(data['val'], label['val'], test_transformation)
else:
    train_dataset = DatasetLoader(data['train'] + data['val'], label['train'] + label['val'], train_transformation)
    test_dataset = DatasetLoader(data['test'], label['test'], test_transformation)
    

train_loader, test_loader = get_data_loader(train_dataset, test_dataset)

    
# train model
losses, accuracies, model = train_model(train_loader, val_loader)

# # plot trained metrics
# loss_curve = "loss"
# draw_training_curves(losses[0], losses[1],loss_curve)


 