In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss
from utils.run_model import train, test
from Models.basic_models import average_model
from Models.two_layer_concat import two_layer_concat

set_figsize(20, 15)

Training directory found, 36 series
Validation directory found, 6 series
Testing directory found, 10 series


In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)


validation = TAVR_3_Frame("__valid")
val_loader = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train")
train_loader = tavr_dataloader(training, batch_size=8, shuffle=True, num_workers=2)


ave_model = average_model()
model = two_layer_concat(standardize_slice=True)
loss_fn = batch_l2_loss()

using device: cpu


In [3]:
learning_rate = 1e-5
momentum = 0.90
reg = 1e-7

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=reg, nesterov=True)

In [4]:
train(model, optimizer, train_loader, val_loader, loss_fn, device, 
         epochs=3, print_every=30, print_level=3, lr_decay=0.8)


****Epoch 0 Iteration 0, loss = 304.7264
Validation loss 250.0459 over 81 frames
conv_a1.weight,   	norm: 3.9261e+00, 	update norm: 2.1340e-02 	Update/norm: 5.4355e-03
conv_a1.bias,   	norm: 2.9904e-01, 	update norm: 1.4460e-04 	Update/norm: 4.8354e-04
conv_b1.weight,   	norm: 3.7868e+00, 	update norm: 1.2698e-02 	Update/norm: 3.3531e-03
conv_b1.bias,   	norm: 2.9479e-01, 	update norm: 9.7706e-05 	Update/norm: 3.3144e-04
final.weight,   	norm: 1.5098e+00, 	update norm: 1.5807e-02 	Update/norm: 1.0470e-02
final.bias,   	norm: 9.5120e-02, 	update norm: 2.5293e-04 	Update/norm: 2.6591e-03

Iter 0... ... ... 
Iter 10... ... ... 
Iter 20... ... ... 
Iteration 30, loss = 75.6211
Validation loss 74.6139 over 81 frames
conv_a1.weight,   	norm: 3.8945e+00, 	update norm: 4.6399e-03 	Update/norm: 1.1914e-03
conv_a1.bias,   	norm: 2.9987e-01, 	update norm: 7.7487e-05 	Update/norm: 2.5841e-04
conv_b1.weight,   	norm: 3.7779e+00, 	update norm: 4.3268e-03 	Update/norm: 1.1453e-03
conv_b1.bias,   	no

In [None]:
# Reset the learning rate, stop lr decay, train 3 more epochs
optimizer.param_groups[0]['lr'] = 1e-5
train(model, optimizer, train_loader, val_loader, loss_fn, device, 
         epochs=3, print_every=30, print_level=3)


****Epoch 0 Iteration 0, loss = 46.0644
Validation loss 46.0817 over 81 frames
conv_a1.weight,   	norm: 3.8056e+00, 	update norm: 1.2793e-03 	Update/norm: 3.3617e-04
conv_a1.bias,   	norm: 3.0098e-01, 	update norm: 1.2813e-05 	Update/norm: 4.2570e-05
conv_b1.weight,   	norm: 3.7459e+00, 	update norm: 6.7003e-04 	Update/norm: 1.7887e-04
conv_b1.bias,   	norm: 2.9588e-01, 	update norm: 4.8381e-06 	Update/norm: 1.6352e-05
final.weight,   	norm: 1.0468e+00, 	update norm: 2.0218e-03 	Update/norm: 1.9314e-03
final.bias,   	norm: 1.0103e-01, 	update norm: 1.8157e-05 	Update/norm: 1.7972e-04

Iter 0... ... ... 
Iter 10... ... ... 
Iter 20... ... ... 
Iteration 30, loss = 44.9104
Validation loss 44.4603 over 81 frames
conv_a1.weight,   	norm: 3.7962e+00, 	update norm: 1.1447e-03 	Update/norm: 3.0154e-04
conv_a1.bias,   	norm: 3.0102e-01, 	update norm: 1.0483e-05 	Update/norm: 3.4825e-05
conv_b1.weight,   	norm: 3.7424e+00, 	update norm: 5.8040e-04 	Update/norm: 1.5509e-04
conv_b1.bias,   	norm

In [None]:
# One last test + visualize results on 1 validation sequence
val_seq = TAVR_Sequence("__valid")
test(model, val_loader, loss_fn, device)
test(ave_model, val_loader, loss_fn, device)

with torch.no_grad():
    model.eval()
    ave_frames = [val_seq[0][0]]
    for i in range(2,9,2):
        ave_frame = model((val_seq[0][i-2][None,:], val_seq[0][i][None,:]))
        ave_frames += [ave_frame[0][0],  val_seq[0][i]]
    ave_frames += [val_seq[0][9]]
    ave_frames_slices = []
    for f in ave_frames:
        ave_frames_slices += get_central_slices(f)
    set_figsize(6,20)
    display_grid(10, 3, ave_frames_slices)

In [None]:
train(model, optimizer, train_loader, val_loader, loss_fn, device, 
         epochs=3, print_every=30, print_level=3)