In [6]:
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import random

# Custom code
from load_model import *
from utils import *
from augmentation import *
from dataset import *

In [7]:
# Path where dataset is stored
root_path = '/data/ASL'

# Checkpoints get saved to './checkpoints/{run_name}_epoch{epoch}.pth.tar'
checkpoint_every_N_epochs = 15 # You put N here
run_name = '14classes' 

# Determines size of input volume into NN
desired_frames = 16
desired_size = (112, 112)

# Computation settings
num_threads = 16
use_cuda = True

# Val/Train parameters
val_frac = .1  # Fraction of dataset set aside for validation set
num_classes = 15  # Number of classes in output layer
epochs = 100
lower_lr_every_N_epochs = 20  # Divide LR by 2 every N epochs (you put N here)
batch_size = 16
learning_rate = .2

# Settings for SGD optimizer
momentum = .9
dampening = .9
weight_decay = .001

# Set this to a filepath string if you want to continue training a pretrained model
#load_checkpoint = "checkpoints/14classes_epoch30.pth"
load_checkpoint = None

In [8]:
# Load in pretrained mobilenetv2
model = load_model(name="mobilenetv2",
                  num_classes=num_classes,
                  sample_size=desired_size[0],
                  width_mult=1.0,
                  pretrain_path=None,
                  use_cuda=use_cuda,
                  transfer_learning='last_layer')

optimizer = optim.SGD(model.parameters(),
            lr=learning_rate,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

if load_checkpoint:
    print('>> Loading checkpoint: {}'.format(load_checkpoint))
    save_info = torch.load(load_checkpoint)
    start_epoch = save_info['epoch']
    optimizer.load_state_dict(save_info['optimizer'])
    model.load_state_dict(save_info['state_dict'])
    history = save_info['history']
else:
    start_epoch = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}


>> Creating model architecture for mobilenetv2
>> Successfully created model architecture for mobilenetv2.


ValueError: use_cuda was specified in generate_model(), but no cuda device was found.

In [None]:
# Verify that all layers are on cuda (if desired) and required_grad is set to true
# for the layers you want to train
for name, param in model.named_parameters():
    device = param.device
    param.requires_grad = False
    
    if name == 'module.classifier.weight':
        param.requires_grad = True
    if name == 'module.classifier.bias':
        param.requires_grad = True
        
    # Manual unfreeze the last couple layers for fine-tuning
    if name == 'module.features.18.0.weight':
        param.requires_grad = True
    if name == 'module.features.18.1.weight':
        param.requires_grad = True
    elif name == 'module.features.18.1.bias':
        param.requires_grad = True


In [None]:
# For validation dataset, pick out a fraction of videos
val_dataset = ASLDataset(root_path=root_path, desired_frames=desired_frames, desired_size=desired_size, val_frac=.1)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=min(64, len(val_dataset)), num_workers=num_threads,
            pin_memory=True)

# Feed validation metadata into the training dataset, so it knows to exclude videos found in validation set
train_dataset = ASLDataset(root_path=root_path, desired_frames=16, desired_size=(112, 112), val_metadata=val_dataset.metadata)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_threads,
            pin_memory=True)

In [None]:
# SANITY CHECK: Load in random training vid and make sure label matches video
i = np.random.randint(0, len(val_dataset))
x, y = val_dataset[i]
%matplotlib notebook
inv_codex = {v:k for k,v in train_dataset.codex.items()}
label = inv_codex[y]
print("Playing animation for '{}'.".format(label))
animate_movie(np.array(x))


In [None]:
def train(model, dataloader, criterion, optimizer, history):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('train_loss', ':.4e')
    accuracies = []
    for i in range(num_classes):
        accuracies.append(AverageMeter('train_acc'+str(i+1), ':6.2f'))
    
    model.train()
    start_time = time.time()
    for i, (x, y) in enumerate(dataloader):

        x = x.float()
        y = y.long()
        
        if torch.cuda.is_available():
            x = x.to(device)
            y = y.to(device)
            
        yhat = model(x)
        loss = criterion(yhat, y)
        acc = calculate_accuracy(yhat, y, topk=tuple(range(1,num_classes+1)))
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update tracking parameters
        losses.update(loss.item(), x.size(0))
        batch_time.update(time.time() - start_time)
        for i in range(num_classes):
            accuracies[i].update(acc[i], x.size(0))

        # Print progress
        print('{num_batch}/{total_batches} batches complete. train_loss: {loss.avg:.3f}. train_acc: {top1.avg:.3f}. Time: {batch_time.avg:.1f} sec. '.format(
                  num_batch=i+1,
                  total_batches=len(train_loader),
                  batch_time=batch_time,
                  loss=losses,
                  top1=accuracies[0]), end='\r')
        
    avg_accuracies = [acc.avg for acc in accuracies]
    return avg_accuracies, losses.avg

def save_checkpoint(state, filename):
    torch.save(state, filename)

def validate(model, dataloader, criterion):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    losses = AverageMeter('val_loss', ':.4e')
    accuracies = []
    for i in range(num_classes):
        accuracies.append(AverageMeter('val_acc'+str(i+1), ':6.2f'))
    
    model.eval()
    start_time = time.time()
    for i, (x, y) in enumerate(dataloader):

        x = x.float()
        y = y.long()
        
        if torch.cuda.is_available():
            x = x.to(device)
            y = y.to(device)
            
        yhat = model(x)
        loss = criterion(yhat, y)
        acc = calculate_accuracy(yhat, y, topk=tuple(range(1,num_classes+1)))
        
        # Update tracking parameters
        losses.update(loss.item(), x.size(0))
        for i in range(num_classes):
            accuracies[i].update(acc[i], x.size(0))

    avg_accuracies = [acc.avg for acc in accuracies]
    return avg_accuracies, losses.avg

In [None]:
# Calculate validation loss/acc before any training occurs, as a reference point
with torch.no_grad():
    val_acc, val_loss = validate(model, val_loader, criterion)
print('BEFORE TRAINING: val_loss = {:.4f}, val_acc = {:.4f} (top 50% = {:.4f})'.format(val_loss, val_acc[0], val_acc[int(num_classes/2)]))

# Main training loop
start_time = time.time()
for epoch in range(start_epoch, epochs):
    
    epoch_start = time.time()

    if epoch % lower_lr_every_N_epochs == 0:
        new_lr = learning_rate/2**int(epoch/lower_lr_every_N_epochs)
        adjust_learning_rate(new_lr, optimizer)
        print('Adjusted learning rate to: {}'.format(new_lr))
    
    train_acc, train_loss = train(model, train_loader, criterion, optimizer, history)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    
    # validate
    with torch.no_grad():
        val_acc, val_loss = validate(model, val_loader, criterion)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
    
    time_since_start = time.time() - start_time
    avg_time_per_epoch = time_since_start/(epoch+1)
    time_remaining = epochs*avg_time_per_epoch - time_since_start
    print('\n Completed epoch {}/{}. val_loss: {:.3f}. val_acc: {:.3f} (top 50%: {:.3f}). ETA: {:.1f} mins remaining.'.format(epoch+1, epochs, val_loss, val_acc[0], val_acc[int(num_classes/2)], time_remaining/60))
    
    if (epoch + 1) % checkpoint_every_N_epochs == 0:
        filename = 'checkpoints/'+run_name+'_epoch'+str(epoch+1)+'.pth'
        print('Saving checkpoint to: {}'.format(filename))
        save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'val_acc1': val_acc, 'optimizer' : optimizer.state_dict(), 'history':history}, filename=filename)

        

In [None]:
# Plot the loss/accuracy vs epoch

%matplotlib notebook
train_loss = history['train_loss']
val_loss = history['val_loss'] 

plt.plot(train_loss, label='Train')
plt.plot(val_loss, label='Val')
plt.legend()
plt.title('Loss vs epoch')
plt.xlabel('epoch')
plt.grid()

plt.figure(2)
train_acc = [item[0] for item in history['train_acc']]
val_acc = [item[0] for item in history['val_acc']]
plt.plot(train_acc, label='Train')
plt.plot(val_acc, label='Val')
plt.legend()
plt.title('Accuracy (top 1)')
plt.xlabel('epoch')
plt.grid()

plt.figure(3)
i = 3
train_acc = [item[i] for item in history['train_acc']]
val_acc = [item[i] for item in history['val_acc']]
plt.plot(train_acc, label='Train')
plt.plot(val_acc, label='Val')
plt.legend()
plt.title('Accuracy (top 50%)')
plt.xlabel('epoch')
plt.grid()