In [1]:
import torch
import torchvision
from torch import optim
import torch.nn as nn
from torch.utils.data import DataLoader

import time
import numpy as np
import os

from utils import *
from dataset import *
from model import *

In [2]:
model = UNET(n_classes=34, padding='valid')
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        print("Detected {} GPUs! Training's about to get hella fast".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
        
    else:
        print('Detected {} GPU. Loading model onto single GPU.'.format(torch.cuda.device_count()))
    batch_size = torch.cuda.device_count() # To make sure all GPUs are utilized
    model.cuda()
else:
    batch_size = 1
    print('No GPU detected - training on cpu.')
    
cityscapes_train = Cityscapes(root='/data/Cityscapes')
cityscapes_val = Cityscapes(root='/data/Cityscapes', split='val')
cityscapes_test = Cityscapes(root='/data/Cityscapes', split='test')

train_loader = DataLoader(cityscapes_train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
val_loader = DataLoader(cityscapes_val, batch_size=batch_size, num_workers=16, pin_memory=True)

epochs = 100
criterion = pixelwise_loss  # wrapper for nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=.0002)

No GPU detected - training on cpu.


RuntimeError: Dataset not found or incomplete. Please make sure all required folders for the specified "split" and "mode" are inside the "root" directory

In [None]:
%matplotlib notebook
dataset = cityscapes_val
i = np.random.randint(low=0, high=len(dataset)-1)

print("Plotting image {} from the '{}' dataset.".format(i, dataset.split))
plot_cityscape(i, dataset=dataset, class_ids=dataset.class_ids)

In [None]:
# Run validation set before training the model 
with torch.no_grad():
    val_acc, val_iou, val_loss = run_model(model, val_loader, criterion, mode='val')
history = {'val_loss': [val_loss],
           'val_acc': [val_acc],
          'val_iou': val_iou.reshape((1, -1)),
          'train_loss': [np.nan],
           'train_acc': [np.nan],
          'train_iou': np.full(shape=(1, num_classes), fill_value=np.nan)}
print('Before training: val_loss: {:.3f}. val_acc: {:.3f}. val_iou: {:.4f}.'.format(
    val_loss, 
    val_acc,
    np.nanmean(val_iou)))
    
start_time = time.time()
for epoch in range(epochs):
    
    # train
    epoch_start = time.time()
    train_acc, train_iou, train_loss = run_model(model, train_loader, criterion, optimizer)
    history['train_acc'].append(train_acc)
    history['train_iou'] = np.concatenate((history['train_iou'], train_iou.reshape(1, -1)), axis=0)
    history['train_loss'].append(train_loss)
    
    # validate
    with torch.no_grad():
        val_acc, val_iou, val_loss = run_model(model, val_loader, criterion, mode='val')
        history['val_acc'].append(val_acc)
        history['val_iou'] = np.concatenate((history['val_iou'], val_iou.reshape(1, -1)), axis=0)
        history['val_loss'].append(val_loss)
    
    # summary
    time_since_start = time.time() - start_time
    avg_time_per_epoch = time_since_start/(epoch+1)
    time_remaining = epochs*avg_time_per_epoch - time_since_start
    print('Completed epoch {}/{}. val_loss: {:.3f}. val_acc: {:.3f}. val_iou: {:.4f}. ETA: {:.1f} mins remaining.'.format(
        epoch+1, 
        epochs, 
        val_loss, 
        val_acc,
        np.nanmean(val_iou),
        time_remaining/60))
    
    # Save model every 5 epochs
    if (epoch + 1) % 5 == 0:
        if not os.path.exists('checkpoints/'):
            os.mkdir('checkpoints/')
        filename = 'checkpoints/epoch_'+str(epoch+1)+'.pth'
        print('Saving checkpoint to: {}'.format(filename))
        state = {'epoch': epoch + 1, 
                 'state_dict': model.state_dict(),
                 'optimizer' : optimizer.state_dict(),
                 'history': history}
        torch.save(state, filename)

In [None]:
%matplotlib notebook
# Plot loss vs epochs
plt.figure(1)
plt.plot(history['val_loss'], label='Validation')
plt.plot(history['train_loss'], label='Training')
plt.legend()
plt.grid()
plt.xlabel('# epochs', fontsize=15)
plt.title('Loss', fontsize=18)

plt.figure(2)
plt.plot(history['val_acc'], label='Validation')
plt.plot(history['train_acc'], label='Training')
plt.legend()
plt.grid()
plt.xlabel('# epochs', fontsize=15)
plt.title('Total accuracy', fontsize=18)
plt.ylim([0, 1])

plt.figure(3)
for i in range(history['val_iou'].shape[1]):
    plt.plot(history['val_iou'][:, i], label=cityscapes_val.class_ids[i])
plt.legend()
plt.grid()
plt.ylim([0, 1])