# Predict unknown validation or test set data

In [None]:
%reload_ext rpy2.ipython

import os
import argparse
import glob
import nibabel as nib
import numpy as np
from tqdm import tqdm_notebook as tqdm

import mxnet as mx
from mxnet import gluon, ndarray as nd

from unet import *

***
## Setup hyperparameters

In [None]:
args = argparse.Namespace()

args.data_dir = '../brats_2018_4D'
args.weights_dir = '../params/baseline/bagged_ensemble/ensemble'
args.output_dir = '../predictions/val__baseline__bagged_ensemble_soft_prediction_190101'

# Training
args.num_workers = 1
GPU_COUNT = 1
args.ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
# args.ctx = [mx.gpu(1)]

# Unet
args.num_downs = 4 # Number of encoding/downsampling layers
args.classes = 4 # Number of classes for segmentation, including background
args.ngf = 32 # Number of channels in base/outermost layer
args.use_bias = True # For conv blocks
args.use_global_stats = True # For BN blocks

# Pre/post-processing
args.pad_size_val = [240, 240, 160] # Should be input vol dims unless 'crop_size_val' is larger
args.crop_size_val = [240, 240, 160] # Should be divisible by 2^num_downs
args.overlap = 0 # Fractional overlap for val patch prediction, combined with voting
args.output_dims = [240, 240, 155]

***
## Setup data loader

In [None]:
data = np.load('normalization_stats_test.npz')
means_brain = nd.array(data['means_brain'])
stds_brain  = nd.array(data['stds_brain'])

In [None]:
testset = MRISegDataset4D(root=args.data_dir, split='test', mode='val', crop_size=args.pad_size_val, transform=brats_transform, means=means_brain, stds=stds_brain)
test_data = gluon.data.DataLoader(testset, batch_size=1, num_workers=args.num_workers, shuffle=False, last_batch='keep')

***
## Extract template NifTI header

In [None]:
subdir = os.path.normpath(testset.paths()[0])
img_path = os.path.join(subdir, os.listdir(subdir)[0])
hdr = nib.load(img_path).header

***
## Setup model and load ensemble weights

In [None]:
model = UnetGenerator(num_downs        = args.num_downs, 
                      classes          = args.classes, 
                      ngf              = args.ngf, 
                      use_bias         = args.use_bias, 
                      use_global_stats = args.use_global_stats)

In [None]:
model.collect_params().initialize(force_reinit=True, ctx=args.ctx)

In [None]:
model.hybridize()

In [None]:
weights_paths = [os.path.join(args.weights_dir, X) for X in sorted(os.listdir(args.weights_dir))]

***
## Predict test data (for each set of model `weights` in ensemble)

Save intermediate output maps with voxelwise softmax class probabilities. 

In [None]:
def brats_predict(model, data, crop_size, overlap, n_classes, ctx): 
    output = model(data.as_in_context(ctx)).squeeze().softmax(axis = 0).asnumpy()
    return output

In [None]:
def img_unpad(img, dims):
    """Unpad image vol back to original input dimensions"""
    pad_dims = img.shape[1:]
    xmin, ymin, zmin = 0, 0, 0
    if pad_dims[0] > dims[0]:
        xmin = (pad_dims[0] - dims[0]) // 2
    if pad_dims[1] > dims[1]:
        ymin = (pad_dims[1] - dims[1]) // 2
    if pad_dims[2] > dims[2]:
        zmin = (pad_dims[2] - dims[2]) // 2
    return img[:, xmin : xmin + dims[0],
                  ymin : ymin + dims[1],
                  zmin : zmin + dims[2]]

In [None]:
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
    
for weights_path in tqdm(weights_paths):
    model.load_parameters(weights_path, ctx=args.ctx[0])
    output_dir = os.path.join(args.output_dir, 'runs', os.path.basename(weights_path).split('.params')[0])
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for isub, (data, _) in enumerate(tqdm(test_data)):
        subID = os.path.basename(os.path.normpath(testset.paths()[isub]))
        mask = brats_predict(model, data, args.crop_size_val, args.overlap, n_classes=args.classes, ctx=args.ctx[0])
        mask = img_unpad(mask, args.output_dims) # Crop back to original BraTS dimensions
        mask = np.flip(mask, 2) # Flip AP orientation back to original BraTS convention
        mask = mask * 1000
        mask = mask.transpose((1,2,3,0))
        mask = mask.astype(np.int16)
        mask_nii = nib.Nifti1Image(mask, None, header=hdr)
        mask_nii.to_filename(os.path.join(output_dir, subID + '.nii.gz'))

***
## Combine ensemble predictions

* Assign output class to background `0` if predicted probability of background class is > 0.5.
* Otherwise, assign output class to the maximum of the three foreground classes.

In [None]:
output_dir = os.path.join(args.output_dir, 'final')
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

run_dirs_parent = os.path.join(args.output_dir, 'runs')
run_dirs = [os.path.join(run_dirs_parent, X) for X in os.listdir(run_dirs_parent)]
for isub in tqdm(range(len(testset))):
    subID = os.path.basename(os.path.normpath(testset.paths()[isub]))
    mask = np.empty(tuple(args.output_dims) + (args.classes,) + (len(run_dirs),))
    for irun, run_dir in enumerate(run_dirs):
        img_path = os.path.join(run_dir, subID + '.nii.gz')
        mask[..., irun] = nib.load(img_path).get_fdata()
    mask_sum = mask.sum(axis = -1)
    mask_out = mask_sum[..., 1:].argmax(axis = -1) + 1
    not_bg = mask_sum[..., 0] < (0.5 * len(run_dirs) * 1000)
    mask_out = mask_out * not_bg
    mask_out[mask_out == 3] = 4 # Convert tissue class labels back to original BraTS convention
    mask_nii = nib.Nifti1Image(mask_out, None, header=hdr)
    mask_nii.to_filename(os.path.join(output_dir, subID + '.nii.gz'))