# Neural network for segmenting LV of heart

## Introduction
This work was performed as part of the MRAI workshop (MIDL 2019 satellite meeting) with exercises designed by Esther Puyol (https://github.com/estherpuyol/MRAI_workshop).

### Objective
Train a simple neural network to automatically segment the Left ventricle from 2D short axis cardiac MR images.

## Import modules
The network is built using functions and classes from the [pytorch library](https://pytorch.org/docs/stable/index.html) 

In [0]:
import os
import numpy as np
import pylab as plt#
import tensorflow as tf
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import precision_score, recall_score
from sklearn.model_selection import StratifiedShuffleSplit

## Download dataset
The data for this study is from the [Sunnybrook Cardiac Data](https://www.cardiacatlas.org/studies/sunnybrook-cardiac-data/)

A preprocessed subset of this data is used, where the data is filtered to contain only left ventricle myocardium segmentations and reduced in XY dimensions.

In [3]:
![ -f scd_lvsegs.npz ] || wget https://github.com/estherpuyol/MRAI_workshop/raw/master/scd_lvsegs.npz

data = np.load('scd_lvsegs.npz') # load all the data from the archive

images = data['images'] # images in BHW array order : .shape (420, 64, 64)]
segs = data['segs'] # segmentations in BHW array order : .shape (420, 64, 64)
caseIndices = data['caseIndices'] # the indices in `images` for each case : .shape (45, 2)

images = images.astype(np.float32)/images.max() # normalize images

--2019-08-13 08:47:44--  https://github.com/estherpuyol/MRAI_workshop/raw/master/scd_lvsegs.npz
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/estherpuyol/MRAI_workshop/master/scd_lvsegs.npz [following]
--2019-08-13 08:47:44--  https://raw.githubusercontent.com/estherpuyol/MRAI_workshop/master/scd_lvsegs.npz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2133403 (2.0M) [application/octet-stream]
Saving to: ‘scd_lvsegs.npz’


2019-08-13 08:47:45 (48.7 MB/s) - ‘scd_lvsegs.npz’ saved [2133403/2133403]



## Split training and test data
This code block splits the data set into training and test data where the variable n_training determines the number of test cases.  

In [0]:
n_training = 6

testIndex = caseIndices[-n_training,0] # keep the last n_training cases for testing

# divide the images, segmentations, and categories into train/test sets
trainImages,trainSegs = images[:testIndex],segs[:testIndex]
testImages,testSegs = images[testIndex:],segs[testIndex:]

## Define segmentation network
Here two classes are defined: 

The first describes the loss function ```DiceLoss``` which will provide a measure of overlap between ground truth segmentations and the network outputs. 

The second class ```SegNet``` is our artifical neural network that inherits methods from the base NN class torch.nn.Module. 


In [0]:
class DiceLoss(tf.keras.losses.Loss):
    '''This defines the binary dice loss function used to assess segmentation overlap.'''
    def call(self, y_pred, y_true, axis=(1, 2, 3), smooth=1e-5):
        y_pred = tf.math.sigmoid(y_pred)
        y_true = tf.math.sigmoid(y_true)
        inse = tf.reduce_sum(y_pred * y_true, axis=axis)
        l = tf.reduce_sum(y_pred, axis=axis)
        r = tf.reduce_sum(y_true, axis=axis)

        dice = (2. * inse + smooth) / (l + r + smooth)
        ##
        dice = tf.reduce_mean(dice, name='dice_coe')
        return dice 


"""
class unet_block(tf.keras.Model):

    def __init__(self, Cin, Cout, subblock):
        super().__init__()
        self._encode = tf.keras.layers.Conv2D(Cout, (3,3), strides=(1,1), padding='valid', data_format='channels_first')
        self._encode_norm = tf.keras.layers.BatchNormalization(axis=1)
        self._encode_dropout = tf.keras.layers.Dropout(rate=0.2)
        self._encode_activation = tf.keras.layers.Activation(tf.nn.leaky_relu)

        self._subblock = subblock

        self._decode = tf.keras.layers.Conv2DTranspose(Cin, (3,3), strides=(1,1), padding='valid', data_format='channels_first')
        self._decode_norm = tf.keras.layers.BatchNormalization(axis=1)
        self._decode_dropout = tf.keras.layers.Dropout(rate=0.2)
        self._decode_activation = tf.keras.layers.Activation(tf.nn.leaky_relu)

    def call(self, tensor_in):
#        Forward pass - encode block -> subblocks -> decode block
        enc = self._encode(tensor_in)
        enc = self._encode_norm(enc)
        #if training:
        enc = self._encode_dropout(enc)
        enc = self._encode_activation(enc)

        sub = self._subblock(enc)
        sub = tf.concat([enc,sub], axis=1)

        dec = self._decode(sub)
        dec = self._decode_norm(dec)
        #if training:
        dec = self._decode_dropout(dec)
        dec = self._decode_activation(dec)

        return dec


class Unet(tf.keras.Model):
    def __init__(self):
        super().__init__()
        # bottom subblock
        net = tf.keras.Sequential()
        net.add(tf.keras.layers.Conv2D(64, (3,3), strides=2, padding='same', data_format='channels_first'))
        net.add(tf.keras.layers.BatchNormalization(axis=1))
        net.add(tf.keras.layers.PReLU())

        # build the unet structure from the bottom up
        net=unet_block(32,64,net)
        net=unet_block(16,32,net)
        net=unet_block(8,16,net)
        net=unet_block(4,8,net)

        # final top-level structure omits dropout and applies sigmoid to the output
        self.model = tf.keras.Sequential()
        self.model.add(tf.keras.layers.Conv2D(4, (3,3), strides=1, padding='same', data_format='channels_first'))
        self.model.add(tf.keras.layers.BatchNormalization(axis=1))
        self.model.add(tf.keras.layers.PReLU())
        self.model.add(net)
        self.model.add(tf.keras.layers.Conv2DTranspose(1, (3,3), strides=1, padding='same', data_format='channels_first'))

    def call(self, x):
        return self.model(x)

"""

class Unet(tf.keras.Model):
    """tf unet based on design https://keunwoochoi.wordpress.com/2017/10/11/u-net-on-keras-2-0/"""
    def __init__(self, conv_channels=[1], kernels=[5], training=True):
        super().__init__()

        self._conv_channels = conv_channels
        self._kernels = kernels

        # check kernel/substack size are equal and make them equal if only a single kernel size is used
        if (len(conv_channels) is not len(kernels)):
            if len(kernels) is 1:
                self._kernels = [item for item in self._kernels for i in range(len(self._conv_channels))]
            else:
                raise AssertionError('conv_channels and kernel lists must have equal length')

        print(self._conv_channels)
        print(self._kernels)

        self._encoder = []

        # first bit define input
        input_tensor = tf.keras.Input(shape=(1, 64, 64))
        self._enc = input_tensor

        # encode module
        for n, C_out in enumerate(self._conv_channels):
            print('%d : %d' % (n, C_out))
            self._enc = tf.keras.layers.Conv2D(C_out, self._kernels[n], strides=2, padding='same', data_format='channels_first')(self._enc)
            self._enc = tf.keras.layers.BatchNormalization(axis=1)(self._enc)
            self._enc = tf.keras.layers.PReLU()(self._enc)
            self._encoder.append(self._enc)
        
        # reverse channel list
        self._conv_channels = self._conv_channels[::-1]
        print(self._encoder)

        # decode module
        self._dec = self._enc
        for n, C_in in enumerate(self._conv_channels):
            print('%d : %d' % (n, C_in))
            idx_rev = len(self._conv_channels) - n - 1
            print(idx_rev)

            self._dec = tf.keras.layers.Conv2DTranspose(C_in, self._kernels[n], strides=2, padding='same', data_format='channels_first')(self._enc)
            self._dec = tf.keras.layers.BatchNormalization(axis=1)(self._dec)
            self._dec = tf.keras.layers.PReLU()(self._dec)
            self._dec = tf.concat([self._dec, self._encoder[idx_rev]], axis=1)
        
        self.model = tf.keras.Model(inputs=[input_tensor],outputs=[output_tensor])

    def call(self, x):
        return self.model(x)

## Train network

In [200]:
# store the training data as tensors
trainTensor = tf.convert_to_tensor(trainImages[:,None].astype(np.float32))
segTensor = tf.convert_to_tensor(trainSegs[:,None].astype(np.float32))

# create network object
channels= [2, 4, 8, 16]
net = Unet(channels)

# choose a device and  (remember to set Google Colab environment runtime to use GPU)
tf.device("/device:GPU:0")

# define optimizer and loss function
opt = tf.keras.optimizers.Adam(learning_rate=0.005)

#net.compile(opt,loss=DiceLoss(), metrics=['accuracy'])

net.compile(opt,loss=tf.keras.losses.BinaryCrossentropy())

net.fit(trainTensor, segTensor, batch_size=32, epochs=5, steps_per_epoch=200)

net.summary()

# result
#sample = np.random.randint(0, trainTensor.shape[0]-1 )  # choose random sample to visualise segmentation the network predicted for it
sample=10
print('Showing results from random sample: %d' % sample)

pred = net.predict(trainTensor, steps=1)
pred.shape

predSample=np.squeeze(pred[sample,:])

fig,ax=plt.subplots(1,5,figsize=(20,5))
#ax[0].set_title('Loss')
#ax[0].semilogy(losses)
ax[1].set_title('Sample Image')
ax[1].imshow(trainImages[sample])
ax[2].set_title('Sample Ground Truth')
ax[2].imshow(trainSegs[sample])
ax[3].set_title('Sample Logits')
ax[3].imshow(predSample)
ax[4].set_title('Sample Predicted Segmentation')
ax[4].imshow(predSample>0.5)

ListWrapper([2, 4, 8, 16])
ListWrapper([5, 5, 5, 5])
0 : 2
1 : 4
2 : 8
3 : 16
ListWrapper([<tf.Tensor 'p_re_lu_315/add:0' shape=(?, 2, 32, 32) dtype=float32>, <tf.Tensor 'p_re_lu_316/add:0' shape=(?, 4, 16, 16) dtype=float32>, <tf.Tensor 'p_re_lu_317/add:0' shape=(?, 8, 8, 8) dtype=float32>, <tf.Tensor 'p_re_lu_318/add:0' shape=(?, 16, 4, 4) dtype=float32>])
0 : 16
3


ValueError: ignored