#### Required Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import models, transforms
import matplotlib
import matplotlib.pyplot as plt
import time
import os
import copy
import random
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torchsummary import summary
from collections import OrderedDict
import re

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box, compute_ts_road_map
from modelzoo import *
from simclr_transforms import *
import contrastive_loss

In [2]:
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

seed = 8964
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print("Current Device - %s" % device)

if torch.cuda.is_available():
    print("CUDA Device Count - %s" % torch.cuda.device_count())
    print("CUDA Device Name - %s" % torch.cuda.get_device_name())
    print("CUDA Device Memory - %0.2f GB"%(float(torch.cuda.get_device_properties(0).total_memory)/1024**3))

Current Device - cuda:0
CUDA Device Count - 1
CUDA Device Name - Tesla P40
CUDA Device Memory - 22.38 GB


#### Simclr Model Summary

In [3]:
model = simclr_model()
inp = torch.rand([3,3,256,306])
model.summarize(inp)

Class: simclr_model
Passed Input Size:torch.Size([3, 3, 256, 306])
----
     Class: encoder
     resnet_style: 18, pretrained: False
     Passed Input Size:torch.Size([3, 3, 256, 306])
     Output Size:torch.Size([3, 512, 8, 8])
----
Hidden state shape: torch.Size([3, 512, 8, 8])
Pooled Hidden state shape: torch.Size([3, 512, 1, 1])
Reshaped Hidden state shape: torch.Size([3, 512])
Output shape:torch.Size([3, 16])


### Setting data details

In [6]:
# transform1
# transform = transforms.Compose([
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
# ])


# transform2
# transform = transforms.Compose([
#     transforms.RandomGrayscale(p=0.5),
#     transforms.RandomHorizontalFlip(),
#     get_color_distortion(s=0.5),
#     RandomGaussianBluring(kernel_size=5),
#     transforms.ToTensor(),
# ])


# transform3
eval_transform = torchvision.transforms.ToTensor()

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    get_color_distortion(s=0.5),
    RandomGaussianBluring(kernel_size=5),
    transforms.ToTensor(),
])

image_folder = '../data'
annotation_csv = '../data/annotation.csv'

unlabeled_scene_index = np.arange(106)

start_index = 0
end_index = 106

# V1 = (21,5,2)
# V2 = (26,1,1)
V3 = (90,16,0)

n_train_samples , n_val_samples , n_test_samples = V3

assert n_train_samples + n_val_samples + n_test_samples == 106 #No of labeled samples

scene_buffer_length = 128


train_unlabeled_scene_index = np.arange(start_index, start_index+n_train_samples)
val_unlabeled_scene_index = np.arange(start_index+n_train_samples,start_index+n_train_samples+n_val_samples)
test_unlabeled_scene_index = np.arange(start_index+n_train_samples+n_val_samples, end_index)

train_unlabeled_set = UnlabeledDataset(image_folder=image_folder, scene_index=train_unlabeled_scene_index, first_dim='image', transform=train_transform, scene_buffer_length=scene_buffer_length)


val_unlabeled_set = UnlabeledDataset(image_folder=image_folder, scene_index=val_unlabeled_scene_index, first_dim='image', transform=train_transform, scene_buffer_length=scene_buffer_length)

if n_test_samples != 0:

    test_unlabeled_set = UnlabeledDataset(image_folder=image_folder, scene_index=test_unlabeled_scene_index, first_dim='image', transform=train_transform, scene_buffer_length=scene_buffer_length)

else:
    
    test_unlabeled_set = None



In [7]:
# from importlib import reload
# import contrastive_loss
# reload(contrastive_loss)


In [8]:
def prepare_experiment(model_tag, batch_size = 64, resnet_style = '18', weights ='random', model_type ='simclr', checkpoint_folder = '/scratch/sc6957/dlproject/checkpoints/'):

    train_loader = torch.utils.data.DataLoader(train_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
    val_loader = torch.utils.data.DataLoader(val_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
    if test_unlabeled_set is not None:
        test_loader = torch.utils.data.DataLoader(test_unlabeled_set, batch_size=1, shuffle=True, num_workers=4, collate_fn=collate_fn)
    else:
        test_loader = None
    
    dataloaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}

    if weights == 'random':
        pretrained = False
    elif weights == 'imagenet':
        pretrained = True

    print('Model Type: {0}'.format(model_type))

    model = simclr_model(resnet_style=resnet_style, pretrained=pretrained)
    
    sample_input = torch.rand([3,3,256,306])
        
    print('{} Model Summary'.format(model_type))
    model.summarize(sample_input)

    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97)
    
    criterion = contrastive_loss.contrastive_loss(tau=0.1, normalize=True)

    today = datetime.now().date()
    
    checkpoint_tag = '{0}-{1}{2}'.format(today.day,today.month, model_tag)
    checkpoint_path = checkpoint_folder + 'checkpoint_{0}.pth.tar'.format(checkpoint_tag)
    best_checkpoint_path = checkpoint_folder + 'checkpoint_{0}_best.pth.tar'.format(checkpoint_tag)

    return model, optimizer, scheduler, criterion, dataloaders, checkpoint_path, best_checkpoint_path

#### Training and Validation

In [9]:
def train(num_epochs, model, model_type, model_tag, optimizer, scheduler, criterion, restore, threshold, dataloaders, checkpoint_path, best_checkpoint_path, tensorboard_log_dir):

    if restore:
        print('Checking to restore...')
        if os.path.isfile(checkpoint_path):
            print('Restoring checkpoint from {}'.format(checkpoint_path))
            state = torch.load(checkpoint_path)
            epoch = state['epoch']
            best_loss = state['best_loss']
            best_ts = state['best_ts']
#             new_state_dict = OrderedDict((re.sub('encoder','encoder_after_resnet',k) if 'mmd_encoder.' in k else k, v) for k, v in state['state_dict'].items())
#             model.load_state_dict(new_state_dict)
            model.load_state_dict(state['state_dict'])
            optimizer.load_state_dict(state['optimizer'])
            scheduler.load_state_dict(state['scheduler'])
        else:
            print('No checkpoint found at {}'.format(checkpoint_path))
            epoch = 0
            best_loss = np.inf
            best_ts = 0
    else:
        print('Not checking to restore')
        epoch = 0
        best_loss = np.inf
        best_ts = 0
            
    writer = SummaryWriter(log_dir=tensorboard_log_dir, comment=model_tag)

    while epoch < num_epochs:
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_acc = 0.0
            
            # Iterate over data.
            for i, temp_batch in tqdm(enumerate(dataloaders[phase])):
                x_i, x_j  = temp_batch
                x_i = torch.stack(x_i).to(device)
                x_j = torch.stack(x_j).to(device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    if model_type == 'simclr':
                        loss, batch_acc = criterion(model(x_i),model(x_j))

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    else:
                        pass

                running_loss += loss.item()#*batch_size
                running_acc += batch_acc.item()
                # tensorboard logging
                if phase == 'train':
                    writer.add_scalar(phase+'_loss', loss.item(), epoch * len(train_unlabeled_set) / dataloaders[phase].batch_size + i)
                    writer.add_scalar(phase+'_acc', batch_acc.item(), epoch * len(train_unlabeled_set) / dataloaders[phase].batch_size + i)
                    if model_type == 'simclr':
                        pass
                # statistics
            mean_batch_loss = running_loss/len(dataloaders[phase])
            mean_batch_acc = running_acc/len(dataloaders[phase])
            
            print(phase, 'mean_batch_loss:', round(mean_batch_loss,4))
            print(phase, 'total_loss:', round(running_loss,4))
            print(phase, 'mean_batch_acc:', round(mean_batch_acc,4))
            writer.add_scalar(phase+'_mean_batch_loss', mean_batch_loss, epoch)
            writer.add_scalar(phase+'_mean_batch_acc', mean_batch_acc, epoch)


        # Saving best loss model so far
        if running_loss < best_loss:
            best_loss = running_loss
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_loss': best_loss,
                'best_ts': best_ts
                }, best_checkpoint_path)
            print('best_loss model after %d epoch saved...' % (epoch+1))
            
        # Saving best ts model so far
        if mean_batch_acc > best_ts:
            best_ts = mean_batch_acc
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_loss': best_loss,
                'best_ts': best_ts
                }, best_checkpoint_path)
            print('best_ts model after %d epoch saved...' % (epoch+1))

        # save model per epoch
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_loss': best_loss,
            'best_ts': best_ts
            }, checkpoint_path)
        print('model after %d epoch saved...' % (epoch+1))
        epoch += 1

    writer.close()

### Running Experiment

**Get Parameters**

In [10]:
checkpoint_folder = '/scratch/sc6957/dlproject/checkpoints/'
tensorboard_log_dir = '/scratch/sc6957/dlproject/tb_logs'

batch_size = scene_buffer_length
resnet_style = '18'
weights = 'random'
model_type ='simclr'
num_epochs = 20
threshold = 0.5
restore = False # Restore from checkpoint, if checkpoint exists

model_tag = '_Batch:{}_Model:{}'.format(batch_size,model_type+resnet_style+weights)
if test_unlabeled_set is None:
    model_tag += '_dataV3'

model_tag += '_scenewise_16dim_acc'
print('model_tag:{}'.format(model_tag))
    
parameters = prepare_experiment(model_tag = model_tag, batch_size = batch_size, resnet_style = resnet_style, weights = weights, model_type = model_type, checkpoint_folder = checkpoint_folder)

model, optimizer, scheduler, criterion, dataloaders, checkpoint_path, best_checkpoint_path = parameters   

model_tag:_Batch:128_Model:simclr18random_dataV3_scenewise_16dim_acc
Model Type: simclr
simclr Model Summary
Class: simclr_model
Passed Input Size:torch.Size([3, 3, 256, 306])
----
     Class: encoder
     resnet_style: 18, pretrained: False
     Passed Input Size:torch.Size([3, 3, 256, 306])
     Output Size:torch.Size([3, 512, 8, 8])
----
Hidden state shape: torch.Size([3, 512, 8, 8])
Pooled Hidden state shape: torch.Size([3, 512, 1, 1])
Reshaped Hidden state shape: torch.Size([3, 512])
Output shape:torch.Size([3, 16])


**Train**

In [None]:
train(num_epochs, model, model_type, model_tag, optimizer, scheduler,  criterion, restore, threshold, dataloaders, checkpoint_path, best_checkpoint_path, tensorboard_log_dir)

Not checking to restore
Epoch 0/19
----------


0it [00:00, ?it/s]

### Evaluation & Analysis

In [11]:
def evaluate(model, model_type, model_path, dataloaders, phase, display_images = False, n_batch_to_display = 3, threshold = 0.5):
    
    assert phase != 'unlabeled'
    
    print('Estimating performance on {} set'.format(phase))
    print('Restoring Best checkpoint from {}'.format(model_path))
    state = torch.load(model_path)
    epoch = state['epoch']
#     new_state_dict = OrderedDict((re.sub('encoder','encoder_after_resnet',k) if 'mmd_encoder.' in k else k, v) for k, v in state['state_dict'].items())
#     model.load_state_dict(new_state_dict)
    model.load_state_dict(state['state_dict'])
    
    model.eval()

    total = 0.0
    total_loss = 0.0
    
    for i, temp_batch in tqdm(enumerate(dataloaders[phase])):
        if display_images and i == n_batch_to_display:
            break
        total += len(temp_batch[0])
        x_i, x_j  = temp_batch
        x_i = torch.stack(x_i).to(device)
        x_j = torch.stack(x_j).to(device)
        if model_type == 'simclr':
            loss = criterion(model(x_i),model(x_j))
        
        total_loss += loss.item() 


    total_loss /= len(dataloaders[phase])
    print('Total samples: {}, Total Threat Score: {}'.format(total,total*total_loss))
    print('Mean Loss is: {}'.format(total_loss))
    
    

### Verifying Performance on Val set

In [11]:
best_checkpoint_path = '/scratch/sc6957/dlproject/checkpoints/checkpoint_6-5_Batch:64_Model:simclr18random_dataV3_scenewise_64dim_acc.pth.tar'
state = torch.load(best_checkpoint_path)
state['epoch'],state['best_loss'],state['best_ts']

(20, 19.265180371701717, 99.83597883597884)

In [12]:
# evaluate(model, model_type, best_checkpoint_path, dataloaders,  'test', display_images = True , n_batch_to_display = 10, threshold = 0.5)
evaluate(model, model_type, best_checkpoint_path, dataloaders,  'val', display_images = False , n_batch_to_display = 1, threshold = 0.5)
# evaluate(model, model_type, best_checkpoint_path, dataloaders,  'train', display_images = True , n_batch_to_display = 1, threshold = 0.5)

Estimating performance on val set
Restoring Best checkpoint from /scratch/sc6957/dlproject/checkpoints/checkpoint_4-5_Batch:64_Model:simclr18random_dataV3transforms3.pth.tar


KeyboardInterrupt: 