<a href="https://colab.research.google.com/github/desikazone/CFCM-2D/blob/master/CoarseToFineContextMemory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://www.nvidia.com/dli"> <img src="DLI Header.png" alt="Header" style="width: 400px;"/> </a>

# Lung segmentation in X-Ray images with Coarse to Fine Context Memory

#  <span style="color:green">Overview</span>

Medical Image segmentation is an essential component of various medical imaging analysis solutions. By partitioning a medical image into disjunct segments, it is for example possible to quantify the extent of a tumor or to study the anatomical structure of a region of interest. But it is also crucial as a first step for subsequent algorithms: by locating the region of interest, the search space for computational extensive algorithms can significantly be reduced. And these are only a few examples of the wide range of applications for segmentation. 

Although there has been significant progress in this active area of research, there is no unique solution for this challenging task. In recent years, approaches based on deep learning have proven to be particularly successful and an architecture that has widely been used is the Encoding-Decoding approach. 

In this lab we will investigate this popular architecture and discuss several strategies of fusing features from the encoding path with the decoding path. In the second part, we will highlight some of the best principles for structuring code for deep learning projects. Finally, we will see how the fusion of features can be further improved with the example of a recent architecture published at MICCAI and lung segmentation in x-ray images. However, it should be noted that the approach is general and can be applied for various medical image types. 


#  <span style="color:green">Introduction</span>
## The importance of skip connections in biomedical image segmentation
In a medical segmentation task, we want to map every pixel location of a medical image to a distinct class value representing for example background or tumor. Consequently, the input (image) and the output (segmentation) usually have the same spatial extent. A straight-forward approach to design a neural network architecture to achieve this goal would be to have several fully connected layers without changing the spatial dimensionality. However, this approach would very quickly lead to an explosion in terms of number of parameters. Convolutional operations have the advantage to significantly reduce the number of parameters, as the same operation is applied in strides over the entire image. Furthermore they are invariant to translation, e.g. it is not important if the object of interest is shifted. 

Designing a fully convolutional network without reducing the spatial dimensions would be possible, but research indicates that a encoder-decoder network is more effective [Ref. 6]. An encoder-decoder network typically looks like this:

![Encoder-Decoder Network](encoder.png)

The contracting path (Encoder) maps the image to a feature representation that is often lower dimensional than the original input size. The expanding path (Decoder) maps the feature representation to the output space. Originally, this type of network was known in connection with auto-encoder, where the output space is the same as the input space. In case of segmentation the output space often has the same spatial extend as the input space, but represents different content (classes). 

Thanks to the downsampling in the contracting path, fewer parameters need to be trained. However, this comes at the cost of losing spatial information. One approach to counteract this is the use of so-called skip connections: allowing the gradient to skip part of the network and to flow directly from a layer of the contracting part to the expanding path. 

## Fusing features from different layers
There are different techniques to realize these skip connections. Two of the most common ones are concatenation and summation:

### Feature concatenation
One possibility is to simply concatenate the layers. This requires the layers to have the same dimension in the concatenation direction. 
So a concatenation t1 = [1 2 3] with t2 = [4 5 6] could be t_new = [1 2 3 4 5 6].  In Tensorflow this operation is tf.concat ([see here](https://www.tensorflow.org/api_docs/python/tf/concat)).

<img src=concatenation.png width="600">


### Feature summation
Another widely used approach is element-wise summation. One very nice property is that it keeps the number of features fixed. A summation of t1= [1 2 3] and t2 = [4 5 6] would be t_new = [5 7 9]. 
<img src=summation.png width="600">

### Coarse to Fine Context Memory (CFCM)
The combination of layers is still an active area of research. A recent example is CFCM. Similar to the above examples, it is based on a fully convolutional architecture consisting of an encoding and a decoding part. The core idea of the CFCM approach is to use a memory mechanism, implemented via convolutional LSTMs, for fusing features extracted from different layers of the encoder. Thereby, the convolutional LSTMs take the role of a coarse-to-fine focusing mechanism which first perceives the global context of the input data, as the deepest activations are fed to the inputs of the LSTM, and later processes
fine-grained details. This happens when shallower, high-resolution features are considered. More details about convolutional LSTMs and this architectures are explained in the remainder of this exercise.

The original implementation of CFCM, which has served as inspiration to create this exercise, is available on http://github.com/faustomilletari/CFCM-2D.

<img src=cfcm_highlevel.png width="450">


### CFCM with "ladder" ResNet as backbone architecture
In this exercise, the encoding portion of CFCM is based
on a standard ResNet architecture while decoding is implemented using convolutional
LSTMs as explained above. 

![architecture](architecture.png)

The ResNet encoder has 4 super-blocks which operate at different resolutions. Each super-block (each represented in figure with a different color) contain multiple ResNet blocks. The features resulting from each ResNet block are forwarded to the LSTM-based decoding path. This kind of pattern has been sometimes termed "ladder" network as it resembles a ladder where each step is a forward (long) skip connection.

For completeness, we include a picture representing the layer configuration of a ResNet block. In our implementation we also use Batch Normalization [Ref 5].
<img src=block.png width="300">

## Lung segmentation in X-Ray images
X-Ray is a widely used modality for investigating the condition of the patient's lung. However, not only the lung is visible on these images. They contain the 2D projection of the entire anatomy, including bone and the heart. Determining the outline of the lung is therefore not an easy task, but necessary for many computer-aided procedures, visualization methods or for quantifying the extent of the lung.  

### The Montgomery County X-ray Set
The dataset employed in this exercise comprises 138 annotated posterior-anterior chest x-rays and has been acquired from the tuberculosis control program of the Department of Health and Human Services of Montgomery County, MD, USA. The set contains 80 normal cases and 58 abnormal cases with manifestations of tuberculosis including effusions and miliary patterns. We split this dataset in a training, validation and test set and use it to train, evaluate and infer using our DL approach. An example of images from this dataset is shown below.

![X-ray Set](data.png)

#  <span style="color:green">Best practice for structuring code</span>

## Code and exercise structure
Structuring code to solve a machine learning problem to ensure both flexibility and adoption of best practices is not an easy task. In this exercise we try to incorporate some of the best principles that have emerged in popular recent python projects.

With the introduction of modern frameworks such as tensorflow and pytorch, most of the processes around development of DL approaches have been standardized. During development of a typical project it is necessary to take care of only a handful of compartmentalized tasks such as:
* DATA
    1. Load data, standardize and augment it
    2. Split dataset into batches and iterate through them
* NETWORK
    1. Define a network architecture as computational graph
    2. Define suitable loss
* OPTIMIZATION
    1. Define optimization op
    2. Implement training and validation loops
    
### Handling Data

For **data** handling we define transforms in charge of loading, standardizing and modifying the dataset. Our transforms are chainable (stackable) such that data handling pipelines can be created. The dataset is stored in a python **dictionary** in order to allow this behaviour.

**Transforms** are implemented by classes. In the constructor of these transforms (`__init__` method in python) we pass the parameters of the transform. We define the `__call__` method to accept only one user defined argument which is the dictionary containing data. 

```
class ExampleTransform(object):
    # this transformation adds a constant to the images of a dataset
    
    def __init__(self, constant):
        # the constant that we add is passed as a parameter of this transform 
        self.constant = constant
        
    def __call__(self, data):
        # data is a dictionary containing the dataset. we suppose it has a field 'images'
        data['images'] = data['images'] + self.constant
        
        # the modified version of the data dictionary is returned as a result of the transform
        return data 
```

The example code above implements a simple transform that adds a constant to all of the images of the dataset. Other transforms, including those aiming at loading datasets from filesystem and actually inject new data into the 'data' dictionary can be implemented.

In order to split the dataset into batches and iterate through these batches during training validation and potentially testing we need a batch iterator. This is implemented here through a python class which acts as a generator. That is, we can use our batch iterator object in a for loop to get batches that we can use during training. The batch iterator will return a dictionary containing data at each iteration. The batch iterator is also able to execute transforms (as defined above) both before and during iterating over the dataset. More details will be shown later in the exercise. 

### Network definition

In the following sections of this exercise you will find the implementation of the network computational graph in tensorflow. A few helper functions have been defined in order to break down the implementation in more manageable portions and group together code that can be re-used.

Loss layers can be defined very easily in tensorflow by implementing only the 'forward' computation of the loss and omitting the gradient implementation. This is possible thanks to the built-in automatic differentiation capabilities of tensorflow and other modern DL frameworks.

### Fitting the networks parameters to the data

In order to train, validate, and test the network we need to write the relevant code implementing the training, validation and testing loops (testing loop omitted here). The basic functionality of this code is to instantiate the network, define inputs (in terms of images and labels in this case), and finally instantiate batch iterators and optimizer.  
At this point we can iterate (using a for loop) through the batches which can be fed to the network in order to optimize it for the task at hand.

In [1]:
!pip install tensorflow==1.15

Collecting tensorflow==1.15
[?25l  Downloading https://files.pythonhosted.org/packages/3f/98/5a99af92fb911d7a88a0005ad55005f35b4c1ba8d75fba02df726cd936e6/tensorflow-1.15.0-cp36-cp36m-manylinux2010_x86_64.whl (412.3MB)
[K     |████████████████████████████████| 412.3MB 37kB/s 
[?25hCollecting gast==0.2.2
  Downloading https://files.pythonhosted.org/packages/4e/35/11749bf99b2d4e3cceb4d55ca22590b0d7c2c62b9de38ac4a4a7f4687421/gast-0.2.2.tar.gz
Collecting keras-applications>=1.0.8
[?25l  Downloading https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae9102edf6342d71b28fbfd9dea3d2f96a882ce099b03f/Keras_Applications-1.0.8-py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 5.3MB/s 
Collecting tensorboard<1.16.0,>=1.15.0
[?25l  Downloading https://files.pythonhosted.org/packages/1e/e9/d3d747a97f7188f48aa5eda486907f3b345cd409f0a0850468ba867db246/tensorboard-1.15.0-py3-none-any.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 30.8MB/s 
Collecting tens

In [2]:
# In this section we import all the python packages used in this exercise

import numpy as np
import tensorflow as tf
import os
import copy

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from skimage import io
from skimage.transform import resize

from tensorflow.contrib import rnn
from tensorflow.contrib import slim

%matplotlib inline

# set random seeds for reproducibility
np.random.seed(4321)


## Loading the dataset
As previously explained, we understand all actions relative to data as **transformations**. Data loading is therefore understood as a transformation as well. A transformation, as already stated, is a callable (it is an object with a `__call__` method or a "pointer" to a function) which operates on a **dictionary** passed as input. **This dictionary contains the data**. The **transform changes** and/or adds to the content of this dictionary. 

In this exercise, loading the Montgomery XRay dataset is done by replacing/creating the fields 'images' and 'labels' of the data dictionary with a list of numpy tensors representing images. These images have been read from the filesystem itself (from a specific path where the dataset is stored) using scikit-image. Data is read when an object of class LungXRayDataset is called with a dictionary as argument. 

In [3]:
class LungXRayDataset(object):
    # This transformation loads data from the Montgomery County Lung XRay Dataset and updates the
    # 'data' dictionary accordingly
    
    def __init__(self, images_path, labels_path):
        # in this method we initialize the object 'LungXRayDataset' such that we know what paths
        # we need to read in order to load images and labels
        
        label_left_path = os.path.join(labels_path, 'left')
        label_right_path = os.path.join(labels_path, 'right')

        image_names = [f for f in os.listdir(images_path) if 'png' in f]

        self.image_path_list = [os.path.join(images_path, n) for n in image_names]
        self.label_left_path_list = [os.path.join(label_left_path, n) for n in image_names]
        self.label_right_path_list = [os.path.join(label_right_path, n) for n in image_names]
        
        assert len(self.image_path_list) == len(self.label_left_path_list)
        assert len(self.label_left_path_list) == len(self.label_right_path_list)

    def __call__(self, data):
        # This method 'transforms' the 'data' dictionary by adding two fields to it, 
        # 'images' and 'labels', which contain a list of images and relative labels that 
        # can be used to train our DL approach
        
        images = []
        labels = []
        
        for image_file, left_file, right_file in zip(self.image_path_list, self.label_left_path_list, self.label_right_path_list):
            image = io.imread(image_file, as_gray=True).astype(np.float32) # read image
            ll = io.imread(left_file, as_gray=True).astype(np.float32)  # read label for left lung
            lr = io.imread(right_file, as_gray=True).astype(np.float32)  # read label for right lung
            
            # fuse label for left and right lungs. In this exercise we do only binary segmentation
            label = np.zeros_like(image)
            label[ll > 0] = 1
            label[lr > 0] = 1
            
            image = (image - np.min(image)) / (np.max(image) - np.min(image))
            # append image and label to image and label lists
            images.append(image)
            labels.append(label)
        
        # update 'data' dictionary such that it can be further transformed and iterated
        data['images'] = images
        data['labels'] = labels
        
        return data

## Transforming and standardizing the data 
Not all the images for this dataset have the same size and exhibit the same content. Some pictures have large areas that are padded with black pixels, and their size is not standard. Here we define additional transforms which aim to standardize the data by removing the black padding that is present in some pictures (CropActualImage transformation) and resizing the images to a conventional size (ResizeImage transformation). We also define another transformation which takes care of shuffling (ShuffleData) the dataset before batching it and feeding it to the network for training. 

In [4]:
class CropActualImage(object):
    # this transformation crops the images such that padding (with zeros) is cropped out
    
    def __call__(self, data):
        # this method operates on 'data' dictionary. It replaces the fields 
        # 'images' and 'labels' with updated information
        images_t = []
        labels_t = []

        for image, label in zip(data['images'], data['labels']):
            non_zero_y = np.where(np.sum(image, axis=0) > 0)
            non_zero_x = np.where(np.sum(image, axis=1) > 0)

            image_t = image[np.min(non_zero_x):np.max(non_zero_x), np.min(non_zero_y):np.max(non_zero_y)]
            label_t = label[np.min(non_zero_x):np.max(non_zero_x), np.min(non_zero_y):np.max(non_zero_y)]

            images_t.append(image_t)
            labels_t.append(label_t)

        data['images'] = images_t
        data['labels'] = labels_t
        
        return data

In [5]:
class ResizeImage(object):
    # this transformation resizes the images and labels to a common size specified at init
    
    def __init__(self, size):
        # a 2D iterable containing the desired image size need to be specified at init
        self.size = size
        
    def __call__(self, data):
        # this method operates on 'data' dictionary. It replaces the fields 
        # 'images' and 'labels' with updated information
        data["images"] = np.array(
            [resize(i, (self.size[0], self.size[1]), preserve_range=True) for i in data['images']],
            dtype=data['images'][0].dtype
        )
        data['labels'] = np.array(
            [resize(i, (self.size[0], self.size[1]), preserve_range=True, order=0) for i in data['labels']],
            dtype=data['labels'][0].dtype
        )
        
        return data

In [6]:
class ShuffleData(object):
    # this transformation shuffles the dataset randomly
    
    def __init__(self, keys):
        # keys is a list of strings containing the fields of 
        # the data dictionary that should be shuffled
        self.keys = keys
        
    def __call__(self, data):
        # this method operates on the 'data' dictionary and it replaces the 
        # fields of the dictionary contained in self.keys with shuffled 
        # versions of them. All fields are shuffled in the same way
        data_length = len(data[self.keys[0]])
        new_order = np.random.permutation(data_length)

        for key in self.keys:
            data[key] = data[key][new_order]

        return data

## Batching and iterating

In this exercise the batch iterator takes care of the whole data loading/transformation/batching process. It gives us the ability to iterate through the dataset and specify the transformations that need to be applied to the data. Not all transformations need to be applied at the same time or at each loop iteration. We distinguish three cases:
* transforms applied BEFORE starting iterating through the data, for example loading (transforms_before_iterating)
* transforms applied to the whole training set at EACH EPOCH, for example shuffling (transforms_each_epoch)
* transforms applied to EACH BATCH separately, for example augmentation (transforms_each_iteration)

We implement here such object which has a method `__iter__` allowing us to use it as a generator (Eg. we can write `for batch in iterator: ...`) 

In [7]:
class BatchIterator(object):
    # this object implements a batch iterator that can be used to iterate through a 
    # dataset during training or inference and to apply transforms to data dictionary
    
    def __init__(self, 
                 batch_size,
                 keys,
                 transforms_before_iterating=[], 
                 transforms_each_epoch=[], 
                 transforms_each_iteration=[]
                ):
        self.batch_size = batch_size
        self.keys = keys
        self.transforms_before_iterating = transforms_before_iterating
        self.transforms_each_epoch = transforms_each_epoch
        self.transforms_each_iteration = transforms_each_iteration
        
        self.data = {}  # at the beginning we have an empty dataset (data dictionary)
        
        for transform in self.transforms_before_iterating:
            # for each transform in the list of transformation to do before iterating
            # apply that transform to the data dictionary
            self.data = transform(self.data)
        
    def __iter__(self):
        # deep copy of the current data such that we can transform it 
        # as we like without influencing next epochs. 
        # this is just a working copy of self.data
        data = copy.deepcopy(self.data) 
    
        for transform in self.transforms_each_epoch:
            # for each transform in the list of transformation to do at each epoch
            # apply that transform to the data dictionary
            data = transform(data)
        
        n_data = len(data[self.keys[0]])
        n_batches = int(np.ceil(n_data / self.batch_size))
        
        for i in range(n_batches):
            current_data = {}
            for key in self.keys:
                current_data[key] = data[key][i*self.batch_size:np.min([(i+1) * self.batch_size,  n_data])]
                
            for transform in self.transforms_each_iteration:
                # for each transform in the list of transformation to on each batch
                # apply that transform to the data dictionary
                current_data = transform(current_data)
        
            yield current_data

In [8]:
##mount drive 
from google.colab import drive
drive.mount("/content/drive")


Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [None]:
# defining model parameters
image_size = [256, 256]
batch_size = 16
num_epochs = 40
learning_rate = 0.000033

training_images_path = '../data/training/images'
training_labels_path = '../data/training/labels'

validation_images_path = '../data/validation/images'
validation_labels_path = '../data/validation/labels'

# declaring training batch iterator. the first 'transform' we do is loading the dataset. 

train_batch_iterator = BatchIterator(
    batch_size=batch_size, 
    keys=['images', 'labels'], 
    transforms_before_iterating=[
        LungXRayDataset(training_images_path, training_labels_path),  # first transform: load dataset (empty dataset > lung dataset)
        CropActualImage(),  # second transform: remove padding (lung dataset with padding > lung dataset)
        ResizeImage(image_size)  # third transform: resize images (lung dataset > resized lung dataset)
    ],
    transforms_each_epoch=[
        ShuffleData(keys=['images', 'labels']),  # at each epoch we shuffle the dataset
    ],
    transforms_each_iteration=[],  # this allows us to add augmentations etc at each iteration, but we don't use this in this exercise.
)

valid_batch_iterator = BatchIterator(
    batch_size=1, 
    keys=['images', 'labels'], 
    transforms_before_iterating=[
        LungXRayDataset(validation_images_path, validation_labels_path),
        CropActualImage(),  
        ResizeImage(image_size)
    ],
    transforms_each_epoch=[], 
    transforms_each_iteration=[],  
)

## Look at your data
It is important to always inspect the data before deciding what method is appropriate to solve the problem at hand. Having a look at the data might mean to take into consideration statistics and distributions underlying the dataset, but in this case we are interested in visually inspecting it in order to be sure that it has been loaded and transformed correctly.

In [None]:
for data in train_batch_iterator:
    plt.imshow(np.concatenate([np.squeeze(img) for img in np.split(data['images'], data['images'].shape[0], axis=0)], axis=1), cmap='gray')
    plt.show()
    plt.imshow(np.concatenate([np.squeeze(img) for img in np.split(data['labels'], data['labels'].shape[0], axis=0)], axis=1), cmap='jet')
    plt.show()

#  <span style="color:green">Going further: CFCM network architecture</span>

## Defining the network architecture
We now define the CFCM network architecture. In this example we show CFCM34 which is based on ResNet34 and uses LSTMs as a strategy to fuse features resulting from long skip connection at different resolutions. As shown in (Ref. ) the best performances on this dataset are obtained by a different, more complex, architecture but in order to reduce training time and computational load we find appropriate to use CFCM34 which is still capable of yielding very good results.

### Utility functions, losses and model constants
Here we define a few helper functions that are needed in order to create the network graph. We define a loss functiona and a scoring function based on the dice coefficient. We also define a few helper functions to facilitate the construction of the computational graph. 

### Dice loss/score formulation
The dice coefficient measures the overlap between two (binary) contours and has been generalized and introduced as an objective function for FCNNs in [Ref. 3]. Since then it has been utilized in a number of scientific works and it is now very well established. The formulation used in this work is DICE=2 * (Gt * Pred) / (Gt^2 * Pred^2). This corresponds to Dice when both Gt and Pred are binary.

other formulations such as DICE=2 * (Gt * Pred) / (Gt * Pred)

have been proposed but have slightly different behaviours, especially when it comes to gradients. You can get more information about this topic in [Ref. 7]

In [None]:
BASE_NUM_KERNELS = 64
EPS = 0.001

# LOSS

def dice(prediction, truth):
    return 2.0 * tf.reduce_sum(prediction * truth, axis=[1, 2, 3]) / (tf.reduce_sum((prediction ** 2 + truth ** 2), [1, 2, 3]) + EPS)

def dice_score(prediction, truth):
    dc = dice(prediction, truth)
    return tf.reduce_mean(dc, axis=0)

def dice_loss(prediction, truth):
    dc = dice(prediction, truth)
    return tf.reduce_mean(1.0 - dc, axis=0)

# COMMON BLOCKS/LAYERS

def batch_norm_relu(inputs, is_training):
    #net = tf.contrib.layers.batch_norm(inputs, is_training=is_training, zero_debias_moving_mean=True, decay=0.9)
    net = tf.nn.relu(inputs)
    return net


def conv2d_transpose(inputs, output_channels, kernel_size):
    return slim.conv2d_transpose(
        inputs,
        num_outputs=output_channels,
        kernel_size=kernel_size,
        stride=2,
    )


def conv2d_fixed_padding(inputs, filters, kernel_size, stride):
    return slim.conv2d(
        inputs,
        filters,
        kernel_size,
        stride=stride,
        padding=('SAME' if stride == 1 else 'VALID'),
        activation_fn=None
    )


def building_block(inputs, filters, is_training, projection_shortcut, stride):
    # this implements one block of the ResNet (
    # convolutional operations and skip connection included)
    
    shortcut = inputs
    inputs = batch_norm_relu(inputs, is_training)
    # The projection shortcut should come after the first batch norm and ReLU
    # since it performs a 1x1 convolution.
    if projection_shortcut is not None:
        shortcut = projection_shortcut(inputs)

    inputs = conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=3, stride=stride)

    inputs = batch_norm_relu(inputs, is_training)

    inputs = conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=3, stride=1)

    return inputs + shortcut



def block_layer_compressing(inputs, filters, blocks, stride, is_training, name):
    def projection_shortcut(inputs):
        return conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=1, stride=stride)

    # Only the first block per block_layer uses projection_shortcut and strides
    inputs = building_block(inputs, filters, is_training, projection_shortcut, stride)

    layers_outputs = [inputs]

    for i in range(1, blocks):
        inputs = building_block(inputs, filters, is_training, None, 1)

        layers_outputs.append(tf.nn.relu(inputs))

    return tf.identity(inputs, name), layers_outputs


### Convolutional LSTMs
At the core of the feature fusion strategy of CFCM, there is a memory mechanism which is implemented via a convolutional LSTM. The decoder treats each block of the ResNet encoder as a single time-step. As discussed above, CFCM forwards the outputs of these blocks to the decoding path where the features are processed through LSTM cells. To this end, convolutional LSTMs are employed. Convolutional LSTMs have the capability of selectively updating their
internal states at each step depending on the result of a convolution.

There are multiple reasons why it is necessary to use a convolutional LSTM in this case:
* Number of parameters: using classic LSTMs we would need to resort to a fully connected layer to produce the features for the LSTM gates. This would result in a parameter explosion due to the high dimensionality of the input features and therefore impair learning.
* We deal with images: convolutions are an appropriate choice when information is spatially localized such as in the case of images.
* Speed and memory: Less parameters means less operations and memory. Using classic LSTMs would be prohibitive when it comes to computational time and memory requirements.  

Here you can see a graphical representation of the convolutional LSTM employed in this work:
<img src=conv_lstm_basic.png width="600">

The input features, which in the CFCM architecture are forwarded from the encoding path, are represented in orange. These features are the results of convolutions and pooling operations on the network inputs. At each time-step a 4D input is presented to the LSTM. The frist dimension is the batch size, the next two dimensions are width and height and the last dimension is the channels. The number of channels can be large. 

The hidden state and cell state are part of the LSTM architecture. Both hidden and cell state are propagated through the sequence. The hidden state is also the output of the LSTM cell. The hidden state/output is never computed directly from the previous hidden state but it is computed from the cell state instead. There is no direct connection between hidden state from a previous step and the output at the next.

As shown in figure each time step makes use of three feature sets: inputs, hidden and cell state. Inputs are concatenated with the hidden state. A convolution is performed and its result is used to (1) pass a part of the information stored in the cell state through the forget gate; (2) compute new features which contribute to the cell state after being (3) decimated; (4) compute a new hidden state.

### (Convolutional) LSTM in action
The main difference between classic LSTMs and convolutional ones is the usage of a convolution operation to obtain the update features for the cell as well as the gating signals for the forget, input and output gates.

To understand how both classic and convolutional LSTMs work we will break down their behaviour for one step of a sequence. In the following drawing we have unrolled the LSTM across two sequence steps and have highlighted with a red box the operations and components relative to one of the two steps. 

<img src=conv_lstm_time.png width="700">

The features fed into our LSTM (at time t) are:
1. the hidden state at time t (at time zero an initialization is provided instead)
2. the cell state at time t (at time zero an initialization is provided instead)
3. the input at time t

<img src=conv_lstm_step1.png width="700">

The inputs (at time t) are concatenated with the hidden state (at time t) and convolved with a kernel having 4xC channels which creates 4 sets of features having C channels each. These features are used to:
1. drive the LSTM gates (forget, input, output) 
2. produce an update signal for the LSTM state consisting of features computed from input and hidden state

<img src=conv_lstm_step2.png width="700">

The features that are used to drive the three gates of the LSTM (forget, input and output) use a sigmoid non-linearity. In this way the output will be always between [0,1]. The update signal uses a different kind of non-linearity which can be in general chosen by the user and in the case of CFCM is a ReLu non-linearity. 

<img src=conv_lstm_step3.png width="700">

The cell state is multiplied by the forget gate. this means that some of the information contained within the cell state is decimated (removed) by multiplying with values close to zero while some other information is kept since it's multiplied with values close to one. 

The input gate decimates (removes) information belonging to the update signal computed in the previous step. In this way only some of the information computed via convolution from inputs and hidden state contributes to the cell state update.

The output gate will be used to choose which information belonging to the cell state (after activation) will be used to create the new hidden state. 

<img src=conv_lstm_step4.png width="700">

Information belonging to the cell state after (forget) gating and new information from the update signal after (input) gating, are fused together via summation.

<img src=conv_lstm_step5.png width="700">

In this way we obtain a new cell state (cell state at time t+1) which we can also use to obtain a new hidden state.

<img src=conv_lstm_step6.png width="700">

The new hidden state (at step t+1) is obtained by (output) gating the cell state (at time t+1) after ReLu non linearity. 

<img src=conv_lstm_step7.png width="700">

In [None]:
# IMPLEMENTATION OF CONV LSTM
class ConvLSTMCell(object):
    def __init__(self, shape, filter_size, num_features, forget_bias=1.0, activation=tf.nn.tanh, scope=''):
        self.shape = shape
        self.filter_size = filter_size
        self.num_features = num_features
        self.forget_bias = forget_bias
        self.activation = activation
        self.num_units = num_features
        self.scope = scope
    
    def zero_state(self, batch_size, dtype):
        init_state = tf.zeros([batch_size, self.shape[0], self.shape[1], self.num_features * 2])
        return init_state

    @property
    def state_size(self):
        return 2 * self.num_units

    @property
    def output_size(self):
        return self.num_units
    
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            cell, hidden = tf.split(axis=3, num_or_size_splits=2, value=state)
            
            features = lstm_conv(
                inputs=[inputs, hidden], 
                filter_size=self.filter_size, 
                num_features=self.num_features * 4,
                use_bias=True,
                init_bias=0.0,
                scope=self.scope + 'conv'
            )

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=features)

            new_cell = (cell * tf.nn.sigmoid(f + self.forget_bias) + tf.nn.sigmoid(i) *
                     self.activation(j))
            
            new_hidden = self.activation(new_cell) * tf.nn.sigmoid(o)

            new_state = tf.concat(axis=3, values=[new_cell, new_hidden])
            
            return new_hidden, new_state


def lstm_conv(inputs, filter_size, num_features, use_bias, init_bias=0.0, scope='Conv'):
    n_in_chan = 0
    shapes = [elem.get_shape().as_list() for elem in inputs]
    
    for shape in shapes:
        n_in_chan += shape[3]

    dtype = inputs[0].dtype

    with tf.variable_scope(scope):
        weights = tf.get_variable(
            "weights", [filter_size[0], filter_size[1], n_in_chan, num_features], dtype=dtype)

        res = tf.nn.conv2d(tf.concat(axis=3, values=inputs), weights, strides=[1, 1, 1, 1], padding='SAME')
        
        if not use_bias:
            return res
        
        additive_bias = tf.get_variable(
            "bias", 
            [num_features],
            dtype=dtype,
            initializer=tf.constant_initializer(
                init_bias, 
                dtype=dtype
            )
        )
        
    return res + additive_bias

## Coarse to Fine Context Memory (CFCM) based on ResNet34

We are now ready to implement the CFCM architecture making use of the convolutional LSTM code we have implemented above and the helper functions that we have defined earlier. The following code implements the computational graph of CFCM based on ResNet34. 

In [None]:
class CFCM34(object):
    layers = [3, 4, 6, 3]  # description of layers arrangement (ResNet 34)
    
    def __init__(self, num_classes, data_format='channels_last'):
        self.num_classes = num_classes
        self.data_format = data_format
    
    def __call__(self, inputs, is_training):
        base_num_kernels = BASE_NUM_KERNELS

        inputs = conv2d_fixed_padding(inputs=inputs, filters=base_num_kernels, kernel_size=7, stride=1)

        inputs = tf.layers.max_pooling2d(inputs=inputs, pool_size=2, strides=2, padding='SAME')

        output_b1, output_list_b1 = block_layer_compressing(
            inputs=inputs,
            filters=base_num_kernels,
            blocks=self.layers[0],
            stride=1,
            is_training=is_training,
            name='block_layer1'
        )

        output_b1 = tf.layers.max_pooling2d(
            inputs=output_b1, 
            pool_size=2, 
            strides=2, 
            padding='SAME',
            data_format=self.data_format
        )

        output_b2, output_list_b2 = block_layer_compressing(
            inputs=output_b1, 
            filters=base_num_kernels * 2, 
            blocks=self.layers[1],
            stride=1, 
            is_training=is_training, 
            name='block_layer2'
        )

        output_b2 = tf.layers.max_pooling2d(
            inputs=output_b2, 
            pool_size=2, 
            strides=2, 
            padding='SAME',
            data_format=self.data_format
        )

        output_b3, output_list_b3 = block_layer_compressing(
            inputs=output_b2, 
            filters=base_num_kernels * 4,  
            blocks=self.layers[2],
            stride=1, 
            is_training=is_training, 
            name='block_layer3'
        )

        output_b3 = tf.layers.max_pooling2d(
            inputs=output_b3, 
            pool_size=2, 
            strides=2, 
            padding='SAME',
            data_format=self.data_format
        )

        output_b4, output_list_b4 = block_layer_compressing(
            inputs=output_b3, 
            filters=base_num_kernels * 8, 
            blocks=self.layers[3],
            stride=1, 
            is_training=is_training, 
            name='block_layer4'
        )

        # lstm - decoding path

        initial_hidden = tf.zeros_like(output_b4)
        initial_cell = tf.zeros_like(output_b4)

        initial_state = tf.concat([initial_cell, initial_hidden], axis=3)

        shape = [output_b4.get_shape().as_list()[1], output_b4.get_shape().as_list()[2]]
        
        lstm_b4 = ConvLSTMCell(shape, [3, 3], num_features=base_num_kernels * 8, scope='lstm_b4',
                                    activation=tf.nn.relu)

        _, state = rnn.static_rnn(lstm_b4, output_list_b4[::-1], initial_state=initial_state, dtype=tf.float32)

        state = conv2d_transpose(state, kernel_size=2, output_channels=base_num_kernels * 8)

        shape = [output_b3.get_shape().as_list()[1], output_b3.get_shape().as_list()[2]]
        lstm_b3 = ConvLSTMCell(shape, [3, 3], num_features=base_num_kernels * 4, scope='lstm_b3',
                                    activation=tf.nn.relu)

        _, state = rnn.static_rnn(lstm_b3, output_list_b3[::-1], initial_state=state, dtype=tf.float32)

        state = conv2d_transpose(state, kernel_size=2, output_channels=base_num_kernels * 4)

        shape = [output_b2.get_shape().as_list()[1], output_b2.get_shape().as_list()[2]]
        lstm_b2 = ConvLSTMCell(shape, [3, 3], num_features=base_num_kernels * 2, scope='lstm_b2',
                                    activation=tf.nn.relu)

        _, state = rnn.static_rnn(lstm_b2, output_list_b2[::-1], initial_state=state, dtype=tf.float32)

        state = conv2d_transpose(state, kernel_size=2, output_channels=base_num_kernels * 2)

        shape = [output_b1.get_shape().as_list()[1], output_b1.get_shape().as_list()[2]]
        lstm_b1 = ConvLSTMCell(shape, [3, 3], num_features=base_num_kernels, scope='lstm_b1',
                                    activation=tf.nn.relu)

        output, state = rnn.static_rnn(lstm_b1, output_list_b1[::-1], initial_state=state, dtype=tf.float32)

        hidden = conv2d_transpose(output[-1], kernel_size=2, output_channels=base_num_kernels)

        conv_final = conv2d_fixed_padding(inputs=hidden, filters=base_num_kernels, kernel_size=3, stride=1)
        
        conv_final = tf.nn.relu(conv_final)

        outputs = conv2d_fixed_padding(inputs=conv_final, filters=self.num_classes, kernel_size=3, stride=1)

        return outputs


## Implementation of training + validation loops
We implement here both the training and validation loops as well as the code necessary to open the tensorflow session and instantiate/initialize the graph. We define here input placeholders, loss, and the optimization op (fitting_op) which is essential to train the network. We loop over the training and validation sets for a number of epochs and plot results for both training and validation in order to give you a sense of how the network training progresses over time. Training + validation can be run by executing the cell after the next. 

In [None]:
def run_training_and_validation(network):
    with tf.Graph().as_default() as g:
        tf.set_random_seed(777)
        with tf.Session() as sess:
            # ResNet34-based Network WITHOUT COARSE TO FINE CONTEXT MEMORY. <HINT> edit the code to use CFCM34
            SegResNet = network(num_classes=1)

            # Placeholders for the inputs
            images = tf.placeholder(shape=[None, 256, 256, 1], dtype=tf.float32)
            labels = tf.placeholder(shape=[None, 256, 256, 1], dtype=tf.float32)
            is_training = tf.placeholder(shape=(), dtype=tf.bool)

            # Network output
            prediction = tf.sigmoid(SegResNet(images, is_training))

            # Loss function
            loss = dice_loss(prediction, labels)
            
            # Score function (binarize the prediction for accurate results!)
            score = dice_score(tf.cast(prediction > 0.5, dtype=tf.float32), labels)

            # Optimizer
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

            extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

            with tf.control_dependencies(extra_update_ops):
                fitting_op = optimizer.minimize(loss)

            # Initialization
            init_op = tf.global_variables_initializer()
            sess.run(init_op)

            training_scores = []
            valid_scores = []
            
            training_epochs_axis = []
            valid_epochs_axis = []

            for i in range(num_epochs):
                print('---------- TRAINING EPOCH {} out of {} ----------'.format(i, num_epochs))

                epoch_training_scores = []
                epoch_valid_scores = []
                
                training_epochs_axis.append(i)

                for data in train_batch_iterator:
                    _, curr_training_score, curr_prediction = sess.run(
                        [fitting_op, score, prediction], 
                        feed_dict={
                            images: data['images'][..., np.newaxis],
                            labels: data['labels'][..., np.newaxis],
                            is_training: True,
                        }
                    )

                    epoch_training_scores.append(curr_training_score)

                    plt.imshow(np.concatenate([np.squeeze(img) for img in np.split(data['images'], data['images'].shape[0], axis=0)], axis=1), cmap='gray')
                    plt.show()
                    plt.imshow(np.concatenate([np.squeeze(img) for img in np.split(data['labels'], data['labels'].shape[0], axis=0)], axis=1), cmap='jet')
                    plt.show()
                    plt.imshow(np.concatenate([np.squeeze(img) for img in np.split(curr_prediction, curr_prediction.shape[0], axis=0)], axis=1), cmap='jet')
                    plt.show()

                training_scores.append(np.mean(epoch_training_scores))
                print('Average TRAINING score for epoch {}: {}'.format(i, np.mean(epoch_training_scores)))
                
                # running validation only after the network attains decent performances (in the interest of time)
                
                if training_scores[-1] > 0.92:
                    print('---------- VALIDATION EPOCH {} out of {} ----------'.format(i, num_epochs))
                    
                    valid_epochs_axis.append(i)
                    
                    for data in valid_batch_iterator:

                        curr_valid_score, curr_prediction = sess.run(
                            [score, prediction], 
                            feed_dict={
                                images: data['images'][..., np.newaxis],
                                labels: data['labels'][..., np.newaxis],
                                is_training: False,
                            }
                        )

                        epoch_valid_scores.append(curr_valid_score)

                        plt.imshow(np.concatenate([np.squeeze(img) for img in np.split(data['images'], data['images'].shape[0], axis=0)], axis=1), cmap='gray')
                        plt.show()
                        plt.imshow(np.concatenate([np.squeeze(img > 0.5) for img in np.split(data['labels'], data['labels'].shape[0], axis=0)], axis=1), cmap='gray')
                        plt.show()
                        plt.imshow(np.concatenate([np.squeeze(img > 0.5) for img in np.split(curr_prediction, curr_prediction.shape[0], axis=0)], axis=1), cmap='gray')
                        plt.show()

                    valid_scores.append(np.mean(epoch_valid_scores))
                    print('Average VALIDATION score for epoch {}: {}'.format(i, np.mean(epoch_valid_scores)))
                
    return training_scores, valid_scores, training_epochs_axis, valid_epochs_axis

In [None]:
# We run training + validation by executing this cell
training_scores, valid_scores, train_axis, valid_axis = run_training_and_validation(CFCM34)

## Plotting scores
We now plot the training and validation scores in terms of Dice coefficient (the higher, the better). More informations about the results and more experiments can be found in the original paper [Ref. 2]

In [None]:
plt.plot(train_axis, training_scores, 'b', valid_axis, valid_scores, 'r')
red_patch = mpatches.Patch(color='red', label='Validation')
blue_patch = mpatches.Patch(color='blue', label='Training')
plt.legend(handles=[red_patch, blue_patch])
plt.show()

## <span style="color:green">Conclusions</span>
In this exercise you have implemented CFCM which is an interesting alternative to classic feature fusion for long skip connections in fully convolutional neural networks (FCNN) applied to medical image segmentation. 
The multi-scale feature integration of CFCM is achieved via LSTMs which process and aggregate features extracted at different resolutions and having different receptive fields. The architecture is based on ResNet34 and the training strategy is similar to the one adopted by more classical and widespread FCNN methods: images are fed to the algorithm together with their labels and a loss function (dice loss in this case) is optimized in a fully supervised manner.

As a result of carrying on this exercise you have familiarized with concepts such as 
* fully convolutional neural networks
* residual networks
* long/short skip connections to improve image segmentation
* batch normalization
* receptive field
* recurrent neural networks (RNN) 
* (convolutional) long short term memory cells (LSTMs)

You have also familiarized with a simple yet powerful way to structure deep learning project allowing to modularize basic tasks such as data loading, manipulation and augmentation, in a way that is scalable to larger project and compatible with practices adopted by frameworks such as pytorch.

## References
* Ref. 1: *Milletari, F., Rieke, N., Baust, M., Esposito, M. and Navab, N., 2018. CFCM: Segmentation via Coarse to Fine Context Memory. arXiv preprint arXiv:1806.01413.*
* Ref. 2: *Ronneberger, O., Fischer, P. and Brox, T., 2015, October. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.*
* Ref. 3: *Milletari, F., Navab, N. and Ahmadi, S.A., 2016, October. V-net: Fully convolutional neural networks for volumetric medical image segmentation. In 3D Vision (3DV), 2016 Fourth International Conference on (pp. 565-571). IEEE.*
* Ref. 4: *Laina, I., Rieke, N., Rupprecht, C., Vizcaíno, J. P., Eslami, A., Tombari, F., and Navab, N. 2017. Concurrent segmentation and localization for tracking of surgical instruments. In International conference on medical image computing and computer-assisted intervention (pp. 664-672). Springer*
* Ref. 5: *Ioffe, S., & Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167*
* Ref. 6: *Drozdzal, M., Vorontsov, E., Chartrand, G., Kadoury, S. and Pal, C., 2016. The importance of skip connections in biomedical image segmentation. In Deep Learning and Data Labeling for Medical Applications (pp. 179-187). Springer, Cham*
* Ref. 7: *Milletari, F., 2018. Hough Voting Strategies for Segmentation, Detection and Tracking (Doctoral dissertation, Universität München)*