In [41]:
# Code credits: Adapted bits and pieces from https://github.com/webdataset/webdataset/blob/master/docs/gettingstarted.ipynb

import sys
sys.path.append('..')

import gc
import json
import os
from itertools import islice
from datetime import datetime
import pytz
from pytz import timezone
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
import skimage.transform as st


import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import webdataset as wds

from model.baseline_3d_cnn import *
from utils.model_utils import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [42]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

batch_size = parameters['batch_size']
parameters

{'batch_size': 16}

In [3]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]
# All the data
train_urls = urls[:round(len(urls)*0.6)]
val_urls = urls[round(len(urls)*0.6):round(len(urls)*0.8)]
test_urls = urls[round(len(urls)*0.8):]

# Smaller data just to run model once
# train_urls = urls[:2]
# val_urls = urls[2:3]
# test_urls = urls[3:]


print("Number of train shards:", len(train_urls))
print("Number of validation shards:", len(val_urls))
print("Number of test shards:", len(test_urls))

Number of train shards: 29
Number of validation shards: 10
Number of test shards: 10


In [4]:
# Create dataset objects
train_dataset = (
    wds
    .WebDataset(train_urls)
#     .shuffle(batch_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
#     .map_tuple(pre_transforms, identity, identity)
)
loader_train = torch.utils.data.DataLoader(train_dataset.batched(batch_size), num_workers=0, batch_size=None) #setting batch_size = None disables batching
val_dataset = (
    wds
    .WebDataset(val_urls)
    .shuffle(batch_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
)
loader_val = torch.utils.data.DataLoader(val_dataset.batched(batch_size), num_workers=0, batch_size=None)

test_dataset = (
    wds
    .WebDataset(test_urls)
    .shuffle(batch_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
)
loader_test = torch.utils.data.DataLoader(test_dataset.batched(batch_size), num_workers=0, batch_size=None)

# for image, target in islice(dataset, 0, 2):
#     print(image.shape)

### Original quality images:

In [5]:
gc.collect()

44

In [6]:
# patient_num = 1

# for t, (x, y, z) in enumerate(loader_train):
#     if t > 0:
#         break
#     tmp = x[patient_num, :, :, :].detach().numpy()
# del x, y, z

In [7]:
# fig, axs = plt.subplots(5,8, figsize=(15, 6))
# axs = axs.ravel()
# for i in range(40):
#     axs[i].imshow(tmp[i,:,:])

### Downsampled quality images

In [8]:
# fig, axs = plt.subplots(5,8, figsize=(15, 6))
# axs = axs.ravel()
# for i in range(40):
#     axs[i].imshow(st.resize(tmp[i,:,:], (256,256)))

In [9]:
# del tmp
# gc.collect()

In [10]:
USE_GPU = True
dtype = torch.float

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
#     dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')

print(device)
print(dtype)

cuda
torch.float32


In [32]:
def validate_model(loader, model, criterion, iteration, writer):
    print('Checking accuracy on validation set')

    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        targets_np, scores_np, val_losses, batch_sizes = [], [], [], []
        
        # Run validation on validation batches
        for x, y, z in loader:
            
            x = transforms.Resize(size=(256, 256))(x)

            # Add code to unsqueeze because we only have 1 channel (axis=1) of this 3d image
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=dtype)
            scores = model(x)
            val_loss = criterion(scores, y)
            val_losses.append(val_loss.item())
            batch_sizes.append(x.shape[0])
            
            scores_np.extend(scores.cpu().numpy())
            targets_np.extend(y.cpu().numpy())
        
    # Calculate metrics after running full validation set
    scores_np, targets_np = np.array(scores_np), np.array(targets_np)

    # Log Metrics
    val_loss = np.average(val_losses, weights=batch_sizes)
    log_metrics(scores_np, targets_np, val_loss, iteration, writer, curr_mode="validation")

    # Print results
    print('Total iteration %d, validation loss = %.4f' % (iteration, val_loss))
#         print("Loss: {:.4f}, Micro accuracy: {:.3f}, Micro precision: {:.3f}, Micro recall: {:.3f}, Micro F1: {:.3f}"
#               .format(np.mean(val_losses), avg_dict['accuracy'], avg_dict['precision'], avg_dict['recall'], avg_dict['fscore'])
#              )

    # Return validation loss
    return val_loss

In [33]:
def train(model, optimizer, criterion, loader_train, loader_val, log_dir, epochs=5, device=device, val_every=1):
    """
    Train a model.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    total_iter = 0
    val_loss_dict = {}
    train_loss_dict = {}
    
    # Writer for tensorboard
    writer = SummaryWriter(log_dir)

    for e in range(1, epochs+1):
        ep_train_losses = []
        ep_val_losses = []

        print("************EPOCH: {:2d} ***************".format(e))
        for t, (x, y, z) in enumerate(loader_train):
            t += 1 # To start from iteration 1 (1-indexing)
            
            model.train()  # put model to training mode
            
            
            # Have to add resize to make dataset more manageable..
            x = transforms.Resize(size=(256, 256))(x)
            # Add code to unsqueeze because we only have 1 channel (axis=1) of this 3d image
            x = x.unsqueeze(axis=1)
            
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=dtype)

            scores = model(x)
            loss = criterion(scores, y)
            curr_batchloss = loss.item()
            ep_train_losses.append(curr_batchloss)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()
            
            # Log metrics
            scores_np = scores.detach().cpu().numpy()
            targets_np = y.detach().cpu().numpy()
            log_metrics(scores_np, targets_np, curr_batchloss, total_iter+t, writer, curr_mode='train')
            
            # Run validation step every val_every iteration
            if t % val_every == 0:
                print('Current epoch %d, epoch iteration %d, train loss = %.4f' % (e, t, curr_batchloss))
                val_loss = validate_model(loader_val, model, criterion, total_iter+t, writer)
                ep_val_losses.append(val_loss)
                
                # Also save checkpoint of model 
                ckpt_path = os.path.join(log_dir, 'Checkpoints', 'ep_%d_iter_%d_ckpt.pt' % (e, t))
                torch.save({
                    'epoch': e,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                }, ckpt_path)
                
                print()
                writer.flush()
                
        train_loss_dict['epoch_'+str(e)] = ep_train_losses
        val_loss_dict['epoch_'+str(e)] = ep_val_losses
        total_iter += t+1
        
    # Close the summarywriter for tensorboard
    writer.close()
    
    # Return train and validation loss
    return (train_loss_dict, val_loss_dict)

In [36]:
# Make log directory and checkpoint directory
dir_nm = datetime.now(tz=pytz.utc).astimezone(timezone('US/Pacific')).strftime('%Y-%m-%d_%H-%M-%S')
dir_nm = "first_full_c2fc2"
log_dir = os.path.join('../runs', dir_nm)
os.mkdir(log_dir)
os.mkdir(os.path.join(log_dir, 'Checkpoints'))


# Model, optimizer, criterion
model = baseline_3DCNN(in_num_ch=1)
optimizer = optim.Adam(model.parameters(), lr = 1e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [37]:
gc.collect()

44

In [38]:
shard_size = 32
print("Number of train iterations per epoch:", len(train_urls) * shard_size//batch_size)

Number of train iterations per epoch: 58


In [39]:
# tmp  = (
#     wds
#     .WebDataset(sorted(train_urls)[:2])
# )
# for i, sample in enumerate(tmp):
#     print("IIIIIIIIII:", i)
#     for key, value in sample.items():
#         print(key, repr(value)[:50])
#     print()

In [40]:
train_loss_dict, val_loss_dict = train(model, optimizer, criterion, loader_train, loader_val, log_dir, epochs=10, device=device, val_every=5)

************EPOCH:  1 ***************
torch.Size([16, 8, 5, 32, 32])
torch.Size([16, 40960])


ValueError: 000031.labels.pyd: duplicate file name in tar file labels.pyd dict_keys(['__key__', 'labels.pyd', 'studynames.pyd', 'volumes.pyd'])

### Loading a checkpoint

In [None]:
# ckpt_path = os.path.join(log_dir, 'Checkpoints', 'ep_0_iter_3_ckpt.pt')
# ckpt = torch.load(ckpt_path)

# ckpt_model = baseline_3DCNN(in_num_ch=1)
# ckpt_model.load_state_dict(ckpt['model_state_dict'])
# ckpt_model.state_dict()