In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss
from utils.run_model import train, test, save, load, get_loss_history
from Models.basic_models import average_model, two_layer_basic, post_process
from Models.two_layer_resnet import two_layer_resnet

set_figsize(20, 15)

Training directory found, 36 series
Validation directory found, 6 series
Testing directory found, 10 series


In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

# "Pixl", "Slice", or "None"
preproc_type = "slice"

validation = TAVR_3_Frame("__valid", preproc=preproc_type, preload=False)
val_loader = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train", preproc=preproc_type, preload=False)
train_loader = tavr_dataloader(training,batch_size=4, shuffle=True, num_workers=2)


ave_model = average_model()
model = two_layer_basic()
post_proc = post_process(kind=preproc_type).to(device=device)
loss_fn = batch_l2_loss()

# CHANGE TO NAME OF JUPYTER NOTEBOOK
model_name = "Model 20 (2_layer_res+SliceNorm) Run 0 [V100]"

using device: cuda


In [3]:
learning_rate = 3e-3
reg = 1e-7

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
LOAD = False
iteration_num = -1

if LOAD:
    load(model_name, iteration_num, model, optimizer)
    loss_history = get_loss_history(model_name)
    model.to(device=device)
    # I don't know why these lines are necessary
    # or even what the hell they do
    # but they are
    if str(device) == 'cuda':
        for state in optimizer.state.values():
            for k, v in state.items():
                state[k] = v.cuda()
else:
    loss_history = None

In [None]:
# If multiple GPU
# model = nn.DataParallel(model)

In [5]:
train(model, post_proc, optimizer, train_loader, val_loader, loss_fn, device, 
         model_name, loss_history,
          epochs=2, print_every=30, print_level=4, lr_decay=1)


Iteration 0, loss = 84.4442, corrected loss = 1154.2776
Validation loss 1133.0815 over 81 frames
model saved to model_checkpoints/Model 20 (2_layer_res+SliceNorm) Run 0 [V100]/Model 20 (2_layer_res+SliceNorm) Run 0 [V100]-0
conv_a1.weight,   	norm: 4.0929e+00, 	update norm: 4.4091e-02 	Update/norm: 1.0773e-02
conv_a1.bias,   	norm: 2.9950e-01, 	update norm: 8.4852e-03 	Update/norm: 2.8331e-02
conv_b1.weight,   	norm: 4.2055e+00, 	update norm: 4.4091e-02 	Update/norm: 1.0484e-02
conv_b1.bias,   	norm: 3.3332e-01, 	update norm: 8.4851e-03 	Update/norm: 2.5457e-02
final.weight,   	norm: 1.4559e+00, 	update norm: 1.2000e-02 	Update/norm: 8.2423e-03
final.bias,   	norm: 1.1771e-01, 	update norm: 3.0000e-03 	Update/norm: 2.5486e-02

... 75.9345... 50.0907... 39.8378
Iter 10... 26.6186... 14.0146... 14.6056
Iter 20... 17.6530... 16.6465... 16.0619
Iteration 30, loss = 12.7612, corrected loss = 174.1712
Validation loss 166.1599 over 81 frames
model saved to model_checkpoints/Model 20 (2_layer

... 3.6607... 4.7263... 3.5886
Iter 280... 4.3976... 4.6851... 3.1586
Iter 290... 5.7807... 3.0533... 4.5724
Iteration 300, loss = 3.2529, corrected loss = 41.4472
Validation loss 47.0838 over 81 frames
model saved to model_checkpoints/Model 20 (2_layer_res+SliceNorm) Run 0 [V100]/Model 20 (2_layer_res+SliceNorm) Run 0 [V100]-300
conv_a1.weight,   	norm: 4.0223e+00, 	update norm: 3.3617e-03 	Update/norm: 8.3576e-04
conv_a1.bias,   	norm: 5.0003e-01, 	update norm: 6.7965e-04 	Update/norm: 1.3592e-03
conv_b1.weight,   	norm: 4.1666e+00, 	update norm: 5.2034e-03 	Update/norm: 1.2488e-03
conv_b1.bias,   	norm: 5.3136e-01, 	update norm: 1.8394e-03 	Update/norm: 3.4616e-03
final.weight,   	norm: 1.1054e+00, 	update norm: 9.9183e-04 	Update/norm: 8.9727e-04
final.bias,   	norm: 5.5382e-02, 	update norm: 4.6473e-05 	Update/norm: 8.3914e-04

... 4.1936... 4.3406... 5.4137
Iter 310... 3.5049... 4.2053... 5.4395
Iter 320... 4.4374
model saved to model_checkpoints/Model 20 (2_layer_res+SliceNorm) 

In [None]:
# One last test + visualize results on 1 validation sequence
val_seq = TAVR_Sequence("__valid")
test(model, post_proc, val_loader, loss_fn, device)
test(ave_model, post_proc, val_loader, loss_fn, device)

with torch.no_grad():
    model.eval()
    ave_frames = [val_seq[0][0]]
    for i in range(2,9,2):
        ave_frame = model((val_seq[0][i-2][None,:], val_seq[0][i][None,:]))
        ave_frames += [ave_frame[0][0],  val_seq[0][i]]
    ave_frames += [val_seq[0][9]]
    ave_frames_slices = []
    for f in ave_frames:
        ave_frames_slices += get_central_slices(f)
    set_figsize(6,20)
    display_grid(10, 3, ave_frames_slices)