In [12]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss
from utils.run_model import train, test, save, load, get_loss_history
from Models.basic_models import average_model, two_layer_basic, post_process, two_d_two_layer, two_d_three_layer

set_figsize(20, 15)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
USE_GPU = False
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

# "Pixl", "Slice", or "None"
preproc_type = "pixel"

validation = TAVR_3_Frame("__valid", preproc=preproc_type)
val_loader = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train", preproc=preproc_type)
train_loader = tavr_dataloader(training,batch_size=8, shuffle=True, num_workers=2)


ave_model = average_model()
model = two_d_three_layer()
post_proc = post_process(kind=preproc_type).to(device=device)
loss_fn = batch_l2_loss()

# CHANGE TO NAME OF JUPYTER NOTEBOOK
model_name = "Model 10 (three layer - 2d) Run 0" # Forgot to change this

using device: cpu


In [14]:
learning_rate = 3e-3
momentum = 0.90
reg = 1e-7

optimizer = optim.Adam(model.parameters())

In [15]:
# Run cell to load model
LOAD = False
iteration_num = -1

if LOAD:
    
    load(model_name, iteration_num, model, optimizer)
    loss_history = get_loss_history(model_name)
    model.to(device=device)
else:
    loss_history = None

In [None]:
train(model, post_proc, optimizer, train_loader, val_loader, loss_fn, device, 
         model_name, loss_history, epochs=3, print_every=30, print_level=4, lr_decay=1)


Iteration 0, loss = 0.9672, corrected loss = 450.0585
Validation loss 402.8865 over 81 frames
model saved to model_checkpoints/Model 10 (three layer - 2d) Run 0/Model 10 (three layer - 2d) Run 0-0
conv_a1.weight,   	norm: 3.6168e+00, 	update norm: 1.4142e-02 	Update/norm: 3.9100e-03
conv_a1.bias,   	norm: 2.0042e-01, 	update norm: 2.8284e-03 	Update/norm: 1.4112e-02
conv_b1.weight,   	norm: 3.9876e+00, 	update norm: 1.4142e-02 	Update/norm: 3.5465e-03
conv_b1.bias,   	norm: 3.6043e-01, 	update norm: 2.8284e-03 	Update/norm: 7.8473e-03
conv_a2.weight,   	norm: 3.9018e+00, 	update norm: 2.3946e-02 	Update/norm: 6.1372e-03
conv_a2.bias,   	norm: 2.1735e-01, 	update norm: 2.8279e-03 	Update/norm: 1.3010e-02
conv_b2.weight,   	norm: 4.0361e+00, 	update norm: 2.3871e-02 	Update/norm: 5.9145e-03
conv_b2.bias,   	norm: 2.0664e-01, 	update norm: 2.8268e-03 	Update/norm: 1.3680e-02
conv_a3.weight,   	norm: 4.0054e+00, 	update norm: 2.2948e-02 	Update/norm: 5.7291e-03
conv_a3.bias,   	norm: 8.13

In [None]:
no_post_proc = post_process(kind="None").to(device=device)

In [None]:
# One last test + visualize results on 1 validation sequence
val_seq = TAVR_Sequence("__valid")
test(ave_model, post_proc, val_loader, loss_fn, device)
print('finished average')
test(model, post_proc, val_loader, loss_fn, device)

In [None]:
with torch.no_grad():
    model.eval()
    ave_frames = [post_proc(val_seq[0][0])]
    for i in range(2,9,2):
        ave_frame = model((val_seq[0][i-2][None,:], val_seq[0][i][None,:]))
        print(ave_frame.shape)
        ave_frames += [post_proc(ave_frame[0]),  post_proc(val_seq[0][i])]
    ave_frames += [post_proc(val_seq[0][9])]
    ave_frames_slices = []
    for f in ave_frames:
        print(f.shape)
        ave_frames_slices += get_central_slices(f)
    set_figsize(6,20)
    display_grid(10, 3, ave_frames_slices)

In [None]:
optimizer.param_groups[0]['lr']

In [None]:
for p in model.parameters():
    print(p.norm())