# Data processing and sampling with tensorflow data input pipelines 

The aim of this tutorial is to learn how to:
- read 3D images in nifiti format for training a neural network
- use tensorflow Dataset for efficient sampling of mini batches and on the fly data augmentation

This tutorial uses simulated PET/MR data and a network that has two input channels and one output channel. However, the basic concept of using a tensor data input pipeline generalizes easily to other examples using different data, different dimensions, or a different number of channels.

**Before you run this notebook: Make sure that have you run the** ```00_introduction.ipynb``` **notebook which downloads the data needed in the following notebookes into the folder** ```brainweb_petmr```

This tutorial is inspired by  this keras tutorial https://keras.io/examples/vision/3D_image_classification/ on 3D CT image classifiction, which is also highly recommended.

## Background

Efficient data sampling and data augmentation are crucial when training convolutional neural networks (CNN) - especially when using 3D images. Tensorflow offers the tf.data.Dataset class which allows to do that in very elegant and efficient way. In the following we will read a few simulated PET and MR data sets and demonstrate how to setup a tensorflow Dataset pipeline with on-the-fly data augmentation. 

## Data set

The data set we will use consists of 20 subjects derived from the brainweb phantom. 

For each subject the data is organized as follows:
- subjectXX
  - mu.nii.gz -> (attenuation image)
  - t1.nii.gz -> (high resolution and low noise T1 MR)
  - sim_0 -> first simulated PET acquisition
    - true_pet.nii.gz -> true tracer uptated
    - osem_psf_counts_0.0E+00.nii.gz -> OSEM recon of simulated noise free data
    - osem_psf_counts_1.0E+07.nii.gz -> OSEM recon of simulated noisy data (1e7 counts) -> high noise level
    - osem_psf_counts_5.0E+08.nii.gz -> OSEM recon of simulated noisy data (5e8 counts) -> low noise level

  - sim_1 -> second simulated PET acquisition
    - osem_psf_counts_0.0E+00.nii.gz
    - osem_psf_counts_1.0E+07.nii.gz
    - osem_psf_counts_5.0E+08.nii.gz
    - true_pet.nii.gz
  - sim_2 -> third simulated PET acquisition
    - osem_psf_counts_0.0E+00.nii.gz
    - osem_psf_counts_1.0E+07.nii.gz
    - osem_psf_counts_5.0E+08.nii.gz
    - true_pet.nii.gz

All data sets have a shape of (256,256,258) and a voxel size of 1mm x 1mm x 1mm and are provided in nifti format. All PET acquisitions have different contrasts.

In [None]:
# import all the modules that we need for this tutorial

import tensorflow as tf
import nibabel as nib
import numpy as np
import pathlib

# enable interactive plots with the ipympl package
%matplotlib widget

Make sure that the simulated brainweb PET/MR data sets were downloaded and that the main data path in the cell below is correct. Let's first find all data directories. In this tutorial **we load only the first 4 subjects to speed up execution**. Since there are 3 simulated acquisitions per subject, we will get in total 4*3 = 12 data sets.

In [None]:
# adjust this variable to the path where the simulated PET/MR data from zenodo was unzipped
data_dir   = pathlib.Path('brainweb_petmr')
batch_size = 10
nsubjects  = 4

# get the paths of the first nsubjects subjects
# we only use a few subjects in this tutorial to speed up the data reading
subject_paths = sorted(list(data_dir.glob('subject??')))[:nsubjects]

Each simulated data set contains a low resolution and noisy standard OSEM PET reconstruction, a high resolution and low noist T1 MR, and a high resolution and low noise target reconstruction. All images volumes are saved in nifti format. Let's define a first helper function that uses nibabel to load a 3D nifti volume in defined orientation (LPS). The standard orientation of nifti is RAW which is why we have to flip the 0 and 1 axis.

In [None]:
def load_nii_in_lps(fname):
  """ function that loads nifti file and returns the volume and affine in 
      LPS orientation
  """
  nii = nib.load(fname)
  nii = nib.as_closest_canonical(nii)
  vol = np.flip(nii.get_fdata(), (0,1))

  return vol, nii.affine

When training neural networks, it is important to normalize the intensity of the input data. In the tutorial, we use a robust maximum which is the maximum of a heavily smoothed version of the input volume. The smoothing is important when working with noisy data.

In [None]:
def robust_max(volume, n = 7):
    """ function that return the max of a heavily smoothed version of the input volume
        
        for the smoothing we use tensorflows strided average pooling (which is faster compared to the numpy / scipy implementation) 
    """
    # to use tf's average pooling we first have to convert the numpy array to a tf tensor
    # for the pooling layers, the shape of the input need to be [1,n0,n1,n2,1]
    t = tf.convert_to_tensor(np.expand_dims(np.expand_dims(volume,0),-1).astype(np.float32))
    
    return tf.nn.avg_pool(t,2*n + 1,n,'SAME').numpy().max()

Let's define another helper function that loads all 3 nifiti volumes of a data set and that also already performs an intensity normalization. For the latter, we divide both PET images by the "robust" max of the input PET image, and the MR by its "robust" max, where "robust" max is the maximum of a heavily downsamped (pooled) volume. This is more stable when working with very noisy data.

In [None]:
def load_data_set(subject_path, sim = 0, counts = 1e7):

  # get the subject number from the path
  data_id = int(subject_path.parts[-1][-2:])

  # setup the file names
  mr_file   = pathlib.Path(subject_path) / 't1.nii.gz'
  osem_file = pathlib.Path(subject_path) / f'sim_{sim}' / f'osem_psf_counts_{counts:0.1E}.nii.gz'
  target_file = pathlib.Path(subject_path) / f'sim_{sim}' / 'true_pet.nii.gz'

  # load nifti files in RAS orientation
  mr, mr_aff = load_nii_in_lps(mr_file)
  osem, osem_aff = load_nii_in_lps(osem_file)
  target, target_aff = load_nii_in_lps(target_file)

  # normalize the intensities of the MR and PET volumes
  mr_scale   = robust_max(mr)
  osem_scale = robust_max(osem)

  mr     /= mr_scale
  osem   /= osem_scale
  target /= osem_scale

  return osem, mr, target, osem_scale, mr_scale

In many CNN training scenarios, on-the-fly data augmentation (e.g. cropping, rotating, change of contrast) is very useful. Moreover, when working with 3D data sets, networks are often trained on smaller patches due to **memory limitations on the available GPUs**. 
Here, we define a function that samples a random 3D patch from the entire input and target data sets. For the sampling we make use of ```tf.image.random_crop``` 

In [None]:
def train_augmentation(x, y, s0 = 64, s1 = 64, s2 = 64):
  """data augmentation function for training 
     
     the input x has shape (n0,n1,n2,2) and the input y has shape (n0,n1,n2,1)
  """

  # do the same random crop of input and output
  z = tf.concat([x,y], axis = -1)
  z_crop = tf.image.random_crop(z, [s0,s1,s2,z.shape[-1]])

  x_crop = z_crop[...,:2]
  y_crop = z_crop[...,2]

  return x_crop, y_crop

Now let's loop over all data directories and let's store all images in 2 big numpy arrays. The first array ```x_train``` should contain the input and the second array ```y_train``` the target for our CNN during training. When working with 3D volumes, the shape of the input and output to the CNN has to be ```(nbatch,n0,n1,n2,nchannels)``` where ```nbatch``` is the mini batch length, ```n0,n1,n2``` are the spatial dimentions, and ```nchannels``` are the number of input channels. In this example, we have two input channels (OSEM PET and T1 MR) and one output channel (target PET image).

Reading one data set takes up to 5s due to calculation of the robust maximum.

In [None]:
# read all the input data sets
# we apply a slight crop to exclude background regions 

x = np.zeros((3*len(subject_paths),176,196,178,2), dtype = np.float32)
y = np.zeros((3*len(subject_paths),176,196,178,1), dtype = np.float32)

# load all the data sets and sort them into the x and y numpy arrays
for i,subject_path in enumerate(subject_paths):
  for sim in range(3):
    print(f'loading {subject_path} simulation {sim}')
    data = load_data_set(subject_path, sim = sim, counts = 1e7)
    
    # - for every subject we have 3 simulated acquistions such that the position of the current acq. is 3*i + sim
    # - [40:-40,30:-30,40:-40] is used to crop the image in every direction to ignore empty background regions
    #   which saves memory and avoid sampling of too many (small) empty patches
    x[3*i + sim,...,0] = data[0][40:-40,30:-30,40:-40]
    x[3*i + sim,...,1] = data[1][40:-40,30:-30,40:-40]
    y[3*i + sim,...,0] = data[2][40:-40,30:-30,40:-40]

After we have read some of the available data sets, let's visualize the first data set.

In [None]:
import pymirc.viewer as pv
vi = pv.ThreeAxisViewer([x[...,0].squeeze(), x[...,1].squeeze(), y[...,0].squeeze()],
                           imshow_kwargs = {'vmin':0,'vmax':1.4}, rowlabels = [f'input 0', f'input 1', f'target'])

Let's create a tensorflow data set from our numpy arrays stored in the host memory ```x_train, y_train```. Moreover, we use the ```shuffle``` and ```map``` methods to shuffle the data and to apply our defined data augmentation function on the fly. More information on the tensorflow dataset class can be found here: https://www.tensorflow.org/api_docs/python/tf/data/Dataset

In [None]:
train_loader = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_loader.shuffle(len(x)).map(lambda x,y: train_augmentation(x,y, s0 = 64, s1 = 64, s2 = 64)).batch(batch_size).prefetch(2)

Finally, let's draw a mini-batch and let's visualize all 3D data sets in the mini batch.

In [None]:
x_batch, y_batch = list(train_dataset.take(1))[0]

You can click in the plots and use your arrow keys to move through the slices / samples in the mini batch. The left/right arrow keys move through the samples, and the top/down arrow keys move throught the slices.

In [None]:
vi = pv.ThreeAxisViewer([x_batch[...,0].numpy().squeeze(), x_batch[...,1].numpy().squeeze(), 
                         y_batch.numpy().squeeze()],
                         imshow_kwargs = {'vmin':0,'vmax':1.4}, rowlabels = [f'input 0', f'input 1', f'target'])

## Now it's your turn - recommended exercise
Now it is your turn, to familiarize yourself with the tensorflow dataset input pipeline:
1. Create a 2nd training data set loader similar to ```train_dataset``` that samples random patches with different size (e.g. 128,128,128)
2. Write your own on-the-fly data augmentation function that randomly changes the contrast of the input MR image (2nd channel). To do so, have a look into https://www.tensorflow.org/api_docs/python/tf/image/random_contrast

## What's next
In the following notebooks we will learn:
- how to setup a simple convolutional neural network (CNN) in tensorflow
- how to train a CNN with our data input pipeline
- how to monitor training