<h1> Lung segmentation of Chest X-Rays </h1>
Image segmentation is an interesting and widely used application for Neural Networks.<br>
We can think of segmentation as a "per-pixel" image classification where the input is not a vector of class activations for the whole image, but class activations for every pixel! We train the segmentation network very similarly to a simple classification and can even use the same loss function. The difference comes with the structure of our network. As was hinted the output of a segmentation network needs to be an image with as many channels as there are classes. Therefore a simple downsampling network will not work!<br>
Instead something like an Autoencoder network must be used!

<h3>Autoencoders</h3>
Autoencoders are a fairly straightforward network structure, characterised by a "bottleneck" where the input is "compressed" before being upsampled again. This network can be used to create compressed representations of images by training the model to reconstruct the input on the output. It could also be used for our segmentation problem! However in segmentation, we don't really want our network to compress our image, we want it to do some "work" and then give us a segmented version of the input!
<img src="https://miro.medium.com/max/3148/1*44eDEuZBEsmG_TCAKRI3Kw@2x.png" width="750" align="center">

[Autoencoders](https://towardsdatascience.com/applied-deep-learning-part-3-autoencoders-1c083af4d798)

In [None]:
import numpy as np 
from random import shuffle
from PIL import Image
import os
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import time
from IPython.display import clear_output

In [None]:
n_epochs = 20
lr = 2e-4
batch_size = 16
data_path ="Chest_xray"

In [None]:
#Set device to GPU_indx if GPU is avaliable
GPU_indx = 0
device = torch.device(GPU_indx if torch.cuda.is_available() else 'cpu')
print(device)

<h2>Create our Dataset</h2>
Along with this notebook came a folder for our data which is split into folders, one containing the input images and the other the segmentation mask (label). Along with the data is a CSV file which specifies which data should be used for training and testing.The following dataset reads the test/train split CSV and saves the filenames of either the training or testing split.<br>
There are many ways to construct such a dataset class as well as the format of your dataset folder structure, this is just one common way.

In [None]:
#Dataset class for our lung data (used by the data loader)
class LungDataset(Dataset):
    def __init__(self, root_data_dir, training = True):
        
        #Load all the filenames and their train/test indexs
        #For every filename there is a 1 or a 0 indicating that it belongs to the 
        #training set or test set
        train_test = np.loadtxt(root_data_dir + "/Train_Test_split.csv", dtype="<U8")
        
        #Convert the indexs from str to int and compare this to "training"
        #1 = True (Training)
        #0 = False (Testing)
        #Then use the boolean array to index the filenames
        self.filenames = train_test[(train_test[:, 1].astype(int) == training), 0]
        
        #ToTensor object for converting image to tensor
        self.to_tensor = torchvision.transforms.ToTensor()
        
        self.root_data_dir = root_data_dir

    #Returns a single image and label pair
    def __getitem__(self, index):
        
        #Read image and labels
        image = Image.open(self.root_data_dir + '/images/' + self.filenames[index])
        label = Image.open(self.root_data_dir + '/labels/' + self.filenames[index])

        #Make image in range (-1,1)
        image = self.to_tensor(image)
        image = (image-0.5)/0.5
        image = image[0:1,:,:] #Images are grayscale, only need one channel
        
        #Cross entropy loss needs labels as LongTensor type
        label = self.to_tensor(label).type(torch.LongTensor).squeeze(0)

        return image, label
    
    def __len__(self):
        return len(self.filenames)


<h3>Create a dataset and dataloader</h3>

In [None]:
# Create train and test dataset
dataset_train = LungDataset(data_path)
dataset_test = LungDataset(data_path,False)
data_loader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
data_loader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)

<h3> Perform Sanity Check </h3>
It is prudent to perform sanity check of the data correspondance. It become a routine check-up after a while but it is very crucial to check if we had made a mistake in loading the data.

In [None]:
#create a dataloader itterable object
dataloader_it = iter(data_loader_train)
#sample from the itterable object
image, label = next(dataloader_it)

plt.figure(figsize = (20,10))
img_out = torchvision.utils.make_grid(((image+1)/2)[0:4], 4)
lbl_out = torchvision.utils.make_grid(label[0:4].unsqueeze(1), 4).float()

out = torch.cat((img_out, lbl_out), 1)

plt.imshow(out.numpy().transpose((1, 2, 0)))

<h2>The U-Net</h2>
The U-Net was developed specifically for image segmentation, the intuition being that the "autoencoder-like" structure will extract class information from the input image and the skip connections allow image "structure" information (contained in the feature maps) to jump the bottle-neck. This means that the network does not have to "learn" how to extract and compress the structure of the image leading to sharper edges and higher quality results.
<img src="https://miro.medium.com/max/1200/1*f7YOaE4TWubwaFF7Z1fzNw.png" width="750" align="center">

[U-Net](https://towardsdatascience.com/u-net-b229b32b4a71)

<h3>Transpose Convolution</h3>
The U-Net model also introduces a new layer-type the "Transpose convolution" (sometimes called "Deconvolution")<br>
The transpose convolution is a "learnable upsampling" method and is essentially the opposite of a convolution! We take a single feature (pixel) in our feature map and replicate it and multiply by a kernel, any overlapping sections are added together. The easiest way to understand them is with the following animation (where the blue square is the input and green is the output).
<img src="https://miro.medium.com/max/986/1*yoQ62ckovnGYV2vSIq9q4g.gif" width="750" align="center">

[Transpose Convolution](https://medium.com/apache-mxnet/transposed-convolutions-explained-with-ms-excel-52d13030c7e8)

[Checkerboard Artifacts](https://distill.pub/2016/deconv-checkerboard/)

<h4> Define the network - U-Net</h4>
We will code the Unet model in two ways: Unet1 and Unet2.<br>
Both of these two network structure are identical. The method we used for Unet1 is easier to visualise and understand, but Unet2 is modular, which allows easier adding/removing/modification of layers.

In [None]:
# Unet1
class Unet1(nn.Module):
    def __init__(self):
        # Call the __init__ function of the parent nn.module class
        super(Unet1, self).__init__()
        
        # Define the first double conv layers, it contains
        # 1. conv 32-channels out, 3x3 kernal and padding of 1
        # The purpose of padding is to retain the output shape of each channel same as input.
        # 2. add a nn.ReLU() to activate the conv layer
        # 3. conv 32-channels out, 3x3 kernal and padding of 1
        # 4. another relu
        self.doubleconv1 = nn.Sequential (nn.Conv2d(1, 32, 3, padding=1),
                                    nn.ReLU(),
                                    nn.Conv2d(32, 32, 3, padding=1),
                                    nn.ReLU()
                                   )
        
        # The second block contains a maxpooling and two conv layers followed by relu
        # 1. maxpooling to halve the image size
        # 2. conv 64-channels out, 3x3 kernal and padding of 1
        # 3. relu
        # 4. conv 64-channels out, 3x3 kernal and padding of 1
        # 5. another relu
        self.down2 = nn.Sequential (nn.MaxPool2d(2),
                                    nn.Conv2d(32, 64, 3, padding=1),
                                    nn.ReLU(),
                                    nn.Conv2d(64, 64, 3, padding=1),
                                    nn.ReLU()
                                   )
        
        # The third block:
        # 1. maxpooling to halve the image size
        # 2. conv 128-channels out, 3x3 kernal and padding of 1
        # 3. relu
        # 4. conv 128-channels out, 3x3 kernal and padding of 1
        # 5. another relu
        self.down3 =  nn.Sequential(nn.MaxPool2d(2),
                                    nn.Conv2d(64, 128, 3, padding=1),
                                    nn.ReLU(),
                                    nn.Conv2d(128, 128, 3, padding=1),
                                    nn.ReLU()
                                   )

        
        # Now as we reach the bottleneck of our network we want upsample to double the size of each feature map
        # use nn.ConvTranspose2d with 64-channels out, 2x2 kernal and stride of 2
        self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        
        # 1. conv 128-channels in, 64-channels out, 3x3 kernal and padding of 1
        # the number of channels in is 128 is because we concatenate the output of self.up4
        # and the output of self.down2
        # 2. relu
        # 3. conv 64-channels out, 3x3 kernal and padding of 1
        # 4. another relu
        self.doubleconv5 = nn.Sequential(nn.Conv2d(128, 64, 3, padding=1),
                                         nn.ReLU(),
                                         nn.Conv2d(64, 64, 3, padding=1),
                                         nn.ReLU()
                                        )
            
        # use nn.ConvTranspose2d with 32-channels out, 2x2 kernal and stride of 2
        self.up6 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        
        # 1. conv 64-channels in, 32-channels out, 3x3 kernal and padding of 1
        # 2. relu
        # 3. conv 32-channels out, 3x3 kernal and padding of 1
        # 4. another relu
        # 5. conv 2-channels out, 1x1 kernal and no padding
        self.doubleconv7 = nn.Sequential( nn.Conv2d(64, 32, 3, padding=1),
                                          nn.ReLU(),
                                          nn.Conv2d(32, 32, 3, padding=1),
                                          nn.ReLU(),
                                          nn.Conv2d(32, 2, 1, padding=0)
                                        )
        
    def forward(self, x):
        # pass the input to the network
        # x is 1 x 64 x 64
        x1 = self.doubleconv1(x)
        # x1 = 32 x 64 x 64
        x2 = self.down2(x1)
        # x2 = 64 x 32 x 32
        x3 = self.down3(x2)
        # x3 = 128 x 16 x 16
        x4 = self.up4(x3)
        # x4 = 64 x 32 x 32
        # torch.cat([x4,x2]) = 128 x 32 x 32 
        x5 = self.doubleconv5(torch.cat([x4,x2],dim=1))
        # x5 = 64 x 32 x 32
        x6 = self.up6(x5)
        # x6 = 32 x 64 x 64
        # torch.cat([x6,x1]) = 64 x 64 x 64
        x7 = self.doubleconv7(torch.cat([x6,x1],dim=1))
        # x7 = 2 x 64 x 64
    
        return x7

In [None]:
#Unet2
class Unetdown(nn.Module):
    def __init__(self, input_nc, output_nc, first_layer = False):
        super(Unetdown, self).__init__()
        
        model = []
        if not first_layer:
            model += [nn.MaxPool2d(2)]
        
        model += [nn.Conv2d(input_nc, output_nc, 3, padding=1),
                  nn.ReLU(),
                  nn.Conv2d(output_nc, output_nc, 3, padding=1),
                  nn.ReLU()]
        
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        out = self.model(x)
        
        return out
      

class Unetup(nn.Module):
    def __init__(self, input_nc, output_nc, last_layer = False):
        super(Unetup, self).__init__()

        self.up= nn.ConvTranspose2d(input_nc, output_nc, 2, stride=2)
        model = []
        model += [nn.Conv2d(input_nc, output_nc, 3, padding=1),
                  nn.ReLU(),
                  nn.Conv2d(output_nc, output_nc, 3, padding=1),
                  nn.ReLU()]
        
        if last_layer:
            model += [nn.Conv2d(output_nc, 2, 1, padding=0)]
          
        self.model = nn.Sequential(*model)
            
    def forward(self, x1, x2):
        x1 = self.up(x1)
        out = self.model(torch.cat([x1,x2],dim=1))
        
        return out
            
         
class Unet2(nn.Module):
    def __init__(self):
        super(Unet2, self).__init__()
        
        self.down1 = Unetdown(1, 32, True)
        self.down2 = Unetdown(32, 64, False)
        self.down3 = Unetdown(64, 128, False)
        self.up4 = Unetup(128, 64, False)
        self.up5 = Unetup(64, 32, True)
        
    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.up4(x3, x2)
        x5 = self.up5(x4, x1)
        
        return x5
        

<h3>Create our Model and optimizer </h3>

In [None]:
#Create UNet model - output is size (batch_size x 2 x H x W)
#The two channels correspond to the two classes: not lung (class 0) and lung (class 1).

#Transfor model to GPU
model = Unet2().to(device)
print(model)

#Use an Adam optimiser to update the weights of the model
optimiser = optim.Adam(model.parameters(), lr = lr)

#Cross entropy - softmax over the two classes and negative log liklihood loss
loss_fn = nn.CrossEntropyLoss()

## Train the model

In [None]:
#Set maximum epochs and create empty lists to store losses
Train_loss = []
Test_loss = []

for epoch in range(n_epochs):
    running_loss_train = 0.0
    running_loss_test = 0.0
    start_time = time.time()
    model.train()
    
    #For each training batch...
    for i, (image, label) in enumerate(data_loader_train):   
           
        #Move images and labels to device
        image = image.to(device)
        label = label.to(device)
        
        #Forward pass through model
        outputs = model(image)
        
        #Compute cross entropy loss
        loss = loss_fn(outputs, label)
        running_loss_train += loss.item()

        #Gradients are accumulated, so they should be zeroed before calling backwards
        optimiser.zero_grad()
        
        #Backward pass through model and update the model weights
        loss.backward()
        optimiser.step()
        
    running_loss_train /= len(data_loader_train)
    Train_loss.append(running_loss_train)
    
    #Compute validation loss
    model.eval()
    with torch.no_grad():
        for i, (image, label) in enumerate(data_loader_test):   
               
            image = image.to(device)
            label = label.to(device)

            outputs = model(image)
            loss = loss_fn(outputs, label)
            running_loss_test += loss.item()
    
        
    running_loss_test /= len(data_loader_test)
    Test_loss.append(running_loss_test)
    end_time = time.time()
    
    clear_output(True)
    print('[Epoch {0:02d}] Train Loss: {1:.4f}, Val Loss: {2:.4f}, Time: {3:.4f}s'.format(epoch, running_loss_train, running_loss_test,end_time - start_time))



## Plot the metric and evaluate

In [None]:
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.plot(Train_loss, '-', label = 'Training Loss')
plt.plot(Test_loss, '-', label = 'Validation Loss')
plt.legend()

# Test the model

In [None]:
data_loader_iter = iter(data_loader_test)

with torch.no_grad():
    for i in range(5):
        image, label = next(data_loader_iter)
        
        plt.subplot(1,3,1)
        plt.imshow(image[0,0,:,:], cmap='gray')
        plt.xlabel("Base Image")
        
        image = image.to(device)
        output = model(image)
        pred = torch.argmax(output,dim=1,keepdim=True)
        
        plt.subplot(1,3,2)
        plt.imshow(label[0,:,:], cmap='gray')
        plt.xlabel("Ground Truth")
        
        plt.subplot(1,3,3)
        plt.imshow(pred.cpu().numpy()[0,0,:,:], cmap='gray')
        plt.xlabel("Prediction")
        plt.show()