# Predict training CV out-of-fold samples (i.e. internal validation set)

In [None]:
%reload_ext rpy2.ipython

import os
import re
import argparse
import nibabel as nib
import numpy as np
from tqdm import tqdm_notebook as tqdm

import mxnet as mx
from mxnet import gluon, ndarray as nd

from unet import *

***
## Setup hyperparameters

In [None]:
args = argparse.Namespace()

args.data_dir = '../brats_2018_4D'
args.weights_dir = '../params/svid5/training_cv/ensemble/'
args.output_dir = '../predictions/training__svid5__training_cv_190504'

args.folds = np.arange(10)

# Training
args.num_workers = 1
# GPU_COUNT = 1
# args.ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
args.ctx = [mx.gpu(0)]

# Unet
args.num_downs = 4 # Number of encoding/downsampling layers
args.classes = 5 # Number of classes for segmentation, including background
args.ngf = 32 # Number of channels in base/outermost layer
args.use_bias = True # For conv blocks
args.use_global_stats = True # For BN blocks

# Pre/post-processing
args.pad_size_val = [240, 240, 160] # Should be input vol dims unless 'crop_size_val' is larger
args.crop_size_val = [240, 240, 160] # Should be divisible by 2^num_downs
args.overlap = 0 # Fractional overlap for val patch prediction, combined with voting
args.output_dims = [240, 240, 155]

***
## Load normalization data

In [None]:
data = np.load('normalization_stats.npz')
means       = nd.array(data['means'])
stds        = nd.array(data['stds'])
means_brain = nd.array(data['means_brain'])
stds_brain  = nd.array(data['stds_brain'])

***
## Setup model and load ensemble weights

In [None]:
model = UnetGenerator(num_downs        = args.num_downs, 
                      classes          = args.classes, 
                      ngf              = args.ngf, 
                      use_bias         = args.use_bias, 
                      use_global_stats = args.use_global_stats)

In [None]:
model.collect_params().initialize(force_reinit=True, ctx=args.ctx)

In [None]:
weights_paths = [os.path.join(args.weights_dir, X) for X in sorted(os.listdir(args.weights_dir))]
weights_folds = np.array([int(re.sub('.*fold([0-9]).*', r'\1', os.path.basename(X))) for X in weights_paths])
ikeep = np.where(np.isin(weights_folds, args.folds))[0]
weights_paths = [weights_paths[i] for i in ikeep]

***
## Predict out-of-fold validation data

In [None]:
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

# Loop over folds
for ifold, weights_path in enumerate(tqdm(weights_paths)):
    # Determine fold
    fold = re.sub('.*fold([0-9]).*', r'\1', os.path.basename(weights_path))
    fold_inds = get_k_folds(n=285, k=10, seed=1)[int(fold)]
    
    # Setup dataloader
    valset = MRISegDataset4D(root=args.data_dir, split='val', mode='val', crop_size=args.pad_size_val, transform=brats_transform, means=means_brain, stds=stds_brain, fold_inds=fold_inds)
    val_data = gluon.data.DataLoader(valset, batch_size=1, num_workers=args.num_workers, shuffle=False, last_batch='keep')
    
    # Load template NifTI header
    if ifold == 0:
        subdir = os.path.normpath(valset.paths()[0])
        img_path = os.path.join(subdir, os.listdir(subdir)[0])
        hdr = nib.load(img_path).header
    
    # Load model weights
    model.load_parameters(weights_path, ctx=args.ctx[0])
    
    # Setup output dir
    output_dir = os.path.join(args.output_dir, 'runs', os.path.basename(weights_path).split('.params')[0])
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    # Loop over out-of-fold subjects
    for isub, (data, _) in enumerate(tqdm(val_data)):
        subID = os.path.basename(os.path.normpath(valset.paths()[isub]))
        mask = brats_predict(model, data, args.crop_size_val, args.overlap, n_classes=args.classes, ctx=args.ctx[0])
        mask = img_unpad(mask, args.output_dims) # Crop back to original BraTS dimensions
        mask = np.flip(mask, 1) # Flip AP orientation back to original BraTS convention
#         mask[mask == 3] = 4 # Convert tissue class labels back to original BraTS convension
        mask = mask.astype(np.int16)
        mask_nii = nib.Nifti1Image(mask, None, header=hdr)
        mask_nii.to_filename(os.path.join(output_dir, subID + '.nii.gz'))

***
## Copy predictions into `final` directory

```bash
cd ~/Research/unet_brats/predictions/brats_training_cv
mkdir final
cp runs/*/* final/
```