# Implement highres3dnet

Paper: https://arxiv.org/abs/1707.01992

Preprocessing steps
---
- Anatomical
    1. Load.
    1. Squeeze data array.
    1. Check for 3 dimensions.
    1. Add one color channel.
- Labels
    1. Load.
    1. Squeeze data array.
    1. Check for 3 dimensions.
    1. One-hot encode.
   
Loading steps
---
1. Load anatomical and labels.
1. Get blocks for each array.
1. Feed those blocks into the model in batches.

Questions
---
1. Should we impose a restriction, where batch size must be evenly divisible by number of blocks (viewpoints) that are in a volume? Probably.

Todo
---
- Look into [`keras.utils.multi_gpu_model`](https://keras.io/utils/#multi_gpu_model) to train on multiple GPUs. This is only available for the TensorFlow backend.
- Learn about one-hot encoding / decoding. Encode the array of labels with `K.one_hot()`. Decode the predictions with `K.argmax()`. `K.one_hot()` simply calls `tf.one_hot()`, so this would restrict us to the TensorFlow backend, which is fine for now.


Notes
---
- NiftyNet trains on a sliding window over the 3D data. This should have that ability to lower GPU memory requirements.

Future considerations
---
- Modify `highres3dnet` to use ResNeXt architechture.

In [None]:
import logging
import os

import nibabel as nib
import numpy as np
import pandas as pd
import skimage
from sklearn.feature_extraction.image import extract_patches
import tensorflow as tf
from tensorflow.python.keras import backend as K

from highres3dnet import dice_coef, dice_loss, HighRes3DNet

logger = logging.getLogger(name=__name__)

In [None]:
config = {
    'batch_size': 16,
    'block_shape': (64, 64, 64, 1),  # size of input to model.
    'image_data_format': 'channels_last',
    'n_classes': 2,
    'volume_shape': (256, 256, 256, 1),
}

K.set_image_data_format(config['image_data_format'])

In [None]:
input_filepath = "/om/user/jakubk/nobrainer-code/niftynet_to_keras/t1_brainmask.csv"
df_input = pd.read_csv(input_filepath)
df_input.head()

In [None]:
model = HighRes3DNet(n_classes=config['n_classes'], input_shape=config['block_size'])

# Use multiple GPUs.
# gpu_ids = [int(ss) for ss in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
# model = keras.utils.multi_gpu_model(model, gpus=gpu_ids)

model.input_shape
model.summary()

In [None]:
row = 0
offset = 65
data_input = load_volume(df_input.loc[row, 't1'])
# data_input = data_input[offset:volume_input_shape[0]+offset, 
#                         offset:volume_input_shape[1]+offset, 
#                         offset:volume_input_shape[2]+offset]

_validate_dims(data_input)
# data_input = _reshape(data_input, volume_input_shape)
# data_input = np.expand_dims(data_input, 0)

data_test = load_volume(df_input.loc[row, 'brainmask'])
# data_test = data_test[offset:volume_input_shape[0]+offset, 
#                       offset:volume_input_shape[1]+offset, 
#                       offset:volume_input_shape[2]+offset]
_validate_dims(data_test)

data_test = K.one_hot(data_test, num_classes=2)

labels = K.stack((data_test,))

labels = K.stack((data_test,))

In [None]:
model.compile('adam', dice_loss)

In [None]:
# model.fit(data_input, data_test, batch_size=1, verbose=1)
model.fit(data_input, labels_np, batch_size=1, verbose=1)

In [None]:
import math


def load_volume(filepath, return_affine=False, c_contiguous=True):
    """Return data given filepath to volume. Optionally return affine array.
    
    Making the data array contiguous takes more time during loading, but this
    ultimately saves time when viewing blocks of data with `skimage`.
    """
    img = nib.load(filepath)
    data = np.asarray(img.dataobj)
    if c_contiguous:
        data = np.ascontiguousarray(data)
    if return_affine:
        return data, img.affine
    return data


def _validate_dims(a, ndim=3):
    """Raise `ValueError` if Numpy array `a` has fewer than `ndims` dimensions."""
    if a.ndim != ndim:
        msg = "Expected {} dimensions but got {}.".format(ndim, a.ndim)
        raise ValueError(msg.format(a.ndim))


def get_blocks(a, block_shape):
    """
    Examples
    --------
    >>> arr = np.ones((4*4)).reshape(4, 4)
    >>> blocks = get_blocks(arr, (2, 2))
    >>> blocks.shape
    (4, 2, 2)
    >>> np.lib.stride_tricks.as_strided(blocks, shape=arr.shape, strides=arr.strides)
    """
    return skimage.util.view_as_blocks(a, block_shape).reshape(-1, *block_shape)


def get_num_blocks(volume_shape, block_shape):
    """Return number of non-overlapping blocks of `block_shape` in `volume_shape`."""
    if len(volume_shape) != len(block_shape):
        raise ValueError("Volume and batch must have same number of dimensions.")
    return np.divide(volume_shape, block_shape).prod()


class VolumeSequence(tf.keras.utils.Sequence):
    """"""
    def __init__(self, x_files, y_files, batch_size, volume_shape, 
                 block_shape=None):
        self.x, self.y = x_files, y_files
        self.batch_size = batch_size
        self.volume_shape = volume_shape
        self.block_shape = block_shape

    def __len__(self):
        # return math.ceil(len(self.x) / self.batch_size / self._volumes_per_batch)
        return math.ceil(len(self.x) * self._blocks_per_volume / self.batch_size)

    def __getitem__(self, idx):
        """Assumes that each input volume is the same shape."""
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        

        return (

#             np.array([resize(imread(file_name), (200, 200)) for file_name in batch_x]), 
#             np.array(batch_y)
        )
    
    @property
    def _blocks_per_volume(self):
        """Number of non-overlapping blocks per volume."""
        return get_num_blocks(volume_shape=self.volume_shape, 
                              block_shape=self.block_shape)

    @property
    def _volumes_per_batch(self):
        return self.batch_size / self._blocks_per_volume
        # return get_volumes_per_batch(num_blocks=self._blocks_per_volume, 
        #                              batch_size=self.batch_size)

In [None]:
aa = VolumeSequence(x_files=df_input['t1'], 
                    y_files=df_input['brainmask'], 
                    batch_size=config['batch_size'], 
                    volume_shape=config['volume_shape'], 
                    block_shape=config['block_shape'],
)

In [None]:
# __getitem__ should find the correct volume(s) and return the arrays from that.
# batch 0 ...

In [None]:
aa._blocks_per_volume

In [None]:
aa._volumes_per_batch