# Implement highres3dnet

Paper: https://arxiv.org/abs/1707.01992

Notes
---
- NiftyNet trains on a sliding window over the 3D data.

Future considerations
---
- Modify `highres3dnet` to use ResNeXt architechture.

In [None]:
import logging
import os

import nibabel as nib
import numpy as np
import pandas as pd

from highres3dnet import dice_coef, dice_loss, HighRes3DNet

In [None]:
logger = logging.getLogger(name=__name__)

In [None]:
input_filepath = "/om/user/jakubk/nobrainer-code/niftynet_to_keras/t1_brainmask.csv"
df_input = pd.read_csv(input_filepath)
df_input.head()

In [None]:
def load_image(filepath):
    """Return data and affine given filepath to volume."""
    img = nib.load(filepath)
    img.uncache()
    return img.get_fdata(), img.affine


def _validate_dims(data):
    """Raise `ValueError` if array of data has fewer than 3 dimensions."""
    if data.ndim < 3:
        raise ValueError("Invalid number of dimensions. Input volume must have"
                         " at least 3 dimensions.")


def _reshape(data, newshape):
    """Return new array of shape `newshape`."""
    logger.info("Resizing image to shape {} from shape {}"
                 .format(data.shape, newshape))
    return data.reshape(newshape)

In [None]:
volume_input_shape = (256, 256, 256, 1)
model = HighRes3DNet(2, input_shape=volume_input_shape)
model.input_shape

In [None]:
row = 0
data_input, data_input_affine = load_image(df_input.loc[row, 't1'])
_validate_dims(data_input)
data_input = _reshape(data_input, volume_input_shape)

data_test, data_test_affine = load_image(df_input.loc[row, 'brainmask'])
_validate_dims(data_test)
data_test = _reshape(data_test, volume_input_shape)

In [None]:
data_input = np.expand_dims(data_input, 0)
data_test = np.expand_dims(data_test, 0)

In [None]:
model.compile('adam', dice_loss)

In [None]:
model.fit(data_input, data_test, batch_size=1, verbose=1)

In [None]:
batch_data = np.stack((img_data, img_data, img_data))

In [None]:
model.train_on_batch(batch_data, mask_data)