# Implement highres3dnet

Paper: https://arxiv.org/abs/1707.01992

Preprocessing steps
---
- Anatomical
    1. Load.
    1. Squeeze data array.
    1. Check for 3 dimensions.
    1. Add one color channel.
- Labels
    1. Load.
    1. Squeeze data array.
    1. Check for 3 dimensions.
    1. One-hot encode.
   
Loading steps
---
1. Load anatomical and labels.
1. Get blocks for each array.
1. Feed those blocks into the model in batches.

Questions
---
1. Should we impose a restriction, where batch size must be evenly divisible by number of blocks (viewpoints) that are in a volume? Probably.

Todo
---
- Look into [`keras.utils.multi_gpu_model`](https://keras.io/utils/#multi_gpu_model) to train on multiple GPUs. This is only available for the TensorFlow backend.
- Learn about one-hot encoding / decoding. Encode the array of labels with `K.one_hot()`. Decode the predictions with `K.argmax()`. `K.one_hot()` simply calls `tf.one_hot()`, so this would restrict us to the TensorFlow backend, which is fine for now.


Notes
---
- NiftyNet trains on a sliding window over the 3D data. This should have that ability to lower GPU memory requirements.

Future considerations
---
- Modify `highres3dnet` to use ResNeXt architechture.

In [None]:
import logging
import os
from warnings import warn

import nibabel as nib
import numpy as np
from numpy.lib.stride_tricks import as_strided
import pandas as pd
import tensorflow as tf
from tensorflow.python.keras import backend as K

from highres3dnet import dice_loss, HighRes3DNet

logger = logging.getLogger(name=__name__)

NUM_CLASSES = 2
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
CSV_FILEPATH = "/om2/user/jakubk/openmind-surface-data/file-lists/master_file_list_brainmask.csv"
# CSV_FILEPATH = "/om/user/jakubk/nobrainer-code/niftynet_to_keras/t1_brainmask.csv"
WINDOW_SHAPE = (128, 128, 128)
NUM_CHANNELS = 1
INPUT_SHAPE = (*WINDOW_SHAPE, NUM_CHANNELS)
TARGET_DTYPE = 'uint8'
TENSORBOARD_BASE_DIR = "/om/user/jakubk/nobrainer-code/niftynet_to_keras/models"

sess = tf.Session()
K.set_session(sess)
K.set_image_data_format('channels_last')

In [None]:
def _get_timestamp():
    import datetime
    return str(datetime.datetime.now()).split('.')[0].replace(' ', '_')


def get_tensorboard_dir(base_dir=None):
    if base_dir is None:
        base_dir = os.getcwd()
    window = "_".join(str(ii) for ii in WINDOW_SHAPE)
    rel_dir = (
        "highres3dnet-{num_classes}_classes-{lr}_lr-{batch}_batch-{window}_window-{ts}"
    ).format(
        num_classes=NUM_CLASSES, lr=LEARNING_RATE, batch=BATCH_SIZE, 
        window=window, ts=_get_timestamp())
    return os.path.join(base_dir, rel_dir, 'logs')


def load_volume(filepath, return_affine=False, c_contiguous=True, dtype=None):
    """Return data given filepath to volume. Optionally return affine array.

    Making the data array contiguous takes more time during loading, but this
    ultimately saves time when viewing blocks of data with `skimage`.
    """
    img = nib.load(filepath)
    data = np.asarray(img.dataobj)
    if dtype is not None:
        data = data.astype(dtype)
    img.uncache()
    if c_contiguous:
        data = np.ascontiguousarray(data)
    if return_affine:
        return data, img.affine
    return data


def one_hot(a, **kwargs):
    """Return one-hot array of N-D array `a`."""
    # https://stackoverflow.com/a/37323404/5666087
    n_values = int(np.max(a) + 1)
    return np.eye(n_values, **kwargs)[a]


def _preprocess_data(data):
    data = view_as_blocks(data, WINDOW_SHAPE).reshape(-1, *WINDOW_SHAPE)
    return data[Ellipsis, np.newaxis]


def _preprocess_target(target):
    target = one_hot(target, dtype=TARGET_DTYPE)
    new_shape = (*WINDOW_SHAPE, NUM_CLASSES)
    return view_as_blocks(target, new_shape).reshape(-1, *new_shape)


def view_as_blocks(arr_in, block_shape):
    """Block view of the input n-dimensional array (using re-striding).
    Blocks are non-overlapping views of the input array.

    Parameters
    ----------
    arr_in : ndarray
        N-d input array.
    block_shape : tuple

    Notes
    -----
    Copied from `skimage.util.view_as_blocks` to avoid having to install the
    entire package + dependencies.
    """
    if not isinstance(block_shape, tuple):
        raise TypeError('block needs to be a tuple')

    block_shape = np.array(block_shape)
    if (block_shape <= 0).any():
        raise ValueError("'block_shape' elements must be strictly positive")

    if block_shape.size != arr_in.ndim:
        raise ValueError("'block_shape' must have the same length "
                         "as 'arr_in.shape'")

    arr_shape = np.array(arr_in.shape)
    if (arr_shape % block_shape).sum() != 0:
        raise ValueError("'block_shape' is not compatible with 'arr_in'")

    # -- restride the array to build the block view

    if not arr_in.flags.contiguous:
        warn(RuntimeWarning("Cannot provide views on a non-contiguous input "
                            "array without copying."))

    arr_in = np.ascontiguousarray(arr_in)

    new_shape = tuple(arr_shape // block_shape) + tuple(block_shape)
    new_strides = tuple(arr_in.strides * block_shape) + arr_in.strides

    arr_out = as_strided(arr_in, shape=new_shape, strides=new_strides)

    return arr_out


def blocks_to_volume(arr):
    """Combine 4D array of non-overlapping cubes to 3D cube array.
    
    Examples
    --------
    >>> mask.shape
    (8, 128, 128, 128)
    >>> blocks_to_volume(mask).shape
    (256, 256, 256)
    """
    n_blocks = arr.shape[0]
    new_ndim = arr.ndim - 1
    nn = int(n_blocks ** (1 / new_ndim))
    new_shape = (nn * arr.shape[1],) * new_ndim
    return (
        arr.reshape(nn, nn, nn, *arr.shape[1:])
        .transpose(0, 3, 1, 4, 2, 5)
        .reshape(new_shape)
    )

In [None]:
df_input = pd.read_csv(CSV_FILEPATH)

model = HighRes3DNet(n_classes=NUM_CLASSES, input_shape=INPUT_SHAPE)

# Use multiple GPUs.
# gpu_ids = [int(ss) for ss in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
# model = keras.utils.multi_gpu_model(model, gpus=gpu_ids)

adam = tf.keras.optimizers.Adam(lr=LEARNING_RATE)
model.compile(adam, dice_loss)

In [None]:
# https://github.com/keras-team/keras/issues/5935#issuecomment-289041967
class MemoryCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, log={}):
        import resource
        # max resident set size
        usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        usage = usage * resource.getpagesize() / 1000000.0
        print("Usage: {:0.0f} Mb".format(usage))

In [None]:
_tensorboard_dir = get_tensorboard_dir(base_dir=TENSORBOARD_BASE_DIR)

print("++ Saving Tensorboard information to\n{}".format(_tensorboard_dir))

callbacks = [
    tf.keras.callbacks.TensorBoard(
        log_dir=_tensorboard_dir,
        write_graph=False,
        batch_size=BATCH_SIZE,
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(_tensorboard_dir, "..", "model-{epoch:02d}.h5"),
        period=50,
    ),
    tf.keras.callbacks.CSVLogger(
        filename=os.path.join(_tensorboard_dir, "..", "training.log"),
        append=True,
    ),
    MemoryCallback(),
]

In [None]:
for index, these_files in df_input.iterrows():
    
    data = load_volume(these_files['t1'])
    data = _preprocess_data(data)
    
    target = load_volume(these_files['brainmask'], dtype=TARGET_DTYPE)
    target = _preprocess_target(target)
    
    model.fit(
        x=data,
        y=target,
        epochs=1,
        batch_size=BATCH_SIZE,
        verbose=1,
        # callbacks=callbacks
    )
    if index > 25:
        break

# Inference

In [None]:
# To run on CPU only.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [None]:
import logging
from warnings import warn

import nibabel as nib
import numpy as np
from numpy.lib.stride_tricks import as_strided
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from tensorflow.python.keras import backend as K

from highres3dnet import dice_loss, HighRes3DNet

logger = logging.getLogger(name=__name__)

NUM_CLASSES = 2
BATCH_SIZE = 1
LEARNING_RATE = 0.001
CSV_FILEPATH = "/om/user/jakubk/nobrainer-code/niftynet_to_keras/t1_brainmask.csv"
WINDOW_SHAPE = (128, 128, 128)
NUM_CHANNELS = 1
INPUT_SHAPE = (*WINDOW_SHAPE, NUM_CHANNELS)
TARGET_DTYPE = 'uint8'

sess = tf.Session()
K.set_session(sess)
K.set_image_data_format('channels_last')

In [None]:
model = tf.keras.models.load_model(
    'models/highres3dnet-2_classes-0.0001_lr-1_batch-128_128_128_window-2018-01-29_12:57:45/model-00.h5', 
    custom_objects={'dice_loss': dice_loss}
)

In [None]:
mindboggle_data = pd.read_csv('/om2/user/jakubk/mindboggle-101/all-101-files.csv')

In [None]:
new = np.zeros((256, 256, 256), dtype='<i2')

In [None]:
foo = anat_true.copy()
foo.resize(256, 256, 256)

In [None]:
ii = 0
anat_true, affine_anat_true = load_volume(mindboggle_data.loc[ii, 't1'], return_affine=True)
anat_true = np.pad(anat_true, ((0,0),(0,0),(48, 48)), 'constant')
anat_true_orig = anat_true.copy()
anat_true = _preprocess_data(anat_true)

In [None]:
plt.matshow(anat_true_orig[:, 150, :], cmap='gray')
plt.show()

In [None]:
import time
t0 = time.time()
prediction = model.predict(anat_true)
diff = time.time() - t0
print(diff)

In [None]:
mask = prediction.argmax(-1)
mask.shape

In [None]:
out = blocks_to_volume(mask)

In [None]:
mask_img = nib.Nifti1Image(out, affine_anat_true)

In [None]:
mask_img.to_filename('testout.nii.gz')

# Alternative reshape methods

In [None]:
import numpy as np

anatp = anat_proc.squeeze()
anatp.shape

def permute_axes(a, div_len_by=2):
    shp1 = np.hstack([(i // div_len_by, 2) for i in a.shape])
    shp2 = [8,] + [i // div_len_by for i in a.shape]
    a = a.reshape(shp1)
    print("after reshape1", a.shape)
    a = a.transpose(0,2,4,1,3,5)
    print("after transpose", a.shape)
    a = a.reshape(shp2)
    print("after reshape2", a.shape)
    return a

permute_axes(anat).shape

In [None]:
new = (
    anatp.reshape((128, 128, 128, 2, 2, 2))
    .transpose(0, 3, 1, 4, 2, 5)
    .reshape(256, 256, 256)
)
print(new.shape)

assert new.shape == (256,256,256), "Incorrect shape"

plt.matshow(new[:, 150, :], cmap='gray')

In [None]:
a = "{x} {y} {z}"
d = {"x":3, "y":"blah", "z":0}
a.format(**d)

# One-hot for aparc+aseg

In [None]:
import pandas as pd

In [None]:
aparcaseg_csv = "/om2/user/jakubk/openmind-surface-data/file-lists/master_file_list_aparcaseg.csv"
aparcaseg_mapping_csv = "/om2/user/jakubk/openmind-surface-data/data/FreeSurferColorLUT-mapping-108.csv"

In [None]:
df = pd.read_csv(aparcaseg_csv)
aparcaseg_mapping = pd.read_csv(aparcaseg_mapping_csv, index_col='label')

In [None]:
data = load_volume(df.loc[10, 'aparcaseg'], dtype='uint64')

In [None]:
def _get_mapping(filepath):
    import csv
    with open(aparcaseg_mapping_csv) as csvfile:
        reader = csv.reader(csvfile)
        return {int(aa[0]): int(aa[1]) for aa in reader if aa[0] != 'label'}

In [None]:
mapping = _get_mapping(aparcaseg_mapping_csv)
mapping_values_arr = np.array(list(mapping.values()))

In [None]:
# Replace values of array.

In [None]:
# Zero any values not in the list of labels we are using.
_not_in_mappin_mask = ~np.isin(data, mapping_keys_arr)
data[_not_in_mappin_mask] = 0

# Convert to range [0, 107].
for orig, new in mapping.items():
    data[data==orig] = new