In [76]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

import utils.tavr_torch as tavr_torch
from utils.tavr_torch import TAVR_3_Frame, TAVR_1_Frame, TAVR_Sequence, tavr_dataloader
from utils.visualization import display_grid, z_stretch, visualize_frame, set_figsize, get_central_slices
from utils.loss_functions import batch_l2_loss, batch_mse_loss
from utils.run_model import train, test
from Models.basic_models import average_model, two_layer_basic


set_figsize(20, 15)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [77]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

validation = TAVR_3_Frame("__valid")
val_loader = tavr_dataloader(validation, batch_size=8, shuffle=True, num_workers=2)
training = TAVR_3_Frame("__train")
train_loader = tavr_dataloader(training, batch_size=4, shuffle=True, num_workers=2)

ave_model = average_model()
model = two_layer_basic()
loss_fn = batch_mse_loss()

using device: cpu


In [None]:
from time import time
t = []
val_loader1 = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=2)
val_loader2 = tavr_dataloader(validation, batch_size=16, shuffle=True, num_workers=2)
val_loader3 = tavr_dataloader(validation, batch_size=4, shuffle=True, num_workers=5)
val_loader4 = tavr_dataloader(validation, batch_size=16, shuffle=True, num_workers=5)

t.append(time())
test(model, val_loader1, loss_fn, device)
t.append(time())
test(ave_model, val_loader1, loss_fn, device)
t.append(time())
test(model, val_loader2, loss_fn, device)
t.append(time())
test(ave_model, val_loader2, loss_fn, device)
t.append(time())
test(model, val_loader3, loss_fn, device)
t.append(time())
test(ave_model, val_loader3, loss_fn, device)
t.append(time())
test(model, val_loader4, loss_fn, device)
t.append(time())
test(ave_model, val_loader4, loss_fn, device)
t.append(time())

for i in range(1, len(t)):
    print(t[i]-t[i-1])

torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.025e+05, max_z_sum 287 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 925564.0
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.591e+05, max_z_sum 289 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 1176161.875
torch.Size([4, 1, 77, 256, 256]) torch.Size([4, 77, 256, 256])
	mean 2.344e+05, max_z_sum 268 torch.Size([4, 4, 77, 256, 256])
BS & loss 4 1077383.75
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.253e+05, max_z_sum 295 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 1001882.4375
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 1.472e+05, max_z_sum 310 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 623028.6875
torch.Size([4, 1, 73, 256, 256]) torch.Size([4, 73, 256, 256])
	mean 2.599e+05, max_z_sum 236 torch.Size([4, 4, 73, 256, 256])
BS & loss 4 1286085.5
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.492e+05, max_z_sum 252 torch

torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.835e+05, max_z_sum 266 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 1398198.625
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.512e+05, max_z_sum 280 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 1177267.375
torch.Size([4, 1, 73, 256, 256]) torch.Size([4, 73, 256, 256])
	mean 1.498e+05, max_z_sum 278 torch.Size([4, 4, 73, 256, 256])
BS & loss 4 629293.625
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.423e+05, max_z_sum 277 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 1147785.0
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.522e+05, max_z_sum 252 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 1313137.625
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 2.499e+05, max_z_sum 280 torch.Size([4, 4, 82, 256, 256])
BS & loss 4 1171121.75
torch.Size([4, 1, 82, 256, 256]) torch.Size([4, 82, 256, 256])
	mean 1.408e+05, max_z_sum 318 torc

In [78]:
learning_rate = 1e-10
momentum = 0.90
reg = 1e-7

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=reg, nesterov=True)

In [79]:
train(model, optimizer, train_loader, val_loader, loss_fn, device, 
         epochs=1, print_every=10, print_level=3, lr_decay=0.8)


****Epoch 0 Iteration 0, loss = 2208868.0000
Validation loss 3691044.7500 over 81 frames
conv_a1.weight,   	norm: 4.1727e+00, 	update norm: 2.1939e-03 	Update/norm: 5.2577e-04
conv_a1.bias,   	norm: 2.9972e-01, 	update norm: 5.0899e-07 	Update/norm: 1.6982e-06
conv_b1.weight,   	norm: 3.9208e+00, 	update norm: 2.8334e-03 	Update/norm: 7.2265e-04
conv_b1.bias,   	norm: 2.7325e-01, 	update norm: 7.9115e-07 	Update/norm: 2.8953e-06
final.weight,   	norm: 1.4334e+00, 	update norm: 1.8729e-03 	Update/norm: 1.3066e-03
final.bias,   	norm: 7.2517e-02, 	update norm: 8.8662e-07 	Update/norm: 1.2226e-05

Iter 0... ... ... 
Iteration 10, loss = 1262448.0000
Validation loss 1824393.7500 over 81 frames
conv_a1.weight,   	norm: 4.1699e+00, 	update norm: 4.0360e-03 	Update/norm: 9.6790e-04
conv_a1.bias,   	norm: 2.9973e-01, 	update norm: 1.0821e-06 	Update/norm: 3.6104e-06
conv_b1.weight,   	norm: 3.9255e+00, 	update norm: 5.2436e-03 	Update/norm: 1.3358e-03
conv_b1.bias,   	norm: 2.7326e-01, 	updat

Process Process-44:
Process Process-43:
Traceback (most recent call last):
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._

KeyboardInterrupt: 

In [None]:
# One last test + visualize results on 1 validation sequence
val_seq = TAVR_Sequence("__valid")
test(model, val_loader, loss_fn, device)
test(ave_model, val_loader, loss_fn, device)

with torch.no_grad():
    model.eval()
    ave_frames = [val_seq[0][0]]
    for i in range(2,9,2):
        ave_frame = model((val_seq[0][i-2][None,:], val_seq[0][i][None,:]))
        ave_frames += [ave_frame[0][0],  val_seq[0][i]]
    ave_frames += [val_seq[0][9]]
    ave_frames_slices = []
    for f in ave_frames:
        ave_frames_slices += get_central_slices(f)
    set_figsize(6,20)
    display_grid(10, 3, ave_frames_slices)

In [20]:
from utils.loss_functions import batch_mse_loss
mse_loss_fn = batch_mse_loss()

from time import time
t = []
t.append(time())
test(model, val_loader, mse_loss_fn, device)
t.append(time())
test(ave_model, val_loader, mse_loss_fn, device)
t.append(time())
test(model, val_loader, loss_fn, device)
t.append(time())
test(ave_model, val_loader, loss_fn, device)
t.append(time())

for i in range(1, len(t)):
    print(t[i]-t[i-1])

Validation loss 4033946.0000 over 81 frames
Validation loss 1468.3656 over 81 frames
Validation loss 0.2381 over 81 frames
Validation loss 0.0047 over 81 frames
105.04296255111694
14.075073480606079
107.34071278572083
13.867462873458862
