# Neural network for segmenting LV of heart

## Introduction
This work was performed as part of the MRAI workshop (MIDL 2019 satellite meeting) with exercises designed by Esther Puyol (https://github.com/estherpuyol/MRAI_workshop).

### Objective
Train a simple neural network to automatically segment the Left ventricle from 2D short axis cardiac MR images.

## Import modules
The network is built using functions and classes from the [pytorch library](https://pytorch.org/docs/stable/index.html) 

In [0]:
import os
import numpy as np
import pylab as plt#
import tensorflow as tf
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import precision_score, recall_score
from sklearn.model_selection import StratifiedShuffleSplit

## Download dataset
The data for this study is from the [Sunnybrook Cardiac Data](https://www.cardiacatlas.org/studies/sunnybrook-cardiac-data/)

A preprocessed subset of this data is used, where the data is filtered to contain only left ventricle myocardium segmentations and reduced in XY dimensions.

In [0]:
![ -f scd_lvsegs.npz ] || wget https://github.com/estherpuyol/MRAI_workshop/raw/master/scd_lvsegs.npz

data = np.load('scd_lvsegs.npz') # load all the data from the archive

images = data['images'] # images in BHW array order : .shape (420, 64, 64)]
segs = data['segs'] # segmentations in BHW array order : .shape (420, 64, 64)
caseIndices = data['caseIndices'] # the indices in `images` for each case : .shape (45, 2)

images = images.astype(np.float32)/images.max() # normalize images

## Split training and test data
This code block splits the data set into training and test data where the variable n_training determines the number of test cases.  

In [0]:
n_training = 6

testIndex = caseIndices[-n_training,0] # keep the last n_training cases for testing

# divide the images, segmentations, and categories into train/test sets
trainImages,trainSegs = images[:testIndex],segs[:testIndex]
testImages,testSegs = images[testIndex:],segs[testIndex:]

## Define segmentation network
Here two classes are defined: 

The first describes the loss function ```DiceLoss``` which will provide a measure of overlap between ground truth segmentations and the network outputs. 

The second class ```SegNet``` is our artifical neural network that inherits methods from the base NN class torch.nn.Module. 


In [70]:
class DiceLoss(tf.keras.losses.Loss):
    '''This defines the binary dice loss function used to assess segmentation overlap.'''
    def call(self, y_true, y_pred):
        smooth=1e-5
        batchsize = target.size(0)
        source = source.sigmoid() # apply sigmoid to the source logits to impose it onto the [0,1] interval
        
        # flatten target and source arrays to 2D BV arrays
        tsum = target.view(batchsize, -1) 
        psum = source.view(batchsize, -1)
        
        intersection=psum*tsum
        sums = psum+tsum 

        # compute the score, the `smooth` value is used to smooth results and prevent divide-by-zero
        score = 2.0 * (intersection.sum(1) + smooth) / (sums.sum(1) + smooth)
        
        # `score` is 1 for perfectly identical source and target, 0 for entirely disjoint
        return 1 - score.sum() / batchsize



class unet_block(tf.keras.Model):

    def __init__(self, Cin, Cout, subblock):
        """initialise unit"""
        super().__init__()
        self._encode = tf.keras.layers.Conv2D(Cout, (3,3), strides=(2,2), padding='valid')
        self._encode_norm = tf.keras.layers.BatchNormalization(Cout)
        self._encode_dropout = tf.keras.layers.Dropout(rate=0.2)
        self._encode_activation = tf.keras.layers.Activation(tf.nn.leaky_relu)

        self._subblock = subblock

        self._decode = tf.keras.layers.Conv2DTranspose(Cin, (3,3), stride=(2,2), padding='valid')
        self._decode_norm = tf.keras.layers.BatchNormalization(Cin)
        self._decode_dropout = tf.keras.layers.Dropout(rate=0.2)
        self._decode_activation = tf.keras.layers.Activation(tf.nn.leaky_relu)

    def call(self, tensor_in, training=False):
        """Forward pass - encode block -> subblocks -> decode block"""
        enc = self._encode(tensor_in)
        enc = self._encode_norm(enc)
        if training:
            enc = self._encode_dropout(enc)
        enc = self._encode_activation(enc)

        sub = self._subblock(enc)
        sub = tf.concat([enc,sub], axis=1)

        dec = self._decode(sub)
        dec = self._decode_norm(dec)
        if training:
            dec = self._decode_dropout(dec)
        dec = self._decode_activation(dec)

        return dec


class Unet(tf.keras.Model):
    """Builds unet"""
    def __init__(self):
        super().__init__()

        filters_bottom = 64

        # bottom subblock
        net = tf.keras.Sequential()
        net.add(tf.keras.layers.Conv2D(filters_bottom, (3,3), strides=2, padding='valid'))
        net.add(tf.keras.layers.BatchNormalization(filters_bottom))
        net.add(tf.keras.layers.Activation(tf.nn.leaky_relu))

        # build the unet structure from the bottom up
        net=UnetBlock(32,filters_bottom,net)
        net=UnetBlock(16,32,net)
        net=UnetBlock(8,16,net)
        net=UnetBlock(4,8,net)

        # final top-level structure omits dropout and applies sigmoid to the output
        self.model =  tf.keras.Sequential(
                                          nn.Conv2d(1,4,3,1,1),
                                          nn.InstanceNorm2d(4),
                                          nn.PReLU(),
                                          net,
                                          nn.Conv2d(4,1,3,1,1),
                                          )

    def call(self, x):
        return self.model(x)

<tensorflow.python.keras.engine.sequential.Sequential object at 0x7f86752f1ef0>


## Train network

In [55]:
# store the training data as tensors
trainTensor = tf.convert_to_tensor(trainImages[:,None])
segTensor = tf.convert_to_tensor(trainSegs[:,None].astype(np.float32))

# create network object
net = Unet()

# choose a device and  (remember to set Google Colab environment runtime to use GPU)
#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tf.device("/device:GPU:0")


# move the net and tensors to device memory
#net = net.to(device)
#trainTensor = trainTensor.to(device)
#segTensor = segTensor.to(device)

# define optimizer and loss function
opt = tf.keras.optimizers.Adam(learning_rate=0.005)
loss=DiceLoss()

trainSteps = 5000
losses = []

# run through training steps
for t in range(1,trainSteps+1):
    opt.zero_grad()
    pred = net(trainTensor)
    if t == 1:
      print('pred shape: ' + str(pred.size()))
      print('segTensor shape: ' + str(segTensor.size()))
    lossval = loss(pred,segTensor)
    lossval.backward()
    opt.step()
        
    losses.append(lossval.item())
    if t%(trainSteps//20) == 0:
        print(t,lossval.item())    

# result
sample = np.random.randint(0, pred.shape[0]-1 )  # choose random sample to visualise segmentation the network predicted for it
print('Showing results from random sample: %d' % sample)
pred.shape
predSample=pred[sample,0].cpu().data.numpy()
fig,ax=plt.subplots(1,5,figsize=(20,5))
ax[0].set_title('Loss')
ax[0].semilogy(losses)
ax[1].set_title('Sample Image')
ax[1].imshow(trainImages[10])
ax[2].set_title('Sample Ground Truth')
ax[2].imshow(trainSegs[10])
ax[3].set_title('Sample Logits')
ax[3].imshow(predSample)
ax[4].set_title('Sample Predicted Segmentation')
ax[4].imshow(predSample>0.5)

TypeError: ignored