In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss
from utils.run_model import train, test, save, load, get_loss_history
from Models.basic_models import average_model, two_layer_basic, post_process
from Models.nm_layer import nm_layer_net, Parallel_Residual

set_figsize(20, 15)

Training directory found, 36 series
Validation directory found, 6 series
Testing directory found, 10 series


In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

# "Pixl", "Slice", or "None"
preproc_type = "pixel"

validation = TAVR_3_Frame("__valid", preproc=preproc_type, preload=False)
val_loader = tavr_dataloader(validation, batch_size=2, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train", preproc=preproc_type, preload=False)
train_loader = tavr_dataloader(training,batch_size=2, shuffle=True, num_workers=2)


ave_model = average_model()
model = Parallel_Residual(4, [4,4],[4,1])
post_proc = post_process(kind=preproc_type).to(device=device)
loss_fn = batch_l2_loss()

# CHANGE TO NAME OF JUPYTER NOTEBOOK
model_name = "Model 25 (Wide Parallel Residual) Run 0"

using device: cuda


In [3]:
learning_rate = 1e-3
reg = 0

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=reg)

In [4]:
LOAD = False
iteration_num = -1

if LOAD:
    load(model_name, iteration_num, model, optimizer)
    loss_history = get_loss_history(model_name)
    model.to(device=device)
    # I don't know why these lines are necessary
    # or even what the hell they do
    # but they are
    if str(device) == 'cuda':
        for state in optimizer.state.values():
            for k, v in state.items():
                state[k] = v.cuda()
else:
    loss_history = None

In [None]:
# If multiple GPU
# model = nn.DataParallel(model)

In [None]:
train(model, post_proc, optimizer, train_loader, val_loader, loss_fn, device, 
         model_name, loss_history,
          epochs=2, print_every=60, print_level=4, lr_decay=1)


Iteration 0, loss = 1.4253, corrected loss = 663.2491
Validation loss 510.3521 over 81 frames
model saved to model_checkpoints/Model 25 (Wide Parallel Residual) Run 0/Model 25 (Wide Parallel Residual) Run 0-0
para_0.conv_a1.weight,   	norm: 2.8934e+00, 	update norm: 1.0336e-02 	Update/norm: 3.5723e-03
para_0.conv_a1.bias,   	norm: 2.1399e-01, 	update norm: 1.9978e-03 	Update/norm: 9.3357e-03
para_0.conv_b1.weight,   	norm: 2.7072e+00, 	update norm: 1.0320e-02 	Update/norm: 3.8121e-03
para_0.conv_b1.bias,   	norm: 1.1737e-01, 	update norm: 1.9890e-03 	Update/norm: 1.6946e-02
para_0.conv_a2.weight,   	norm: 2.8016e+00, 	update norm: 2.0533e-02 	Update/norm: 7.3292e-03
para_0.conv_a2.bias,   	norm: 1.0215e-01, 	update norm: 1.9988e-03 	Update/norm: 1.9568e-02
para_0.conv_b2.weight,   	norm: 2.7705e+00, 	update norm: 2.0444e-02 	Update/norm: 7.3791e-03
para_0.conv_b2.bias,   	norm: 1.3266e-01, 	update norm: 1.9929e-03 	Update/norm: 1.5022e-02
para_0.conv_ab1.weight,   	norm: 2.7111e+00, 	

... 0.0499... 0.0847... 0.0622
Iter 70... 0.0607... 0.0601... 0.0656
Iter 80... 0.0851... 0.0846... 0.1035
Iter 90... 0.0778... 0.1040... 0.0559
Iter 100... 0.0519... 0.0715... 0.0837
Iter 110... 0.0993... 0.0929... 0.0558
Iteration 120, loss = 0.0893, corrected loss = 41.5502
Validation loss 37.7454 over 81 frames
model saved to model_checkpoints/Model 25 (Wide Parallel Residual) Run 0/Model 25 (Wide Parallel Residual) Run 0-120
para_0.conv_a1.weight,   	norm: 2.8845e+00, 	update norm: 4.4377e-05 	Update/norm: 1.5385e-05
para_0.conv_a1.bias,   	norm: 2.3582e-01, 	update norm: 9.9711e-07 	Update/norm: 4.2282e-06
para_0.conv_b1.weight,   	norm: 2.6967e+00, 	update norm: 7.9512e-05 	Update/norm: 2.9485e-05
para_0.conv_b1.bias,   	norm: 1.3094e-01, 	update norm: 3.0227e-05 	Update/norm: 2.3084e-04
para_0.conv_a2.weight,   	norm: 2.7994e+00, 	update norm: 1.2453e-04 	Update/norm: 4.4485e-05
para_0.conv_a2.bias,   	norm: 1.0629e-01, 	update norm: 2.8480e-06 	Update/norm: 2.6794e-05
para_0.c

In [None]:
# One last test + visualize results on 1 validation sequence

test(model, post_proc, val_loader, loss_fn, device)
test(ave_model, post_proc, val_loader, loss_fn, device)

In [None]:
with torch.no_grad():
    model.eval()
    for t, (x1, y, x2, mask, max_z) in enumerate(val_loader):
        x1 = x1.to(device=device)  # move to device, e.g. GPU
        y = post_proc(y.to(device=device))
        x2 = x2.to(device=device)
        mask = mask.to(device=device)
        max_z = max_z.to(device=device)
        
        y_hat = post_proc(model((x1, x2)))
        L2_to_ground = loss_fn((y, y_hat, mask, max_z))
        L2_to_ave = loss_fn((post_proc(ave_model((x1,x2))), y_hat, mask, max_z))
        print("Batch %d. Prediction-Real dist: %.5f, Prediction-Ave dist:%.5f"%(t, L2_to_ground.item(), L2_to_ave.item()))

In [None]:
val_seq = TAVR_Sequence("__valid", preproc=preproc_type)
with torch.no_grad():
    model.eval()
    ave_frames = [post_proc(val_seq[0][0].to(device=device))]
    for i in range(2,9,2):
        ave_frame = model((val_seq[0][i-2][None,:].to(device=device), val_seq[0][i][None,:].to(device=device)))
        ave_frames += [post_proc(ave_frame[0][0]),  post_proc(val_seq[0][i].to(device=device))]
    ave_frames += [post_proc(val_seq[0][9].to(device=device))]
    ave_frames_slices = []
    for f in ave_frames:
        ave_frames_slices += get_central_slices(f)
    set_figsize(6,20)
    display_grid(10, 3, ave_frames_slices)