# Neural network for segmenting LV of heart

## Introduction
This work was performed as part of the MRAI workshop (MIDL 2019 satellite meeting) with exercises designed by Esther Puyol (https://github.com/estherpuyol/MRAI_workshop).

### Objective
Train a simple deep neural network (DNN) to classify between healthy and heart failure subjects using clinical metrics, i.e. LVEDV, LVESV, LVSV and LVEF.

## Import modules
The network is built using functions and classes from the [pytorch library](https://pytorch.org/docs/stable/index.html) 

In [0]:
import os
import numpy as np
import pylab as plt#
import torch, torch.nn as nn, torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import precision_score, recall_score
from sklearn.model_selection import StratifiedShuffleSplit

## Download dataset
The data for this study is from the [Sunnybrook Cardiac Data](https://www.cardiacatlas.org/studies/sunnybrook-cardiac-data/)

A preprocessed subset of this data is used, where the data is filtered to contain only left ventricle myocardium segmentations and reduced in XY dimensions.

In [23]:
![ -f scd_lvsegs.npz ] || wget https://github.com/estherpuyol/MRAI_workshop/raw/master/scd_lvsegs.npz

data=np.load('scd_lvsegs.npz') # load all the data from the archive

images=data['images'] # images in BHW array order : .shape (420, 64, 64)]
segs.shape=data['segs'] # segmentations in BHW array order : .shape (420, 64, 64)
caseIndices=data['caseIndices'] # the indices in `images` for each case : .shape (45, 2)

images=images.astype(np.float32)/images.max() # normalize images

(45, 2)

## Split training and test data
This code block splits the data set into training and test data where the variable n_training determines the number of test cases.  

In [0]:
n_training = 6

testIndex=caseIndices[-n_training,0] # keep the last n_training cases for testing

# divide the images, segmentations, and categories into train/test sets
trainImages,trainSegs=images[:testIndex],segs[:testIndex]
testImages,testSegs=images[testIndex:],segs[testIndex:]

## Define segmentation network
Here two classes are defined. 

The first is the loss function ```DiceLoss``` which will provide a measure of overlap between ground truth segmentations and the network outputs. 

The second class ```SegNet``` is our artifical neural network, in this case an auto-encoder that inherits methods from the base NN class torch.nn.Module. 

In [0]:
class DiceLoss(nn.modules.loss._Loss):
    '''This defines the binary dice loss function used to assess segmentation overlap.'''
    def forward(self, source, target, smooth=1e-5):
        batchsize = target.size(0)
        source=source.sigmoid() # apply sigmoid to the source logits to impose it onto the [0,1] interval
        
        # flatten target and source arrays to 2D BV arrays
        tsum=target.view(batchsize, -1) 
        psum=source.view(batchsize, -1)
        
        intersection=psum*tsum
        sums=psum+tsum 

        # compute the score, the `smooth` value is used to smooth results and prevent divide-by-zero
        score = 2.0 * (intersection.sum(1) + smooth) / (sums.sum(1) + smooth)
        
        # `score` is 1 for perfectly identical source and target, 0 for entirely disjoint
        return 1 - score.sum() / batchsize

# Segmentation network
class SegNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.model=nn.Sequential(
            # layer 1: convolution, normalization, downsampling
            nn.Conv2d(1,2,3,1,1),
            nn.BatchNorm2d(2),
            nn.ReLU(),
            nn.MaxPool2d(3,2,1),
            # layer 2
            nn.Conv2d(2,4,3,1,1),
            # layer 3
            nn.ConvTranspose2d(4,2,3,2,1,1),
            nn.BatchNorm2d(2),
            nn.ReLU(),
            # layer 4: output
            nn.Conv2d(2,1,3,1,1),
        )
        
    def forward(self,x):
        return self.model(x)

## Train network