<a href="https://colab.research.google.com/github/cholojuanito/deep-learning-cancer-detection/blob/main/unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cancer Detection

## Overview
Convolutional neural networks are used to solve many computer vision problems. One of the more modern problems researchers and machine learning engineers have tried to solve is detecting cancer in cross-sectioned microscopic slides for various tissue types. These slide images are called "whole slide images" (WSI) they are *huge* images, usually 1-2 GBs large and are digitized with fancy scanner machines.

Recently the FDA gave the first ever approval for a machine learning based digital pathology product [read more here](https://www.paige.ai/news/news-press-release1/). So the digital pathology industry is still growing and in need of great minds!

### What you should already
* Know the Python programming language and various packages like numpy and matplotlib
* Be familiar with PyTorch
* Have a basic understanding of machine and deep learning concepts

### You will learn
* To build a dense prediction model for image segmentation problems
* To understand how to convert research papers into usable networks using PyTorch

### Data set
The data is given as a set of 1024×1024 PNG images. These images are "tiles" or small squares of the original WSI because working with one gigabit file wouldn't fit on most machines and would be extremely slow for training.
Each input image (in the ```inputs``` directory) is an RGB image of a section of tissue,
and there a file with the same name (in the ```outputs``` directory) 
that has a dense labeling of whether or not a section of tissue is cancerous
(white pixels mean “cancerous”, while black pixels mean “not cancerous”).

The data has been pre-split for into test and training splits.
Filenames also reflect whether or not the image has any cancer at all 
(files starting with ```pos_``` have some cancerous pixels, while files 
starting with ```neg_``` have no cancer anywhere).
All of the data is hand-labeled, so the dataset is not very large.
This means that overfitting is a real possibility.

An example image, and its corresponding ground truth labeling, is shown below.
(And is contained in the downloadable dataset below).

![](http://liftothers.org/dokuwiki/lib/exe/fetch.php?w=200&tok=a8ac31&media=cs501r_f2016:pos_test_000072_output.png)
<img src="http://liftothers.org/dokuwiki/lib/exe/fetch.php?media=cs501r_f2016:pos_test_000072.png" width="200">

### Articles & Papers to read beforehand
* [U-Net: Convolutional Networks for Biomedical Image Segmentation | Arvix paper](https://arxiv.org/pdf/1505.04597.pdf)
* [Up Sampling Images | Toward Data Science](https://towardsdatascience.com/up-sampling-with-transposed-convolution-9ae4f2df52d0)

If you don't understand or are new to CNN's then chekcout this article as well
* [Intro to Convolutional Neural Networks | Towards Data Science](https://towardsdatascience.com/an-introduction-to-convolutional-neural-networks-eb0b60b58fd7)

___

## Part 1 - Implement a dense image segmentation network

#### Food for thought

The simplest network you could implement (with all the desired properties)
is just a single convolution layer with two filters and no relu! 
Why is that? (of course it wouldn't work very well!)

#### 1.1 Dissecting the network topology

Below is an image from the U-Net paper, it is an illustration of what a U-Net "looks" like. You can probably guess why they named it "U" net.

![(Figure 1)](https://lh3.googleusercontent.com/qnHiB3B2KRxC3NjiSDtY08_DgDGTDsHcO6PP53oNRuct-p2QXCR-gyLkDveO850F2tTAhIOPC5Ha06NP9xq1JPsVAHlQ5UXA5V-9zkUrJHGhP_MNHFoRGnjBz1vn1p8P2rMWhlAb6HQ=w2400)

Let's take a look at all the parts of this network illustration.

We have, in order from the top-left of the "U", down to the bottom, and then up to the top-right of the "U":
1) 3x3 convolution followed by a Rectified Linear Unit (ReLU)
2) 2x2 max pool
3) 2x2 up convolution
4) 1x1 final convolution
5) "Copy and crop" for each level

Let's dissect each part, its importance and what PyTorch module(s) we will use to implement it.

*3x3 convolution and ReLU*

These convolutions are "unpadded" convolutions, so they will decrease the spatial dimensions of the image slightly, while increasing the number of feature channels. Every one of these convolutions is followed by a ReLU, a non-linear activation function, which provides our network with some non-linearities.

PyTorch modules: `Conv2d`, `ReLU`

*2x2 max pool*

This performs a "down-sampling", which doubles the number of feature channels in the image.

PyTorch modules: `MaxPool2d`

*2x2 up convolution*

This is esstentially the opposite of max pool, it will half the number of feature channels in the image.

PyTorch modules: `ConvTranspose2d`

*1x1 convolution*

We use this last convolution to make sure our outputs have the correct dimensionality, basically to make sure we have the correct number of classes at the end. In our case we are just looking for two classes, cancerous and non-cancerous parts of the image.

PyTorch modules: `Conv2d`

*Copy and crop*

This last part is the operation related to the grey arrows that cross the "U". We take the *output from the second 3x3 convolution, ReLU combo* at each layer of the "U" and crop and concatenate those feature channels to the *output of the up convolution* on the same level from the opposite side of the "U". We have to crop because of some pixels that get lost through the convolution process.

PyTorch function: `torch.cat` for concatenating across the "U"


#### 1.2 Creating the U-Net network

Install some dependencies if they aren't installed already

In [1]:
%pip3 install torch
%pip3 install torchvision
%pip3 install tqdm

UsageError: Line magic function `%pip3` not found.


Below are the imports for all the packages and modules we will use. Run it before you run any other code.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, utils, datasets
from tqdm import tqdm
from torch.nn.parameter import Parameter
import pdb
import torchvision
import os
import gzip
import tarfile
import gc
from IPython.core.ultratb import AutoFormattedTB
__ITB__ = AutoFormattedTB(mode = 'Verbose',color_scheme='LightBg', tb_offset = 1)

# assert torch.cuda.is_available(), "You need to request a GPU from Runtime > Change Runtime"

Alright, so we know that we need the following PyTorch modules to start making this network:
* `Conv2d`
* `ReLU`
* `MaxPool2d`
* `ConvTranspose2d`

Let's think about the other info we will need to convert to code about the network. First, we will need to know how many input and output channels the image will have. We will call those `in_channels` and `out_channels`.

The network in the paper starts with a convolution that produces 64 feature channels so we will want a variable to hold that information. We also have a set number of "levels" or series of convolutional blocks down and up the "U". The network in the paper does this 4 times, but we can make it any number we want. Let's call those variables `num_features_start` and `u_depth`.

That should be all the hyper-parameters we need to make this network. Below is the network I've made, extending it as a `nn.Module` so it fits seamlessly into PyTorch's computation graph.

In [3]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, num_features_start=64, u_depth=4):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.depth = u_depth
        size = num_features_start

        self.conv = nn.Conv2d
        self.activation_func = nn.ReLU

        self.down_convs = nn.ModuleList()
        self.down_samples = nn.ModuleList()
        self.up_samples = nn.ModuleList()
        self.up_convs = nn.ModuleList()

        previous_size = self.in_channels
        current_size = size
        # down the U
        for i in range(self.depth) :
            self.down_convs.append(nn.Sequential(
                self.conv(previous_size, current_size, kernel_size=3, stride=1, padding=1),
                self.activation_func(),
                self.conv(current_size, current_size, kernel_size=3, stride=1, padding=1),
                self.activation_func(),
            ))
            self.down_samples.append(nn.MaxPool2d(kernel_size=2))
            previous_size = current_size
            current_size *= 2

        # bottom convolutions
        self.bottom = nn.Sequential(
            self.conv(previous_size, current_size, kernel_size=3, stride=1, padding=1),
            self.activation_func(),
            self.conv(current_size, current_size, kernel_size=3, stride=1, padding=1),
            self.activation_func()
        )

        # up the U
        for i in range(self.depth) :
            next_size = current_size//2
            self.up_samples.append(nn.Sequential(
                nn.ConvTranspose2d(current_size, next_size, kernel_size=2, stride=2, padding=0),
                self.activation_func(),
            ))
            self.up_convs.append(nn.Sequential(
                self.conv(current_size, next_size, kernel_size=3, stride=1, padding=1),
                self.activation_func(),
                self.conv(next_size, next_size, kernel_size=3, stride=1, padding=1),
                self.activation_func(),
            ))
            current_size = next_size

        # last convolutional layer
        self.final = self.conv(current_size, self.out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # go down the U
        activations = []
        for i in range(self.depth):
            x = self.down_convs[i](x)
            activations.append(x)
            x = self.down_samples[i](x)
        
        x = self.bottom(x)

        # back up the U
        for i in range(self.depth):
            x = self.up_samples[i](x)
            x = torch.cat((activations[-(i+1)], x), 1)
            x = self.up_convs[i](x)

        return self.final(x)

def pad_to_shape(tensor, out_shape):
    """
    Pads this image with zeroes to shp.
    Args:
        tensor: image tensor to pad
        shp: desired output shape
    Returns:
        Zero-padded tensor of shape shp.
    """
    if len(out_shape) == 4:
        pad = (0, out_shape[3] - tensor.shape[3], 0, out_shape[2] - tensor.shape[2])
    elif len(out_shape) == 5:
        pad = (0, out_shape[4] - tensor.shape[4], 0, out_shape[3] - tensor.shape[3], 0, out_shape[2] - tensor.shape[2])
    return F.pad(tensor, pad)

___

## Part 2 - Dataset

### Preparing the dataset
Next, we wrap our image dataset into a PyTorch `Dataset` object to make it easier to load the data into memory and create batches.

In [4]:
class CancerDataset(Dataset):
  def __init__(self, root, download=True, size=512, train=True):
    if download and not os.path.exists(os.path.join(root, 'cancer_data')):
      datasets.utils.download_url('http://liftothers.org/cancer_data.tar.gz', root, 'cancer_data.tar.gz', None)
      self.extract_gzip(os.path.join(root, 'cancer_data.tar.gz'))
      self.extract_tar(os.path.join(root, 'cancer_data.tar'))
    
    postfix = 'train' if train else 'test'
    root = os.path.join(root, 'cancer_data', 'cancer_data')
    self.dataset_folder = torchvision.datasets.ImageFolder(os.path.join(root, 'inputs_' + postfix) ,transform = transforms.Compose([transforms.Resize(size),transforms.ToTensor()]))
    self.label_folder = torchvision.datasets.ImageFolder(os.path.join(root, 'outputs_' + postfix) ,transform = transforms.Compose([transforms.Resize(size),transforms.ToTensor()]))

  @staticmethod
  def extract_gzip(gzip_path, remove_finished=False):
    print('Extracting {}'.format(gzip_path))
    with open(gzip_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(gzip_path) as zip_f:
      out_f.write(zip_f.read())
    if remove_finished:
      os.unlink(gzip_path)
  
  @staticmethod
  def extract_tar(tar_path):
    print('Untarring {}'.format(tar_path))
    z = tarfile.TarFile(tar_path)
    z.extractall(tar_path.replace('.tar', ''))

  def __getitem__(self,index):
    img = self.dataset_folder[index]
    label = self.label_folder[index]
    return img[0],label[0][0]
  
  def __len__(self):
    return len(self.dataset_folder)

___

## Part 3 - Training the model

### Trainer class
Here we are going to bundle up general training steps into a single class. It doesn't directly rely on PyTorch expect for making sure that the gradient is off for the eval steps. In our case, training involves the following steps:
1. For each epoch*:
    1. Ensure the gradient is turned on (this is how the network learns through backpropagation)
    2. For each batch of images in the **training** dataset:
        1. Load the batch into GPU memory
        2. Reset the gradient
        3. Pass the images through the U-Net
        4. Calculate loss/error
        5. Track the accuracy
        6. Call `backward()` to peform the back propagation
        7. Take a "step" in the solution space based on the gradient
    3. Turn off the gradient
    4. For each batch of images in the **validation** dataset:
        1. Load the batch into GPU memory
        2. Pass the images through the U-Net
        3. Calculate loss/error
        4. Track the accuracy

\* - **epoch** means one round through both the training and validation datasets

In [5]:
class SegmentationTrainer:
    def __init__(self):
        self.network = None
        self.optimizer = None
        self.loss_func = None
        self.train_dataloader = None 
        self.val_dataloader = None

        self.device = 'cuda'
        self.epoch_count = 0
        self.loop = None

    def increment_epoch_count(self):
        self.epoch_count += 1

    def train(self, num_epochs):
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []

        for i in range(1, num_epochs+1):
            self.loop = tqdm(total=len(self.train_dataloader), position=0, leave=False)
            # Train
            print(f'\nTraining for epoch {i} of {num_epochs}')
            train_loss, train_acc = self.train_epoch_(self.train_dataloader, training=True)
            self.loop.close()

            self.loop = tqdm(total=len(self.val_dataloader), position=0, leave=False)
            # Validate
            print(f"\nValidation for epoch {i} of {num_epochs}")
            val_loss, val_acc = self.train_epoch_(self.val_dataloader, training=False)
            self.loop.close()
            

            train_losses.append(train_loss)
            val_losses.append(val_loss)
            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)

        return  train_losses, val_losses, train_accuracies, val_accuracies

    def train_epoch_(self, data_loader, training=True):
        if training:
            self.network.train() 
            torch.set_grad_enabled(True)
        else:
            self.network.eval()
            torch.set_grad_enabled(False)

        loss_sum = 0
        acc_sum = 0
        for i, (imgs, labels) in enumerate(data_loader):
            imgs, labels = imgs.to(self.device, non_blocking=True), labels.to(self.device,non_blocking=True) # non_blocking is a speed up (async)
            
            if training:
                self.optimizer.zero_grad() # Set gradient to zero

            y_hat = self.network(imgs)
            loss = self.loss_func(y_hat, labels.long())
            loss_sum += loss.detach().sum().item()

            
            probs = y_hat.argmax(1)
            accuracy = (probs == labels).float().mean()
            acc_sum += accuracy.detach().item()

            mem_allocated = torch.cuda.memory_allocated(0) / 1e9

            self.loop.set_description('loss: {:.4f}, accuracy: {:.4f}, mem: {:.2f}'.format(loss.detach().sum().item(), accuracy, mem_allocated))
            self.loop.update(1)

            if training:
                loss.backward() # Compute gradient, for weight with respect to loss
                self.optimizer.step() # Take step in the direction of the negative gradient

        return loss_sum/i, acc_sum/i

### Preparation
Now we instantiate anything we need for training. Datasets, dataloaders, the network, loss functions, etc.

In [6]:
# Instantiate data sets
train_dataset = CancerDataset('D:\\tmp', train=True)
val_dataset = CancerDataset('D:\\tmp', train=False)

Downloading http://liftothers.org/cancer_data.tar.gz to D:\tmp\cancer_data.tar.gz


100%|██████████| 2750494655/2750494655 [15:32<00:00, 2948261.12it/s]


Extracting D:\tmp\cancer_data.tar.gz
Untarring D:\tmp\cancer_data.tar


In [6]:
# Instantiate data loaders
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=4,
                          num_workers=4,
                          pin_memory=True,)

val_loader = DataLoader(val_dataset,
                        shuffle=True,
                        batch_size=4,
                        num_workers=4,
                        pin_memory=True)

In [11]:
# Instantiate the network
device = "cuda"

net = UNet(
    3, #RGB
    2, #Greyscale
    num_features_start=64,
    u_depth=4
).to(device)

# Instantiate loss function and optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(
    net.parameters(),
    lr = learning_rate
)
loss_func = nn.CrossEntropyLoss().to(device)

In [12]:
# Instantiate the trainer
trainer = SegmentationTrainer()
trainer.network = net
trainer.optimizer = optimizer
trainer.loss_func = loss_func
trainer.train_dataloader = train_loader
trainer.val_dataloader = val_loader

### Training
Next we will create a loop that goes over the number of epochs we want to train for, in our case we will do 10 so we aren't waiting hours.

In [None]:
num_epochs = 10
def train_loop(num_epochs):
  try:
    gc.collect()
    print(torch.cuda.memory_allocated() / 1e9)
    return trainer.train(num_epochs)
    
  except:
    __ITB__()
    
train_losses, val_losses, train_acc, val_acc = train_loop(num_epochs)

##### WARNING: You may run into an error that says `RuntimeError: CUDA out of memory`

In this case, the memory required for your batch is larger than what the GPU is capable of. You can solve this problem by either providing more GPUs/memory or adjusting the image size or the batch size and then restarting the runtime. 


___

## Part 4 - Display performance

A key part of determining the training performance and potential production performance of any model is to calculate and visualize its accuracy.


### Plot performance over time




In [None]:
# Graphing fucntions
def show_loss_graph(title, train_losses, val_losses):
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.title(title)
    plt.ylabel("Avg. Loss per epoch")
    plt.xlabel("Epochs")
    plt.ylim(0, max(max(train_losses), max(val_losses)))
    plt.xlim(0, len(train_losses))
    plt.legend()
    plt.show()

def show_accuracy_graph(title, train_acc, val_acc):
    plt.plot(train_acc, label="Train Accuracy")
    plt.plot(val_acc, label="Validation Accuracy")
    plt.title(title)
    plt.ylabel("Avg. Accuracy per epoch (%)")
    plt.xlabel("Epochs")
    plt.ylim(0, 1.0)
    plt.xlim(0, len(train_acc))
    plt.legend()
    plt.show()

**NOTE:**

Guessing that the pixel is not cancerous every single time will give you an accuracy of ~85%. This is due to the fact that the majority of cases and therefore cells are going to be noncancerous. Cancerous cells are an outlier and we want our network to be able to accurately detect them.
The trained network should be able to do better than 85% accuracy if we want to call it a successful model.

Below are some example graphs that show plotted accuracy metrics

![](http://liftothers.org/dokuwiki/lib/exe/fetch.php?w=400&tok=d23e0b&media=cs501r_f2016:training_accuracy.png)
![](http://liftothers.org/dokuwiki/lib/exe/fetch.php?w=400&tok=bb8e3c&media=cs501r_f2016:training_loss.png)

In [None]:
show_loss_graph("Loss graph", train_losses, val_losses)

show_accuracy_graph("Accuracy graph", train_acc, val_acc)

___

## Part 5 - Generating predictions

This is the real test to see how well your model accomplished its task. We need it to work outside of a training environment and in the real world with novel data points, in our case, new images. This is commonly known as inference.

First let's define a function that converts our tensors to images so we can see what they look like

In [None]:
def show_tensor_as_img(tensor):
    img = tensor.numpy()
    plt_img = np.transpose(img, (1, 2, 0))
    if (plt_img.shape[2] < 2):
        plt.imshow(plt_img.squeeze() * 255, cmap='gray', vmin=0, vmax=255)
    else:
        plt.imshow(plt_img)
    plt.show()

### Inference
A key aspect of inference is that it needs to be fast AND not update the network further. We want to keep the weights of all our network connections fixed. To ensure things stay put we call the `Module.eval()` method to ensure that the gradient will not be updated by passing images through the network.

Here we will just use image #172 from the validation data set since we know its label is marked with some cancerous tissue. This way we can compare our model's prediction output using the base and truth images

In [None]:
test_img, test_label = val_dataset[172]
print('Base Image')
show_tensor_as_img(test_img)

print('Diagnosis - Truth')
show_tensor_as_img(test_label.unsqueeze(0))

# We aren't training so we need to put the network into `eval` mode
net.eval()

img_batch = torch.unsqueeze(test_img, 0)
img_batch = img_batch.to(device)
prediction = net(img_batch)
prediction = prediction.argmax(1).to('cpu')

print('Diagnosis - Prediction')
show_tensor_as_img(prediction)


----

## Part 6 - Saving the model

Alright! We've made a pretty good model. But what go is it if it will disappear the second we end this coding session. We will have to train a new one every time and that's just not an efficient use of resources. So we will learn a few ways we can save our model's state. This will allow us to pick up where we left off if we need to continue training and fine tuning a model or allow us to use the model in a production inference environment.

