In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss
from utils.run_model import train, test, save,batch_invariant_l2_test, load, get_loss_history
from Models.basic_models import average_model, two_layer_basic, post_process
from Models.nm_layer import nm_layer_net

set_figsize(20, 15)

Training directory found, 36 series
Validation directory found, 6 series
Testing directory found, 10 series


In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

# "Pixl", "Slice", or "None"
preproc_type = "pixel"

validation = TAVR_3_Frame("__valid", preproc=preproc_type, preload=False)
val_loader = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train", preproc=preproc_type, preload=False)
train_loader = tavr_dataloader(training,batch_size=4, shuffle=True, num_workers=2)


ave_model = average_model()
model = nm_layer_net([4,8],[4,1], True)
post_proc = post_process(kind=preproc_type).to(device=device)
loss_fn = batch_l2_loss()

# CHANGE TO NAME OF JUPYTER NOTEBOOK
model_name = "Model 23 (2,2 layer, Residual+PixNorm) Run 0"

using device: cpu


In [3]:
learning_rate = 1e-3
reg = 0

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=reg)

In [18]:
LOAD = True
iteration_num = 324

if LOAD:
    load(model_name, iteration_num, model, optimizer, map_location='cpu')
    loss_history = get_loss_history(model_name)
    model.to(device=device)
    # I don't know why these lines are necessary
    # or even what the hell they do
    # but they are
    if str(device) == 'cuda':
        for state in optimizer.state.values():
            for k, v in state.items():
                state[k] = v.cuda()
else:
    loss_history = None

model loaded from model_checkpoints/Model 23 (2,2 layer, Residual+PixNorm) Run 0/Model 23 (2,2 layer, Residual+PixNorm) Run 0-324


In [35]:
val_loader = tavr_dataloader(validation, batch_size=1, shuffle=True, num_workers=5)

M23 = batch_invariant_l2_test(model, post_proc, val_loader, loss_fn, device)
ave = batch_invariant_l2_test(ave_model, post_proc, val_loader, loss_fn, device)

Validation loss 37.8831 over 81 frames
Validation loss 37.8815 over 81 frames


In [36]:
test = TAVR_3_Frame("__test", preproc=preproc_type, preload=False)
test_cons_loader = tavr_dataloader(test, batch_size=1, shuffle=True, num_workers=2)
batch_invariant_l2_test(model, post_proc, test_cons_loader, loss_fn, device)
batch_invariant_l2_test(ave_model, post_proc, test_cons_loader, loss_fn, device)

Validation loss 44.6286 over 185 frames
Validation loss 44.6268 over 185 frames


tensor(44.6268)