In [1]:
import sys
sys.path.append('..')

import gc
import json
import os
from itertools import islice
from datetime import datetime
import pytz
from pytz import timezone
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import webdataset as wds

from model.baseline_3d_cnn import *

%load_ext autoreload
%autoreload 2

In [2]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

parameters

{'batch_size': 32}

In [10]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]
train_urls = urls[:round(len(urls)*0.6)]
val_urls = urls[round(len(urls)*0.6):]

train_dataset = (
    wds
    .WebDataset(train_urls)
    .shuffle(parameters['batch_size'])
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
)
loader_train = torch.utils.data.DataLoader(train_dataset.batched(parameters['batch_size']), num_workers=2, batch_size=None)

val_dataset = (
    wds
    .WebDataset(val_urls)
    .shuffle(parameters['batch_size'])
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
)
loader_val = torch.utils.data.DataLoader(val_dataset.batched(parameters['batch_size']), num_workers=2, batch_size=None)

# for image, target in islice(dataset, 0, 2):
#     print(image.shape)

In [11]:
USE_GPU = True
dtype = torch.float

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(device)

cuda


In [17]:
def validate_model(loader, model, criterion, iteration, writer):
    print('Checking accuracy on validation set')

    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        targets_np, scores_np, val_losses = [], [], []
        
        # Run validation on validation batches
        for x, y, z in loader:
            # Temporarily add code to unsqueeze #TEMP
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            x = x.type(torch.cuda.FloatTensor)
            y = y.to(device=device, dtype=dtype)
            scores = model(x)
            val_loss = criterion(scores, y)
            val_losses.append(val_loss.item())
            
            scores_np.extend(scores.cpu().numpy())
            targets_np.extend(y.cpu().numpy())
        
        # Calculate metrics after running full validation set
        scores_np, targets_np = np.array(scores_np), np.array(targets_np)
        ind_dict, avg_dict = calculate_metrics(scores_np, targets_np)
        
        # Save metrics to tensorboard
        writer.add_scalar("Loss/validation", np.mean(val_losses), iteration)
        writer.add_scalar("Micro Accuracy/validation", avg_dict['accuracy'], iteration)
        writer.add_scalars("Micro Precision-Recall-F1/validation", 
                           {'Precision': avg_dict['precision'], 'Recall': avg_dict['recall'], 'F1': avg_dict['fscore']}, iteration)
        
        # Print results
        print("Validation results, iter: {:2d}".format(iteration))
        print("Loss: {:.4f}, Micro accuracy: {:.3f}, Micro precision: {:.3f}, Micro recall: {:.3f}, Micro F1: {:.3f}"
              .format(np.mean(val_losses), avg_dict['accuracy'], avg_dict['precision'], avg_dict['recall'], avg_dict['fscore'])
             )

In [18]:
def train(model, optimizer, criterion, loader_train, loader_val, writer, epochs=5, device=device, val_every=1):
    """
    Train a model.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    batch_losses = []
    total_iter = 0
    for e in range(epochs):
        print("************EPOCH: {:2d} ***************".format(e))
        for t, (x, y, z) in enumerate(loader_train):
            model.train()  # put model to training mode
            
            # Temporarily add code to unsqueeze #TEMP
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device)  # move to device, e.g. GPU
            x = x.type(torch.cuda.FloatTensor)

            y = y.to(device=device, dtype=dtype)

            scores = model(x)
            loss = criterion(scores, y)
            curr_batchloss = loss.item()
            batch_losses.append(curr_batchloss)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()
            
            # Calculate train metrics
            ind_dict, avg_dict = calculate_metrics(scores.detach().cpu().numpy(), y.detach().cpu().numpy(), threshold=0.5)
            
            # Write train metrics to tensorboard (could move this when we run validation step..?)
            writer.add_scalar("Loss/train", curr_batchloss, total_iter+t)
            writer.add_scalar("Micro Accuracy/train", avg_dict['accuracy'], total_iter+t)
            writer.add_scalars("Micro Precision-Recall-F1/train", 
                               {'Precision': avg_dict['precision'], 'Recall': avg_dict['recall'], 'F1': avg_dict['fscore']}, total_iter+t)
            
            # Run validation step
            if t % val_every == 0:
                print('Iteration %d, train loss = %.4f' % (t, curr_batchloss))
                validate_model(loader_val, model, criterion, total_iter+t, writer)
                print()
                writer.flush()
        total_iter += t+1
    writer.close()

In [31]:
model = baseline_3DCNN(in_num_ch=1)

currtime = datetime.now(tz=pytz.utc).astimezone(timezone('US/Pacific')).strftime('%Y-%m-%d_%H-%M-%S')
writer = SummaryWriter(log_dir=os.path.join('../runs', currtime))
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
criterion = torch.nn.BCELoss()

In [32]:
gc.collect()

568

In [33]:
train(model, optimizer, criterion, loader_train, loader_val, writer, epochs=5, device=device, val_every=1)

************EPOCH:  0 ***************
Iteration 0, loss = 0.7341
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation results, iter:  0
Loss: 0.6734, Micro accuracy: 0.681, Micro precision: 0.152, Micro recall: 0.189, Micro F1: 0.169

Iteration 1, loss = 1.9795
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))


Validation results, iter:  1
Loss: 0.6924, Micro accuracy: 0.421, Micro precision: 0.172, Micro recall: 0.622, Micro F1: 0.269

************EPOCH:  1 ***************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 1.2643
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation results, iter:  2
Loss: 0.7243, Micro accuracy: 0.477, Micro precision: 0.142, Micro recall: 0.405, Micro F1: 0.210

Iteration 1, loss = 1.1966
Checking accuracy on validation set
Validation results, iter:  3
Loss: 0.8648, Micro accuracy: 0.458, Micro precision: 0.130, Micro recall: 0.378, Micro F1: 0.193

************EPOCH:  2 ***************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 0.7779
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation results, iter:  4
Loss: 1.1169, Micro accuracy: 0.458, Micro precision: 0.130, Micro recall: 0.378, Micro F1: 0.193

Iteration 1, loss = 0.8344
Checking accuracy on validation set
Validation results, iter:  5
Loss: 1.4659, Micro accuracy: 0.458, Micro precision: 0.130, Micro recall: 0.378, Micro F1: 0.193

************EPOCH:  3 ***************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 0.6980
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation results, iter:  6
Loss: 1.8011, Micro accuracy: 0.458, Micro precision: 0.130, Micro recall: 0.378, Micro F1: 0.193

Iteration 1, loss = 0.6195
Checking accuracy on validation set
Validation results, iter:  7
Loss: 2.1176, Micro accuracy: 0.458, Micro precision: 0.130, Micro recall: 0.378, Micro F1: 0.193

************EPOCH:  4 ***************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 0.6164
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation results, iter:  8
Loss: 2.4588, Micro accuracy: 0.458, Micro precision: 0.130, Micro recall: 0.378, Micro F1: 0.193

Iteration 1, loss = 0.6044
Checking accuracy on validation set
Validation results, iter:  9
Loss: 2.6201, Micro accuracy: 0.458, Micro precision: 0.130, Micro recall: 0.378, Micro F1: 0.193



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
