# Data normalization
Script to pre-compute image volume means and standard deviations on a per-subject and per-channel basis.

In [None]:
%reload_ext rpy2.ipython

import os
import numpy as np
from tqdm import tqdm

import mxnet as mx
from mxnet import gluon, autograd, ndarray as nd
from mxnet.gluon import nn, utils

import gluoncv

from unet_brats.unet import *

***
## Setup hyperparameters

In [None]:
data_dir = '/path/to/gk/data'
split = 'train'
crop_size = [256, 256, 128]

***
## Setup data loader

In [None]:
dataset = MRISegDataset(root=data_dir, split=split, mode='val', crop_size=crop_size)

***
## Plot before normalization

In [None]:
img, mask = dataset[0]
img = img.asnumpy()

In [None]:
%%R -i img -w 800 -h 300 -u px

library(neurobase)
ortho2(img[1,,,], mfrow=c(1,3))
ortho2(img[2,,,], mfrow=c(1,3))

***
## Calculate mean and standard deviation for each channel

In [None]:
means = np.zeros((len(dataset), 2))
stds  = np.zeros((len(dataset), 2))

for i, (data, _) in enumerate(tqdm(dataset)):
    data = data.asnumpy()
    #import pdb; pdb.set_trace()
    means[i,] = data.mean(axis=(1,2,3))
    stds[i,]  = data.std(axis=(1,2,3))

In [None]:
means.shape

In [None]:
means[0:5]

In [None]:
stds[0:5]

In [None]:
np.savez_compressed('../data/normalization_stats_' + split + '.npz', 
                    means       = means, 
                    stds        = stds)

***
## Confirm normalization

### Load example subject

In [None]:
data = np.load('../data/normalization_stats_' + split + '.npz')
means       = nd.array(data['means'])
stds        = nd.array(data['stds'])

In [None]:
def transform(img, means, stds):
    means = means.reshape(-1,1,1,1)
    stds  = stds.reshape(-1,1,1,1)
    return (img - means) / stds

In [None]:
dataset = MRISegDataset(root=data_dir, split=split, mode='val', crop_size=crop_size, transform=transform, means=means, stds=stds)

In [None]:
data = dataset[0][0].asnumpy()

### Calculate `mean` and `std`

(Should be `mean=0`, `std=1`)

In [None]:
[data[i,].mean() for i in range(len(data))]

In [None]:
[data[i,].std() for i in range(len(data))]

***
## Plot after normalization

In [None]:
dataset = MRISegDataset(root=data_dir, split=split, mode='val', crop_size=crop_size, transform=transform, means=means, stds=stds)

In [None]:
img, mask = dataset[0]
img = img.asnumpy()

In [None]:
%%R -i img -w 800 -h 300 -u px

library(neurobase)
ortho2(img[1,,,], mfrow=c(1,3), window=c(-4,4))
ortho2(img[2,,,], mfrow=c(1,3), window=c(-4,4))