In [1]:
# Code credits: Adapted bits and pieces from https://github.com/webdataset/webdataset/blob/master/docs/gettingstarted.ipynb

import sys
sys.path.append('..')

import gc
import json
import os
from itertools import islice
from datetime import datetime
import pytz
from pytz import timezone
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import webdataset as wds

from model.baseline_3d_cnn import *
from utils.model_utils import *

%load_ext autoreload
%autoreload 2

In [2]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards_small')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

parameters

{'batch_size': 32}

In [3]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]
# All the data
#train_urls = urls[:round(len(urls)*0.6)]
#val_urls = urls[round(len(urls)*0.6):round(len(urls)*0.8)]
#test_urls = urls[round(len(urls)*0.8):]

# Smaller data just to run model once
train_urls = urls[:2]
val_urls = urls[2:3]
test_urls = urls[3:]

print("Number of train shards:", len(train_urls))
print("Number of validation shards:", len(val_urls))
print("Number of test shards:", len(test_urls))

Number of train shards: 2
Number of validation shards: 1
Number of test shards: 18


In [4]:
# Create dataset objects
train_dataset = (
    wds
    .WebDataset(train_urls)
    .shuffle(parameters['batch_size'])
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
)
loader_train = torch.utils.data.DataLoader(train_dataset.batched(1), num_workers=2, batch_size=None) #setting batch_size = None disables batching
# parameters['batch_size']
val_dataset = (
    wds
    .WebDataset(val_urls)
    .shuffle(parameters['batch_size'])
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
)
loader_val = torch.utils.data.DataLoader(val_dataset.batched(1), num_workers=2, batch_size=None)

test_dataset = (
    wds
    .WebDataset(test_urls)
    .shuffle(parameters['batch_size'])
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
)
loader_test = torch.utils.data.DataLoader(test_dataset.batched(1), num_workers=2, batch_size=None)

# for image, target in islice(dataset, 0, 2):
#     print(image.shape)

In [5]:
USE_GPU = True
dtype = torch.float

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(device)

cuda


In [6]:
def validate_model(loader, model, criterion, iteration, writer):
    print('Checking accuracy on validation set')

    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        targets_np, scores_np, val_losses, batch_sizes = [], [], [], []
        
        # Run validation on validation batches
        for x, y, z in loader:
            # Temporarily add code to unsqueeze #TEMP
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            x = x.type(torch.cuda.FloatTensor)
            y = y.to(device=device, dtype=dtype)
            scores = model(x)
            val_loss = criterion(scores, y)
            val_losses.append(val_loss.item())
            batch_sizes.append(x.shape[0])
            
            scores_np.extend(scores.cpu().numpy())
            targets_np.extend(y.cpu().numpy())
        
        # Calculate metrics after running full validation set
        scores_np, targets_np = np.array(scores_np), np.array(targets_np)
        
        # Log Metrics
        val_loss = np.average(val_losses, weights=batch_sizes)
        log_metrics(scores_np, targets_np, val_loss, iteration, writer, curr_mode="validation")
        
        # Print results
        print('Total iteration %d, validation loss = %.4f' % (iteration, val_loss))
#         print("Loss: {:.4f}, Micro accuracy: {:.3f}, Micro precision: {:.3f}, Micro recall: {:.3f}, Micro F1: {:.3f}"
#               .format(np.mean(val_losses), avg_dict['accuracy'], avg_dict['precision'], avg_dict['recall'], avg_dict['fscore'])
#              )

In [7]:
def train(model, optimizer, criterion, loader_train, loader_val, writer, epochs=5, device=device, val_every=1):
    """
    Train a model.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    batch_losses = []
    total_iter = 0
    for e in range(epochs):
        print("************EPOCH: {:2d} ***************".format(e))
        for t, (x, y, z) in enumerate(loader_train):
            model.train()  # put model to training mode
            
            # Temporarily add code to unsqueeze #TEMP
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device)  # move to device, e.g. GPU
            x = x.type(torch.cuda.FloatTensor)

            y = y.to(device=device, dtype=dtype)

            scores = model(x)
            loss = criterion(scores, y)
            curr_batchloss = loss.item()
            batch_losses.append(curr_batchloss)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()
            
            # Log metrics
            scores_np = scores.detach().cpu().numpy()
            targets_np = y.detach().cpu().numpy()
            log_metrics(scores_np, targets_np, curr_batchloss, total_iter+t, writer, curr_mode='train')
            
            # Run validation step
            if t % val_every == 0:
                print('Current epoch %d, epoch iteration %d, train loss = %.4f' % (e, t, curr_batchloss))
                validate_model(loader_val, model, criterion, total_iter+t, writer)
                print()
                writer.flush()
        total_iter += t+1
    writer.close()

In [8]:
model = baseline_3DCNN(in_num_ch=1)

currtime = datetime.now(tz=pytz.utc).astimezone(timezone('US/Pacific')).strftime('%Y-%m-%d_%H-%M-%S')
writer = SummaryWriter(log_dir=os.path.join('../runs', currtime))
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
criterion = torch.nn.BCEWithLogitsLoss()

In [9]:
gc.collect()

66

In [10]:
train(model, optimizer, criterion, loader_train, loader_val, writer, epochs=3, device=device, val_every=1)

************EPOCH:  0 ***************
torch.Size([1, 16, 5, 64, 64])
Current epoch 0, epoch iteration 0, train loss = 0.6811
Checking accuracy on validation set


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


MemoryError: Caught MemoryError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
    data = next(self.dataset_iter)
  File "/opt/conda/lib/python3.7/site-packages/webdataset/iterators.py", line 359, in batched
    for sample in data:
  File "/opt/conda/lib/python3.7/site-packages/webdataset/iterators.py", line 282, in to_tuple
    for sample in data:
  File "/opt/conda/lib/python3.7/site-packages/webdataset/iterators.py", line 226, in map
    for sample in data:
  File "/opt/conda/lib/python3.7/site-packages/webdataset/iterators.py", line 190, in shuffle
    buf.append(next(data))  # skipcq: PYL-R1708
  File "/opt/conda/lib/python3.7/site-packages/webdataset/tariterators.py", line 153, in group_by_keys
    for fname, value in data:
  File "/opt/conda/lib/python3.7/site-packages/webdataset/tariterators.py", line 139, in tar_file_expander
    if handler(exn):
  File "/opt/conda/lib/python3.7/site-packages/webdataset/utils.py", line 12, in reraise_exception
    raise exn
  File "/opt/conda/lib/python3.7/site-packages/webdataset/tariterators.py", line 135, in tar_file_expander
    for sample in tar_file_iterator(source["stream"]):
  File "/opt/conda/lib/python3.7/site-packages/webdataset/tariterators.py", line 119, in tar_file_iterator
    if handler(exn):
  File "/opt/conda/lib/python3.7/site-packages/webdataset/utils.py", line 12, in reraise_exception
    raise exn
  File "/opt/conda/lib/python3.7/site-packages/webdataset/tariterators.py", line 116, in tar_file_iterator
    data = stream.extractfile(tarinfo).read()
  File "/opt/conda/lib/python3.7/tarfile.py", line 695, in read
    b = self.fileobj.read(length)
  File "/opt/conda/lib/python3.7/tarfile.py", line 537, in read
    buf = self._read(size)
  File "/opt/conda/lib/python3.7/tarfile.py", line 545, in _read
    return self.__read(size)
  File "/opt/conda/lib/python3.7/tarfile.py", line 577, in __read
    return t[:size]
MemoryError
