# 3D U-Net on BraTS glioma dataset

In [None]:
import os
import argparse
import datetime
import numpy as np
from tqdm import tqdm

import mxnet as mx
from mxnet import gluon, autograd, ndarray as nd
from mxnet.gluon.utils import split_and_load

from unet import *

***
## Setup hyperparameters

In [None]:
args = argparse.Namespace()

args.data_dir = '../brats_2018_4D' # Should contain 'training' and 'validation' dirs. 

# Training
args.resume = ''
args.start_epoch = 0
args.epochs = 600
args.batch_size = 6
args.num_workers = 6
args.optimizer = 'adam'
args.optimizer_params = {'learning_rate': 0.0001, 'lr_scheduler': mx.lr_scheduler.PolyScheduler(max_update=43*args.epochs, base_lr=0.0001, pwr=2)}
# GPU_COUNT = 1
# args.ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
args.ctx = [mx.gpu(0)]

# Unet
args.num_downs = 4 # Number of encoding/downsampling layers
args.classes = 4 # Number of classes for segmentation, including background
args.ngf = 32 # Number of channels in base/outermost layer
args.use_bias = True # For conv blocks
args.use_global_stats = True # For BN blocks

# Pre/post-processing
args.crop_size_train = [80, 80, 80] # Training patch size
args.lesion_frac = 0.9 # Fraction of patch centerpoints to be placed within lesion (vs randomly within brain)
args.warp_params = {'theta_max': 45,
                    'offset_max': 0,
                    'scale_max': 1.25,
                    'shear_max': 0.1}

# Checkpoint
args.save_interval = args.epochs
args.save_dir = '../params'
fold_str = 'fold' + str(args.fold) if hasattr(args, 'fold') else 'foldAll'
time_str = str(datetime.datetime.now().strftime("%Y-%m-%dT%H%M%S"))
net_name = '_'.join(('unet', str(args.crop_size_train[0]), fold_str, time_str))
args.save_prefix = os.path.join(args.save_dir, net_name, net_name)

***
## Overrides

In [None]:
args.classes = 5

In [None]:
class MRISegDataset4DSVID(MRISegDataset):
    """Dataset class with all inputs and GT seg combined into a single 4D NifTI"""
    def __getitem__(self, idx):
        #import pdb; pdb.set_trace()
        _sub_name = os.path.basename(os.path.dirname(self.sub_dirs[idx]))

        # Load multichannel input data
        img_path = os.path.join(self.sub_dirs[idx], _sub_name + '_' + '4D' + '.nii.gz')
        img_raw = nib.load(img_path).get_fdata()
        img_raw = img_raw.transpose((3,0,1,2))
        img_raw = np.flip(img_raw, 2) # Correct AP orientation
        img = img_raw[0:4]
        
        # Load segmentation label map
        if self.split is not 'test':
            target = img_raw[4]
            #target[target==4] = 3 # Need to have consecutive integers [0, n_classes) for training
            
            # Include SVID=3 labels
            seg_dir = 'data/BRATS_Training_WMH_Revised'
            target_svid = nib.load(os.path.join(seg_dir, _sub_name + '_seg2.nii.gz')).get_fdata()
            target_svid = np.flip(target_svid, 1) # Correct AP orientation
            target[target_svid==2] = 3
        else:
            target = np.zeros_like(img[0,:]) # dummy segmentation
        target = np.expand_dims(target, axis=0)

        # Data augmentation
        if self.mode == 'train':
            img, target = self._sync_transform(img, target)
        elif self.mode == 'val':
            img, target = self._val_sync_transform(img, target)
        else:
            raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode))

        # Routine img specific processing (normalize, etc.)
        if self.transform is not None:
            img = self.transform(img, self.means[idx], self.stds[idx])

        return img, target

In [None]:
def init_brats_metrics():
    """Initialize dict for BraTS Dice metrics"""
    metrics = {}
    metrics['ET'] = {'labels': [4]}
    metrics['TC'] = {'labels': [1, 4]}
    metrics['WT'] = {'labels': [1, 2, 4]}
    metrics['SVID'] = {'labels': [3]}
    for _, value in metrics.items():
        value.update({'tp':0, 'tot':0})
    return metrics

In [None]:
def brats_validate(model, data_loader, crop_size, overlap, ctx):
    """Predict segs from val data, compare to ground truth val segs, and calculate val dice metrics"""
    # Setup metric dictionary
    metrics = init_brats_metrics()

    # Get patch index iterator
    dims = data_loader._dataset[0][1].shape[1:]
    patch_iter = get_patch_iter(dims, crop_size, overlap)
    
    # Iterate over subjects
    for i, (data, label) in enumerate(data_loader):  

        # Iterate over patches
        for inds in patch_iter:
            data_patch  = get_patch(data, inds).as_in_context(ctx)
            label_patch = get_patch(label, inds)
            label_mask = label_patch.squeeze().asnumpy()
            
            output_mask = get_output_mask(model, data_patch).asnumpy()

            # Update metrics
            for _, metric in metrics.items():
                label_mask_bin  = np.isin(label_mask, metric['labels'])
                output_mask_bin = np.isin(output_mask, metric['labels'])
                metric['tp']  += np.sum(label_mask_bin * output_mask_bin)
                metric['tot'] += np.sum(label_mask_bin) + np.sum(output_mask_bin)

    # Calculate overall metrics
    for _, metric in metrics.items():
            metric['DSC'] = 2 * metric['tp'] / metric['tot']
    return metrics

In [None]:
def log_epoch_hooks(epoch, train_loss, metrics, logger, sw):
    """Epoch logging"""
    DSCs = np.array([v['DSC'] for k,v in metrics.items()])
    DSC_avg = DSCs.mean()
    logger.info('E %d | loss %.4f | ET %.4f | TC %.4f | WT %.4f | SVID %.4f | Avg %.4f'%((epoch, train_loss) + tuple(DSCs) + (DSC_avg, )))
    sw.add_scalar(tag='Dice', value=('Val ET', DSCs[0]), global_step=epoch)
    sw.add_scalar(tag='Dice', value=('Val TC', DSCs[1]), global_step=epoch)
    sw.add_scalar(tag='Dice', value=('Val WT', DSCs[2]), global_step=epoch)
    sw.add_scalar(tag='Dice', value=('Val SVID', DSCs[3]), global_step=epoch)
    sw.add_scalar(tag='Dice', value=('Val Avg', DSCs.mean()), global_step=epoch)
    return DSC_avg

***
## Setup data loaders

In [None]:
data = np.load('data/normalization_stats.npz')
means_brain = nd.array(data['means_brain'])
stds_brain  = nd.array(data['stds_brain'])

In [None]:
trainset = MRISegDataset4DSVID(root=args.data_dir, split='train', mode='train', crop_size=args.crop_size_train, transform=brats_transform, means=means_brain, stds=stds_brain, lesion_frac=args.lesion_frac, warp_params=args.warp_params)
train_data = gluon.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True,  last_batch='rollover')

***
## Setup model

In [None]:
model = UnetGenerator(num_downs        = args.num_downs, 
                      classes          = args.classes, 
                      ngf              = args.ngf, 
                      use_bias         = args.use_bias, 
                      use_global_stats = args.use_global_stats)

In [None]:
model.hybridize()
model.collect_params().initialize(mx.init.Xavier(), force_reinit=True, ctx=args.ctx)
if args.resume.strip():
    model.load_parameters(os.path.join('../params', args.resume.strip()))

***
## Train

In [None]:
loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)

In [None]:
trainer = gluon.Trainer(model.collect_params(), args.optimizer, args.optimizer_params)

In [None]:
logger, sw = start_logger(args)

global_step = 0
for epoch in range(args.start_epoch, args.epochs):
    tbar = tqdm(train_data)
    train_loss = 0.
    for i, (data, label) in enumerate(tbar):
        n_batch = data.shape[0]
        label = label.squeeze(axis=1)
        label = split_and_load(label, args.ctx)
        data  = split_and_load(data,  args.ctx)
        with autograd.record():
            losses = [loss(model(X), Y) for X, Y in zip(data, label)]
            for l in losses:
                l.backward()
                train_loss += l.sum().asnumpy() / n_batch
        trainer.step(n_batch)

        # Mini-batch logging
        sw.add_scalar(tag='Cross_Entropy', value=('Train loss', l.mean().asscalar()), global_step=global_step)
        global_step += 1
        tbar.set_description('E %d | loss %.4f'%(epoch, train_loss/(i+1)))
    
    # Epoch logging
    best_wt  = save_params(model, 1, 0, epoch, args.save_interval, args.save_prefix)
sw.export_scalars(args.save_prefix + '_scalar_dict.json')
sw.close()