#### Required Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import models, transforms
import matplotlib
import matplotlib.pyplot as plt
import time
import os
import copy
import random
from tensorboardX import SummaryWriter
from tqdm import tqdm
from torchsummary import summary

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box, compute_ts_road_map
from modelzoo import *

#### VAE Model Summary

In [2]:
model = vae()
inp = torch.rand([3,6,3,256,306])
model.summarize(inp)

Class: vae
Passed Input Size:torch.Size([3, 6, 3, 256, 306])
----
     Class: encoder
     resnet_style: 18, pretrained: False
     Passed Input Size:torch.Size([3, 3, 256, 306])
     Output Size:torch.Size([3, 512, 8, 8])
----
Number of encoded states: 6, each of size: torch.Size([3, 512, 8, 8])
Concatenated encoded states shape: torch.Size([3, 3072, 8, 8])
----
     Class: encoder_after_resnet
     Passed Input Size:torch.Size([3, 3072, 8, 8])
     Convolved Encoded state shape: torch.Size([3, 512, 4, 4])
     Output Mean Size:torch.Size([3, 4096])
     Output Var Size:torch.Size([3, 4096])
----
Output Mean Size:torch.Size([3, 4096])
Output Var Size:torch.Size([3, 4096])
Reparameterized Hidden State size: torch.Size([3, 4096])
----
     Class: vae_decoder
     Passed Input Size:torch.Size([3, 4096])
     Input recast into shape: torch.Size([3, 64, 8, 8])
     Output Size:torch.Size([3, 800, 800])
----
Output Size:torch.Size([3, 800, 800])


#### AE Model Summary

In [3]:
model = autoencoder()
inp = torch.rand([3,6,3,256,306])
model.summarize(inp)

Class: autoencoder
Passed Input Size:torch.Size([3, 6, 3, 256, 306])
----
     Class: encoder
     resnet_style: 18, pretrained: False
     Passed Input Size:torch.Size([3, 3, 256, 306])
     Output Size:torch.Size([3, 512, 8, 8])
----
Number of encoded states: 6, each of size: torch.Size([3, 512, 8, 8])
Concatenated Hidden State size: torch.Size([3, 3072, 8, 8])
----
     Class: decoder
     Passed Input Size:torch.Size([3, 3072, 8, 8])
     Output Size:torch.Size([3, 800, 800])
----
Output Size:torch.Size([3, 800, 800])


#### Experiment and Model details

In [9]:
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

seed = 8964
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


image_folder = '../data'
annotation_csv = '../data/annotation.csv'

#unlabeled_scene_index = np.arange(106)
#labeled_scene_index = np.arange(106, 134)
labeled_scene_index = np.arange(106, 132)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter()

transform = torchvision.transforms.ToTensor()

labeled_set = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=labeled_scene_index,
                                  transform=transform,
                                  extra_info=True
                                 )
train_val_split_ratio = 0.85
n=len(labeled_set)
n_train=int(train_val_split_ratio*n)+1
n_val=int((1-train_val_split_ratio)*n)
assert n_train+n_val==n

threshold = 0.5
train_set, val_set = torch.utils.data.random_split(labeled_set, [n_train, n_val])

batch_size = 16

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
dataloaders = {'train': train_loader, 'val': val_loader}


model_types = ['ae','vae']
model_choice = 0
model_type = model_types[model_choice]

print('Model Type: {0}'.format(model_type))

resnet_style = '18'
weights = 'random'

if weights == 'random':
    pretrained = False
elif weights == 'imagenet':
    pretrained = True
elif weights == 'ssl':
    pretrained = False
    pass

if model_type == 'ae':
    model = autoencoder(resnet_style=resnet_style,pretrained=pretrained)
elif model_type == 'vae':
    model = vae(resnet_style=resnet_style,pretrained=pretrained)

sample_input = torch.rand([3,6,3,256,306])
print('{} Model Summary'.format(model_type))
model.summarize(sample_input)

model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97)
criterion = nn.BCELoss()


num_epochs = 40
restore = True
checkpoint_tag = '4-24-res{0}-{1}'.format(resnet_style,weights)
checkpoint_path = './checkpoints/{0}_checkpoint_{1}.pth.tar'.format(model_type,checkpoint_tag)
best_checkpoint_path = './checkpoints/{0}_checkpoint{1}_best.pth.tar'.format(model_type,checkpoint_tag)

if restore:
    print('Checking to restore...')
    if os.path.isfile(checkpoint_path):
        print('Restoring checkpoint from {}'.format(checkpoint_path))
        state = torch.load(checkpoint_path)
        epoch = state['epoch']
        model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        scheduler.load_state_dict(state['scheduler'])
    else:
        print('No checkpoint found at {}'.format(checkpoint_path))
        epoch = 0
else:
    print('Not checking to restore')
    epoch = 0

best_loss = np.inf

Model Type: ae
ae Model Summary
Class: autoencoder
Passed Input Size:torch.Size([3, 6, 3, 256, 306])
----
     Class: encoder
     resnet_style: 18, pretrained: False
     Passed Input Size:torch.Size([3, 3, 256, 306])
     Output Size:torch.Size([3, 512, 8, 8])
----
Number of encoded states: 6, each of size: torch.Size([3, 512, 8, 8])
Concatenated Hidden State size: torch.Size([3, 3072, 8, 8])
----
     Class: decoder
     Passed Input Size:torch.Size([3, 3072, 8, 8])
     Output Size:torch.Size([3, 800, 800])
----
Output Size:torch.Size([3, 800, 800])
Checking to restore...
No checkpoint found at ./checkpoints/ae_checkpoint_4-24-res18-random.pth.tar


#### Training and Validation

In [None]:
while epoch < num_epochs:
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            scheduler.step()
            model.train()  # Set model to training mode
        else:
            model.eval()  # Set model to evaluate mode

        running_loss = 0.0
        threat_score = 0.0

        # Iterate over data.
        for i, temp_batch in tqdm(enumerate(dataloaders[phase])):
            samples, targets, road_images, extras  = temp_batch
            samples = torch.stack(samples).to(device)
            road_images = torch.stack(road_images).to(device)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                if model_type == 'ae':
                    pred_maps = model(samples)
                    loss = criterion(pred_maps, road_images.float())
                elif model_type == 'vae':
                    pred_maps, mu, logvar = model(samples,phase == 'train')
                    loss, CE, KLD = loss_function(pred_maps, road_images, mu, logvar)
                
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                else:
                    for pred_map,road_image in zip(pred_maps,road_images):
                        ts_road_map = compute_ts_road_map(pred_map > threshold, road_image)
                        threat_score += ts_road_map

            running_loss += loss.item()#*batch_size

            # tensorboardX logging
            if phase == 'train':
                writer.add_scalar(phase+'_loss', loss.item(), epoch * len(train_set) / batch_size + i)
                if model_type == 'vae':
                    writer.add_scalar(phase+'_loss_CE', CE.item(), epoch * len(train_set) / batch_size + i)
                    writer.add_scalar(phase+'_loss_KLD', KLD.item(), epoch * len(train_set) / batch_size + i)

            # statistics
        if phase == 'train':
            running_loss = running_loss / len(train_set) # per batch, per sample
            print(phase, 'running_loss:', running_loss)
        else:
            running_loss = running_loss / len(val_set)
            print(phase,'running_loss:', running_loss, 'cumulative_threat_score:', threat_score.item(), 'val_len:',len(val_set), 'mean_threat_score:',threat_score.item() / len(val_set))#, iou / len(val_set))
            writer.add_scalar(phase+'_ts', threat_score.item()/len(val_set), (epoch + 1) * len(train_set) / batch_size)

    # Saving best model so far
    if running_loss < best_loss:
        best_loss = running_loss
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
            }, best_checkpoint_path)
        print('best model after %d epoch saved...' % (epoch+1))

    # save model per epoch
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
        }, checkpoint_path)
    print('model after %d epoch saved...' % (epoch+1))
    epoch += 1

writer.close()

Epoch 0/39
----------


175it [01:53,  1.54it/s]

train running_loss: 0.025727465293993855



31it [00:15,  1.95it/s]


val running_loss: 0.022149322838987205 cumulative_threat_score: 348.3143310546875 val_len: 491 mean_threat_score: 0.7093978229219705
best model after 1 epoch saved...
model after 1 epoch saved...
Epoch 1/39
----------


175it [01:49,  1.61it/s]

train running_loss: 0.018870074097301202



31it [00:15,  2.05it/s]


val running_loss: 0.017541053460231867 cumulative_threat_score: 372.8281555175781 val_len: 491 mean_threat_score: 0.7593241456569819
best model after 2 epoch saved...
model after 2 epoch saved...
Epoch 2/39
----------


175it [01:48,  1.61it/s]

train running_loss: 0.013742169919733085



31it [00:15,  2.03it/s]


val running_loss: 0.016348334685605315 cumulative_threat_score: 387.1137390136719 val_len: 491 mean_threat_score: 0.7884190203944438
best model after 3 epoch saved...
model after 3 epoch saved...
Epoch 3/39
----------


175it [01:48,  1.61it/s]

train running_loss: 0.010811657071969667



31it [00:15,  2.03it/s]


val running_loss: 0.013942611860889038 cumulative_threat_score: 396.41058349609375 val_len: 491 mean_threat_score: 0.8073535305419425
best model after 4 epoch saved...
model after 4 epoch saved...
Epoch 4/39
----------


175it [01:48,  1.61it/s]

train running_loss: 0.008639877202587881



31it [00:15,  2.03it/s]


val running_loss: 0.011291452647469438 cumulative_threat_score: 412.43194580078125 val_len: 491 mean_threat_score: 0.8399835963356034
best model after 5 epoch saved...
model after 5 epoch saved...
Epoch 5/39
----------


175it [01:49,  1.59it/s]

train running_loss: 0.007388100750356431



31it [00:15,  2.04it/s]


val running_loss: 0.011425864672223807 cumulative_threat_score: 413.3017578125 val_len: 491 mean_threat_score: 0.8417551075610998
model after 6 epoch saved...
Epoch 6/39
----------


175it [01:49,  1.60it/s]

train running_loss: 0.006844504324709277



31it [00:15,  2.05it/s]


val running_loss: 0.010210585706467056 cumulative_threat_score: 421.17572021484375 val_len: 491 mean_threat_score: 0.8577916908652622
best model after 7 epoch saved...
model after 7 epoch saved...
Epoch 7/39
----------


53it [00:34,  1.62it/s]

#### Testing

In [8]:
test_labeled_scene_index = np.arange(132, 134)
# The labeled dataset can only be retrieved by sample.
# And all the returned data are tuple of tensors, since bounding boxes may have different size
# You can choose whether the loader returns the extra_info. It is optional. You don't have to use it.
test_labeled_set = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=test_labeled_scene_index,
                                  transform=transform,
                                  extra_info=True
                                 )
testloader = torch.utils.data.DataLoader(test_labeled_set, batch_size=1, shuffle=True, num_workers=4, collate_fn=collate_fn)

In [41]:
print(model_type)
print('Restoring Best checkpoint from {}'.format(best_checkpoint_path))
state = torch.load(best_checkpoint_path)
epoch = state['epoch']
model.load_state_dict(state['state_dict'])

model.eval()

threat_score = 0.0
total = 0.0
display_images = False

for i, temp_batch in enumerate(testloader):
#     if i == 30:
#         break
        
    total += len(temp_batch[0])
    samples, targets, road_images, extra = temp_batch
    samples = torch.stack(samples)

    if model_type == 'ae':
        pred_maps = model(samples)
    elif model_type == 'vae':
        pred_maps, mu, logvar = model(samples.to(device),is_training=False)
    
    ts_road_map = compute_ts_road_map(pred_maps[0].to(device) > threshold, road_images[0].to(device))
    threat_score += ts_road_map
    if display_images:
        print('Test Sample: {}'.format(i))
        plt.imshow(torchvision.utils.make_grid(samples[0], nrow=3).numpy().transpose(1, 2, 0))
            # CAM_FRONT_LEFT, CAM_FRONT, CAM_FRONT_RIGHT, CAM_BACK_LEFT, CAM_BACK, CAM_BACK_RIGHT
#     plt.axis('off');
        fig, (ax1, ax2) = plt.subplots(1, 2)
        fig.suptitle('Road Map Comparison')
        ax1.imshow(road_images[0].detach().cpu(), cmap='binary');
        ax1.set_title('Original Road Map')
        ax2.imshow((pred_maps[0] > threshold).detach().cpu(), cmap='binary');
        ax2.set_title('Predicted Road Map')
        plt.show()
        print('-'*20)
    
threat_score /= total
print('Total samples: {}, Total Threat Score: {}'.format(total,total*threat_score))
print('Mean Threat Score is: {}'.format(threat_score))

vae
Restoring Best checkpoint from ./checkpoints/vae_checkpoint_random_132_p100_best.pth.tar
Total Threat Score is: 0.8383897542953491


### Verifying Performance on Val set

In [39]:
phase = 'val'
threat_score = 0.0
total = 0.0
display_images = False

for i, temp_batch in enumerate(dataloaders[phase]):
#     if i==5:
#         break
    total += len(temp_batch[0])
    samples, targets, road_images, extras  = temp_batch
    samples = torch.stack(samples).to(device)
    road_images = torch.stack(road_images).to(device)

    with torch.set_grad_enabled(False):
        if model_type == 'ae':
            pred_maps = model(samples)
            loss = criterion(pred_maps, road_images.float())
        elif model_type == 'vae':
            pred_maps, mu, logvar = model(samples,False)
            loss, CE, KLD = loss_function(pred_maps, road_images, mu, logvar)

        for pred_map,road_image in zip(pred_maps,road_images):
            ts_road_map = compute_ts_road_map(pred_map > threshold, road_image)
            threat_score += ts_road_map

        if display_images:
            plt.imshow(torchvision.utils.make_grid(samples[0], nrow=3).numpy().transpose(1, 2, 0))
            fig, (ax1, ax2) = plt.subplots(1, 2)
            fig.suptitle('Road Map Comparison')
            ax1.imshow(road_image.detach().cpu(), cmap='binary');
            ax1.set_title('Original Road Map')
            ax2.imshow((pred_map > threshold).detach().cpu(), cmap='binary');
            ax2.set_title('Predicted Road Map')
            plt.show()
            print('-'*20)
                
threat_score /= total
print('Total samples: {}, Total Threat Score: {}'.format(total,total*threat_score))
print('Mean Threat Score is: {}'.format(threat_score))

Total Threat Score is: 0.9356419444084167


### Verifying Performance on Train set

In [45]:
phase = 'train'
threat_score = 0.0
total = 0.0
display_images = False

for i, temp_batch in enumerate(dataloaders[phase]):
    if i==5:
        break
    total += len(temp_batch[0])
    samples, targets, road_images, extras  = temp_batch
    samples = torch.stack(samples).to(device)
    road_images = torch.stack(road_images).to(device)

    with torch.set_grad_enabled(False):
        if model_type == 'ae':
            pred_maps = model(samples)
            loss = criterion(pred_maps, road_images.float())
        elif model_type == 'vae':
            pred_maps, mu, logvar = model(samples,False)
            loss, CE, KLD = loss_function(pred_maps, road_images, mu, logvar)

        for pred_map,road_image in zip(pred_maps,road_images):
            ts_road_map = compute_ts_road_map(pred_map > threshold, road_image)
            threat_score += ts_road_map

        if display_images:
            plt.imshow(torchvision.utils.make_grid(samples[0].cpu(), nrow=3).numpy().transpose(1, 2, 0))
            fig, (ax1, ax2) = plt.subplots(1, 2)
            fig.suptitle('Road Map Comparison')
            ax1.imshow(road_image.detach().cpu(), cmap='binary');
            ax1.set_title('Original Road Map')
            ax2.imshow((pred_map > threshold).detach().cpu(), cmap='binary');
            ax2.set_title('Predicted Road Map')
            plt.show()
            print('-'*20)
                
threat_score /= total
print('Total samples: {}, Total Threat Score: {}'.format(total,total*threat_score))
print('Mean Threat Score is: {}'.format(threat_score))

Total samples: 2785.0, Total Threat Score: 2664.673095703125
Mean Threat Score is: 0.9567946195602417
