In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss
from utils.run_model import train, test, save, load, get_loss_history
from Models.basic_models import average_model, two_layer_basic, post_process

set_figsize(20, 15)

Training directory found, 36 series
Validation directory found, 6 series
Testing directory found, 10 series


In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

# "Pixl", "Slice", or "None"
preproc_type = "pixel"

validation = TAVR_3_Frame("__valid", preproc=preproc_type)
val_loader = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train", preproc=preproc_type)
train_loader = tavr_dataloader(training,batch_size=8, shuffle=True, num_workers=2)


ave_model = average_model()
model = two_layer_basic()
post_proc = post_process(kind=preproc_type).to(device=device)
loss_fn = batch_l2_loss()

# CHANGE TO NAME OF JUPYTER NOTEBOOK
model_name = "Model 5 (Basic+PixelNorm)"

using device: cuda


In [3]:
learning_rate = 3e-3
momentum = 0.95
reg = 1e-7

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=reg, nesterov=True)

In [4]:
LOAD = True
iteration_num = 81

if LOAD:
    load(model_name, iteration_num, model, optimizer)
    loss_history = get_loss_history(model_name)
    model.to(device=device)
else:
    loss_history = None

model loaded from model_checkpoints/Model 5 (Basic+PixelNorm)/Model 5 (Basic+PixelNorm)-81


In [5]:
# If multiple GPU
# DO NOT CALL IF ONLY 1 GPU
model = nn.DataParallel(model)

In [6]:
v_loss = test(model, post_proc, val_loader, loss_fn, device)

Validation loss 69.7419 over 81 frames


In [13]:
loss_history

{'epoch': 0,
 'iteration': 81,
 'print_every': 30,
 'train': [1.4583444595336914,
  1.386093258857727,
  1.303180456161499,
  1.1800551414489746,
  0.9898819923400879,
  0.8864600658416748,
  0.643781304359436,
  0.41178756952285767,
  0.2897035777568817,
  0.32744431495666504,
  0.41706836223602295,
  0.4686945974826813,
  0.5407416224479675,
  0.6122727394104004,
  0.5340273976325989,
  0.5399351119995117,
  0.4800082743167877,
  0.41271457076072693,
  0.3245951235294342,
  0.2616511285305023,
  0.24818198382854462,
  0.26347947120666504,
  0.2647935152053833,
  0.2880227267742157,
  0.29751893877983093,
  0.29350441694259644,
  0.3103369474411011,
  0.29251813888549805,
  0.2621140480041504,
  0.23710985481739044,
  0.23486699163913727,
  0.21325638890266418,
  0.19239456951618195,
  0.20148739218711853,
  0.2010451704263687,
  0.19279226660728455,
  0.20746038854122162,
  0.21256385743618011,
  0.2055629938840866,
  0.20884472131729126,
  0.19681207835674286,
  0.1944475919008255,
