To run this notebook provide the following:
- path to the training dataset;
- path to the validation dataset;
- path to the model (model and optimizer state dictionaries), if training is resumed;
- path to store the results (text files, containing model settings, loss function values and
  average dice scores vs epochs);

In [None]:
import torch.nn as nn
import torch.optim
import torch.nn.functional as F
import torch.utils.data
from datetime import date

from db_load_pytable import LoadPyTable
from model import UNet
from augmentation import *
from dice_generalization import *

In [None]:
## Provide path to the training and validation datasets;
path_to_trainset = ''
path_to_valset = ''

## Provide path to the existing model to resume training;
path_to_model = None

## Provide path to save trained model parameters;
path_to_store = ''

In [None]:
## -------------------------------- Specify hyper-parameters --------------------------------
## ---------------- Model architecture;
depth = 4
filter_number = 4
kernel_size = 3
padding = 1

## ---------------- Data augmentation;
## Specify number of augmentation steps to perform within one epoch;
steps = 350
    
## ---------------- Parameter initialization; 
## Default: see pytorch documentation at
## https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d ;
## Otherwise, specify if training is resumed;
resume_training = False
strict = True 

## ---------------- Optimizer parameters;
optimizer = 'Adam' 
batchsize = 4
lr = 0.001 ## learning rate;
epochs = 100

## ---------------- Outputs;
## Loss function value will be outputted every 10th epoch;
output_number = 10

In [None]:
## ---------------- Specify transformation parameters (data augmentation) --------------------

## Create Transform class object; pass a dictionary with transformation parameters;
## Choose device_mode "cuda" to allocate the output on GPU;
ranges = {}
ranges['degree'] = 10
ranges['scale_yx'] = (0.9,1.1)
ranges['shear_yx'] = 0.2
ranges['sigma_points'] = (3,12)

transform = Transform()
transform.warp_mode = 'reflect'
transform.device_mode = 'cuda' 
transform.ranges = ranges

## ----------------------------------- Evaluation metric ---------------------------------
## Create Dice class object to compute average dice score at each epoch;
dice = Dice()

In [None]:
## ------------------------- Save current settings to a text file --------------------------
file = open(path_to_store + 'settings.txt', 'w')

comments = '''
'''
settings_description = '''Comments: %s

Model:
    Model depth: %s;
    Filter number: %s;
    Convolutional kernel size: %s;
    Padding: %s;

Parameter Initialization: default initialization;
    Training is resumed: %s;
    If resumed training, model is taken from: %s;
    
Adam optimizer settings: 
    Batchsize: %s;
    Learing rate: %s;
    Epochs: %s;

During each training epoch %s augmentation steps are performed;

Data augmentation. Transformation parameters: 
ranges['degree'] = 10
ranges['scale_yx'] = (0.9,1.1)
ranges['shear_yx'] = 0.2
ranges['sigma_points'] = (3,12)
transform.warp_mode = 'reflect'

Data is taken from:
    Training set: %s
    Validation set: %s
'''%(comments, depth, filter_number, kernel_size, padding, resume_training, path_to_model,
     batchsize, lr, epochs, steps, path_to_trainset, path_to_valset)

file.write(settings_description)
file.close()

In [None]:
## Load training, validation datasets; pass pytables to torch dataloader;
## Specify batchsize and whether to shuffle the data;
trainset = LoadPyTable(path_to_trainset)
valset = LoadPyTable(path_to_valset)

trainloader = torch.utils.data.DataLoader(trainset, batch_size = batchsize, shuffle = True)
valloader = torch.utils.data.DataLoader(valset, batch_size = batchsize)

In [None]:
## ---------------------------------------- Model ----------------------------------------
model = UNet(depth = depth,
             filter_number = filter_number,
             kernel_size = kernel_size,
             padding = padding)

model = model.cuda() ## Allocate model on GPU BEFORE constructing the optimizer;
    
## If training is resumed, load model state dictionary;
if resume_training:
    
    dictionary = torch.load(path_to_model)
    model.load_state_dict(dictionary['model_state_dict'], strict = strict)
    
    print('Model training is resumed. Model and optimizer parameters are taken from:')
    print('\t', path_to_model)

## ---------------------------- Construct optimizer (Adam) ----------------------------------
optim = torch.optim.Adam(model.parameters(), lr = lr)

## If training is resumed, load optimizer state dictionary;
## In case of transfer learning comment this part: load only model parameters via
## model.load_state_dict(path, strict = ..);
if resume_training:

    optim.load_state_dict( dictionary['optimizer_state_dict'] )
    print('Resuming training, loaded optimizer state dictionary.')

## --------------------------- Define loss function (BCE) ----------------------------------
loss_fn = nn.BCELoss()

## -------------- Create lists to store loss function values, mean dice scores -------------
trainloss, train_mean_scores = [], []
valloss, val_mean_scores = [], []

In [None]:
#### ---------------------------------------------------------------------- Loop over epochs;
for k in range(epochs):
                    
    ## Set an initial value of the loss funtion (to be accumulated over augmentation steps);
    run_loss = 0

    ## NB: Pass model to the training mode; 
    model = model.train()
    
    ## Create a list to store dice score for each step, averaged over classes and samples in
    ## the batch;
    trainbatch_scores = []
    
    #### -------------------------------------------------------- Loop over augmentation steps;
    for step in range(steps):

        ## NB: Set gradients to zero not to accumulate them over the training epochs;
        optim.zero_grad()
        
        ## Create a list: [ [image batch 1, rs batch 1], ..., [image batch n, rs batch n] ];
        ## (rs - reference standard);
        datalist = list(trainloader)
        
        ## Generate a random integer from the range (0, number of batches);
        rn = np.random.randint(low = 0, high = len(datalist))

        ## Choose a random sample from the list: [image batch, rs batch];
        ## Output: two 4D torch tensors of size [N,1,H,W]; (requires_grad = False);
        ## (N - batch size);
        sample = datalist[rn]
        inputs, targets = sample[0], sample[1]

        ## ---------------------------- Data Augmentation ------------------------------

        ## Pass size of the input tensor to the Transform class object;
        ## (size of the last batch may differ);
        transform.dim = inputs.shape

        ## Transform image, target batch (elastic, affine, random flip, intensity augmentation);;
        ## Output: two 4D torch tensors of size [N,1,H,W], allocated on GPU;
        t_inputs, t_targets = transform.batch_transform(inputs, targets)

        ## -----------------------------------------------------------------------------

        ## Pass transformed inputs to the model; 
        ## Output: 4D torch tensor [N,1,H,W];
        ## (requires_grad = True => computational graph will be stored to compute gradients);
        prediction = model(t_inputs)

        ## Compute loss function value, given model prediction and transformed targets;
        ## Output: 1D torch tensor;
        loss = loss_fn(prediction, t_targets)

        ## Increment run loss variable;
        ## (use .item() method to retrieve float value from the output tensor);
        run_loss += loss.item()

        ## Compute the loss function gradients;
        loss.backward()

        ## Update model parameters;
        optim.step()

        ## ------------------------------- Compute dice score -----------------------------
        ## Don't store the computational graph, unnecessary here;
        with torch.no_grad():

            ## Inputs: two 4D torch tensors of size [N,1,H,W], output: float;
            trainbatch_sd = dice.compute_average_dice(prediction, t_targets)
            trainbatch_scores.append(trainbatch_sd) ## append to the list;

        ## -------------------------------------------------------------------------------
    #### -------------------------------------------------------------- Loop over steps ends;
    
    ## Save loss function value accumulated over steps;
    trainloss.append(run_loss)
    
    ## Save dice scores, averaged over classes, samples in the batch and augmentation steps;
    ## Convert list to numpy array to be able to compute mean;
    trainbatch_scores = np.asarray(trainbatch_scores)
    train_mean_scores.append( np.mean(trainbatch_scores) )

    ## Re-initialize value of the loss function; (used in the next loop);
    run_loss = 0
    
    ## NB: pass the model to evaluation mode;
    model = model.eval()

    ## Create a list to store dice score for each batch, averaged over classes and samples in
    ## the batch;
    valbatch_scores = []
    
    #### ------------------------------------------------------ Loop over validation batches;
    for i, sample in enumerate(valloader):
        
        ## Retrieve image and reference standard batch;
        ## Output: two 4D torch tensors of size [N,1,H,W]; (requires_grad = False);
        inputs, targets = sample[0].cuda(), sample[1].cuda()
        
        ## Pass inputs to the model; don't store computational graph;
        with torch.no_grad():
            
            ## Output: 4D torch tensor of size [N,1,H,W]; (requires_grad = False);
            prediction = model(inputs)
        
        ## Compute value of the loss function, i.e., after model parameters update;
        loss = loss_fn(prediction, targets)

        ## Increment run loss variable;
        run_loss += loss.item()
        
        ## ------------------------------- Compute dice score -----------------------------
        valbatch_sd = dice.compute_average_dice(prediction, targets)
        valbatch_scores.append(valbatch_sd)     
        ## --------------------------------------------------------------------------------
    #### ----------------------------------------------- Loop over validation batches ends;
    
    ## Save loss function value, accumulated over batches;
    valloss.append(run_loss)
    
    ## Save dice scores, averaged over classes, samples in the batch and number of batches;
    ## Convert list to numpy array to be able to compute mean;
    valbatch_scores = np.asarray(valbatch_scores)
    val_mean_scores.append( np.mean(valbatch_scores) )
    
    ## Output loss function value each XXth epoch; 
    if k % output_number == 0:

        print('Loss at {}th iteration, training set: {}'.format(k,round(trainloss[k],2)))
        print('Loss at {}th iteration, validation set: {}\n'.format(k,round(valloss[k],2)))

    ## Save model and optimizer state dictionary;
    ## Specify the condition on the epoch;
    if k % 3 == 0:
    
        ## Specify model name;
        model_name = 'model_'+ str(date.today()) + '_epoch_'+ str(k) + '.pt'

        torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optim.state_dict(),
                    'epoch': k,
                    'trainloss': trainloss,
                    'valloss': valloss,
                    'train_scores': train_mean_scores,
                    'val_scores': val_mean_scores
                    }, path_to_store + model_name)       
#### ----------------------------------------------------------------- Loop over epochs ends;

In [None]:
## Save the final model parameters;
model_name = 'model_'+ str(date.today()) + '_epoch_'+ str(k) + '.pt'
torch.save( model.state_dict(), path_to_store + model_name)

In [None]:
## Convert lists to numpy arrays;
trainloss = np.asarray(trainloss)
train_mean_scores = np.asarray(train_mean_scores)

valloss = np.asarray(valloss)
val_mean_scores = np.asarray(val_mean_scores)

## Save loss function value and average dice scores to a txt file;
header = '''
Col.0: Training Loss
Col.1: Validation Loss 
'''
np.savetxt(path_to_store + 'losses', np.c_[trainloss, valloss], header = header)

header = '''
Col.0: Average 2D soft dice, training data
Col.1: Average 2D soft dice, validation data
'''
np.savetxt(path_to_store + 'scores',np.c_[train_mean_scores, val_mean_scores],header = header)