In [1]:
import sys
sys.path.append('..')

import gc
import json
import os
from itertools import islice
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import webdataset as wds

from model.baseline_3d_cnn import *

%load_ext autoreload
%autoreload 2

In [2]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

parameters

{'batch_size': 32}

In [3]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir)]
train_urls = urls[:round(len(urls)*0.6)]
val_urls = urls[round(len(urls)*0.6):]

train_dataset = (
    wds
    .WebDataset(train_urls)
    .shuffle(parameters['batch_size'])
    .decode('torch')
    .to_tuple('data.pyd', 'target.pyd')
)
loader_train = torch.utils.data.DataLoader(train_dataset.batched(parameters['batch_size']), num_workers=2, batch_size=None)

val_dataset = (
    wds
    .WebDataset(val_urls)
    .shuffle(parameters['batch_size'])
    .decode('torch')
    .to_tuple('data.pyd', 'target.pyd')
)
loader_val = torch.utils.data.DataLoader(val_dataset.batched(parameters['batch_size']), num_workers=2, batch_size=None)

# for image, target in islice(dataset, 0, 2):
#     print(image.shape)

In [4]:
USE_GPU = True
dtype = torch.float

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(device)

cuda


In [14]:
def validate_model(loader, model, criterion, iteration, writer):
    print('Checking accuracy on validation set')

    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        targets_np, scores_np, val_losses = [], [], []
        
        # Run validation on validation batches
        for x, y in loader:
            # Temporarily add code to unsqueeze #TEMP
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            x = x.type(torch.cuda.FloatTensor)
            y = y.to(device=device, dtype=dtype)
            scores = model(x)
            val_loss = criterion(scores, y)
            val_losses.append(val_loss.item())
            
            scores_np.extend(scores.cpu().numpy())
            targets_np.extend(y.cpu().numpy())
        
        # Calculate metrics after running full validation set
        scores_np, targets_np = np.array(scores_np), np.array(targets_np)
        ind_dict, avg_dict = calculate_metrics(scores_np, targets_np)
        
        # Save metrics to tensorboard
        writer.add_scalar("Loss/validation", np.mean(val_losses), iteration)
        writer.add_scalar("Micro Accuracy/validation", avg_dict['accuracy'], iteration)
        writer.add_scalars("Micro Precision-Recall-F1/validation", 
                           {'Precision': avg_dict['precision'], 'Recall': avg_dict['recall'], 'F1': avg_dict['fscore']}, iteration)
        
        # Print results
        print("Validation results, iter: {:2d}".format(iteration))
        print("Loss: {:.4f}, Micro accuracy: {:.3f}, Micro precision: {:.3f}, Micro recall: {:.3f}, Micro F1: {:.3f}"
              .format(np.mean(val_losses), avg_dict['accuracy'], avg_dict['precision'], avg_dict['recall'], avg_dict['fscore'])
             )

In [15]:
def train(model, optimizer, criterion, loader_train, loader_val, writer, epochs=5, device=device, val_every=1):
    """
    Train a model.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    batch_losses = []
    total_iter = 0
    for e in range(epochs):
        print("************EPOCH: {:2d} ***************".format(e))
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            
            # Temporarily add code to unsqueeze #TEMP
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device)  # move to device, e.g. GPU
            x = x.type(torch.cuda.FloatTensor)

            y = y.to(device=device, dtype=dtype)

            scores = model(x)
            loss = criterion(scores, y)
            curr_batchloss = loss.item()
            batch_losses.append(curr_batchloss)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()
            
            # Calculate train metrics
            ind_dict, avg_dict = calculate_metrics(scores.detach().cpu().numpy(), y.detach().cpu().numpy(), threshold=0.5)
            
            # Write train metrics to tensorboard (could move this when we run validation step..?)
            writer.add_scalar("Loss/train", curr_batchloss, total_iter+t)
            writer.add_scalar("Micro Accuracy/train", avg_dict['accuracy'], total_iter+t)
            writer.add_scalars("Micro Precision-Recall-F1/train", 
                               {'Precision': avg_dict['precision'], 'Recall': avg_dict['recall'], 'F1': avg_dict['fscore']}, total_iter+t)
            
            # Run validation step
            if t % val_every == 0:
                print('Iteration %d, loss = %.4f' % (t, curr_batchloss))
                validate_model(loader_val, model, criterion, total_iter+t, writer)
                print()
                writer.flush()
        total_iter += t+1
    writer.close()

In [16]:
model = baseline_3DCNN(in_num_ch=1)

writer = SummaryWriter(log_dir='../runs')
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
criterion = torch.nn.BCELoss()

In [17]:
gc.collect()

90

In [19]:
train(model, optimizer, criterion, loader_train, loader_val, writer, epochs=5, device=device, val_every=1)

Iteration 0, loss = 0.8985
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  0, Validation results:
Loss: 1.3067, Micro accuracy: 0.185, Micro precision: 0.112, Micro recall: 0.541, Micro F1: 0.185

Iteration 1, loss = 0.6982
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  1, Validation results:
Loss: 1.8957, Micro accuracy: 0.245, Micro precision: 0.111, Micro recall: 0.486, Micro F1: 0.181

Iteration 2, loss = 0.6740
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  2, Validation results:
Loss: 2.4898, Micro accuracy: 0.412, Micro precision: 0.112, Micro recall: 0.351, Micro F1: 0.170



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 0.6805
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  0, Validation results:
Loss: 3.1867, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 1, loss = 0.7191
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  1, Validation results:
Loss: 3.9313, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 2, loss = 0.6810
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  2, Validation results:
Loss: 4.6599, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 1.2561
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  0, Validation results:
Loss: 4.9607, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 1, loss = 0.9312
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  1, Validation results:
Loss: 4.9244, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 2, loss = 0.7389
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  2, Validation results:
Loss: 4.8355, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 0.6924
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  0, Validation results:
Loss: 4.6041, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 1, loss = 0.6435
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  1, Validation results:
Loss: 4.3412, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 2, loss = 0.6427
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  2, Validation results:
Loss: 4.0156, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iteration 0, loss = 0.6333
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  0, Validation results:
Loss: 3.7794, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 1, loss = 0.6960
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  1, Validation results:
Loss: 3.5508, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152

Iteration 2, loss = 0.6515
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Iter:,  2, Validation results:
Loss: 3.3188, Micro accuracy: 0.431, Micro precision: 0.102, Micro recall: 0.297, Micro F1: 0.152



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
