In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss
from utils.run_model import train, test, save, load, get_loss_history
from Models.basic_models import average_model, two_layer_basic, post_process

set_figsize(20, 15)

Training directory found, 36 series
Validation directory found, 6 series
Testing directory found, 10 series


In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

# "Pixl", "Slice", or "None"
preproc_type = "pixel"

validation = TAVR_3_Frame("__valid", preproc=preproc_type, preload=True)
val_loader = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train", preproc=preproc_type, preload=True)
train_loader = tavr_dataloader(training,batch_size=6, shuffle=True, num_workers=2)


ave_model = average_model()
model = two_layer_basic()
post_proc = post_process(kind=preproc_type).to(device=device)
loss_fn = batch_l2_loss()

# CHANGE TO NAME OF JUPYTER NOTEBOOK
model_name = "Model 5 (Basic+PixelNorm) R1"

using device: cuda


In [10]:
learning_rate = 3e-3
momentum = 0.90
reg = 1e-7

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=reg, nesterov=True)

In [13]:
LOAD = True
iteration_num = 324

if LOAD:
    
    load(model_name, iteration_num, model, optimizer)
    loss_history = get_loss_history(model_name)
    model.to(device=device)
else:
    loss_history = None

model loaded from model_checkpoints/Model 5 (Basic+PixelNorm) R1/Model 5 (Basic+PixelNorm) R1-324


In [12]:
optimizer.state_dict()

{'param_groups': [{'dampening': 0,
   'lr': 0.003,
   'momentum': 0.9,
   'nesterov': True,
   'params': [140710311026408,
    140710311041208,
    140710311050264,
    140710311050192,
    140710311050480,
    140710207899976],
   'weight_decay': 1e-07}],
 'state': {}}

In [14]:
optimizer.state_dict()

{'param_groups': [{'dampening': 0,
   'lr': 0.003,
   'momentum': 0.9,
   'nesterov': True,
   'params': [140710311026408,
    140710311041208,
    140710311050264,
    140710311050192,
    140710311050480,
    140710207899976],
   'weight_decay': 1e-07}],
 'state': {140710207899976: {'momentum_buffer': tensor(1.00000e-02 *
          [ 1.1223], device='cuda:0')},
  140710311026408: {'momentum_buffer': tensor(1.00000e-02 *
          [[[[[ 0.4241,  0.3153, -0.9125],
              [ 0.1918,  1.3896,  0.4412],
              [-1.2447, -0.2620, -0.8182]],
   
             [[-0.5371,  0.2670, -0.5354],
              [-0.0886,  1.6924,  0.5198],
              [-0.6745,  0.3949, -0.8940]],
   
             [[-0.6876,  0.5229,  0.5831],
              [-1.0059,  0.5198,  0.7415],
              [-1.3470, -0.0137,  0.0587]]]],
   
   
   
           [[[[ 1.5616,  0.4495,  3.5651],
              [-0.3727, -6.4716, -4.3946],
              [ 2.6364, -2.7429, -1.8214]],
   
             [[ 6.2029,  2.0

In [None]:
# If multiple GPU
model = nn.DataParallel(model)

In [5]:
train(model, post_proc, optimizer, train_loader, val_loader, loss_fn, device, 
         model_name, loss_history,
          epochs=3, print_every=30, print_level=4, lr_decay=1)

RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #4 'other'

In [9]:
# One last test + visualize results on 1 validation sequence
val_seq = TAVR_Sequence("__valid")
test(model, post_proc, val_loader, loss_fn, device)
test(ave_model, post_proc, val_loader, loss_fn, device)

with torch.no_grad():
    model.eval()
    ave_frames = [val_seq[0][0]]
    for i in range(2,9,2):
        ave_frame = model((val_seq[0][i-2][None,:], val_seq[0][i][None,:]))
        ave_frames += [ave_frame[0][0],  val_seq[0][i]]
    ave_frames += [val_seq[0][9]]
    ave_frames_slices = []
    for f in ave_frames:
        ave_frames_slices += get_central_slices(f)
    set_figsize(6,20)
    display_grid(10, 3, ave_frames_slices)

Validation loss 52.8518 over 81 frames
Validation loss 36.8170 over 81 frames


RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #2 'weight'