# Segmentation of Medical Scans using Variational VAE's
This notebook goes through the process of creating a pytorch-compatible dataset, and setting up a model for segmentation of tumors in various organs. It enables reproduceability of our final model and testing results.

## 1. Setup
We import some necessary libraries

In [1]:
# For ML
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torch.optim as optim
from torch import Tensor
import torchvision.transforms as Transform
from torch.utils.data import Dataset

# For reading raw data.
import json
import nibabel as nib

# For displaying and evaluating results.
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

# For monitoring resource-usage and progress.
from timeit import default_timer as timer
from tqdm import tqdm # Install ipywidgets to remove warning.
import os, sys, psutil
from os.path import join, exists

And check if our GPU is available, while also retrieving some system stats. We need a lot of RAM, because our selected datasets are very large.

In [2]:
# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)

available_ram = round(psutil.virtual_memory()[0]/1000000000,2)
print('RAM: ' + str(available_ram) + 'GB')

Using cuda
NVIDIA GeForce GTX 1070
CUDA version: 11.7
RAM: 16.74GB


We setup up some global constants.

In [3]:
root_dir = '../' # Location of project, relative to the working directory.
raw_data_dir = join(root_dir, 'raw_data')
prep_data_dir = join(root_dir, 'augmenteddata')

cmap_seg = ListedColormap(['none', 'red']) # For drawing tumors in red.

And some utility functions.

In [4]:
def superimpose(image, label):
    plt.imshow(image, cmap='gray')
    plt.imshow(label, cmap=cmap_seg)
    plt.show()

We preprocess and format datasets from raw data for each specified organ, using the above function. We save progress after each organ is completed. Can be interrupted and resumed at any time, and accounts for progress, which has already been made. We define a function which loads and stores our data in the proper formatting. As the datasets are huge and have to concatenate each set of 240 slices to the previous, we monitor progress and RAM-usage.

In [5]:
def prep_data(organ, type, resolution):
    with open(join(raw_data_dir, organ, 'dataset.json')) as f:
        manifest = json.load(f)['training']
    
    bar = tqdm(total=len(manifest))
    bar.set_description('Prepping ' + organ + ' ' + type + 's')
    
    resize = Transform.Resize((resolution, resolution))

    try: 
        images = torch.zeros((0, resolution, resolution))

        for entry in manifest:
            bar.set_postfix(**{'RAM':round(psutil.virtual_memory()[3]/10e8, 2)})
            bar.update()

            nii_img = nib.load(join(raw_data_dir, organ, entry[type][2:]))

            # Convert to numpy array, then pytorch tensor.
            nii_data = Tensor(nii_img.get_fdata())

            # Scale between 0 and 1.
            nii_data -= nii_data.min()
            nii_data /= nii_data.max()
            nii_data = nii_data.permute(2, 0, 1) # (slice, rows, columns)
            nii_data = resize(nii_data)
            images = torch.cat((images, nii_data), 0)
        
        torch.save(images, join(prep_data_dir, organ + '_' + type + '_slices_' + str(resolution) + '.pt'))

    except KeyboardInterrupt:
        print('Manually stopped.')
    
    bar.close()

We call the preprocessor function for the organs, we wish to train on.

In [6]:
lod = 2**7                 # Level of detail.
resolution = lod           # 2**8 = 256
do_prep = False            # Toggle to prep data.

organs = ['spleen','colon','pancreas','lung','liver']

if do_prep:
    for organ in organs:
        prep_data(organ,'image',resolution)
        prep_data(organ,'label',resolution)
else:
    print('Data already prepped.')

Data already prepped.


## 2. Creating a pytorch dataset
We define a custom dataset class.

In [7]:
class CT_Dataset(Dataset):
    def __init__(self, path, organ, resolution):
        self.images = torch.load(join(path, organ + '_image_slices_' + str(resolution) + '.pt'))
        
        self.labels = torch.load(join(path, organ + '_label_slices_' + str(resolution) + '.pt'))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def show_datapoint(self, index):
        image, label = self.__getitem__(index)
        superimpose(image, label)

We check if the dataset functions directly by retrieving and displaying a single datapoint. One must specify the organ and the resolution of the images.

## 3. Defining our model architecture
We define the variational encoder architecture as a pytorch module.

In [8]:
latent_size = 1024 # latent space size
w = h = 256 # output of encoder size

class VAEModel(nn.Module):
    def __init__(self) -> None:
        super(VAEModel, self).__init__()
        self.activation = nn.LeakyReLU(0.05)
        self.encoder = nn.Sequential()
        for i in range(0, 6):
            self.encoder.append(nn.Sequential(
                nn.Conv2d(2**i, 2**(i+1), 4, 2, 1),
                self.activation,))

        self.decoder = nn.Sequential()
        for i in range(2, 7):
            self.decoder.append(nn.Sequential(
                nn.ConvTranspose2d(2**(8-i), 2**(7-i), 4, 2, 1),
                self.activation,))
        self.decoder.append(nn.Sequential(
            nn.ConvTranspose2d(2, 1, kernel_size = 4, stride = 2, padding=1),
            nn.Sigmoid()
            )
        )

        '''
        self.segmentation_decoder = nn.Sequential()
        for i in range(0, 7):
            self.segmentation_decoder.append(nn.Sequential(
                nn.ConvTranspose2d(2**(8-i), 2**(7-i), 4, 2, 1),
                self.activation,))
        self.segmentation_decoder.append(nn.Sequential(
            nn.ConvTranspose2d(2, 1, kernel_size = 4, stride = 2, padding=1),
            nn.Sigmoid()
            )
        )'''

        self.fc_mu = nn.Linear(1024, latent_size)
        self.fc_log_sigma = nn.Linear(1024, latent_size)
        self.latent_de = nn.Linear(1024, 1024)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        
    
    def encode(self, x):
        h = self.encoder(x)
        h = h.view(-1, 1024)
        h = self.activation(h)
        mu = self.fc_mu(h)
        log_sigma = self.fc_log_sigma(h)
        return mu, log_sigma

    def decode(self, z):
        z = self.latent_de(z)
        z = self.activation(z)
        z = z.view(-1, 64, 4, 4)
        z = self.decoder(z)
        return z

    def reparameterize(self, mu, log_sigma):
        std = torch.exp(0.5*log_sigma)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x):
        x = x.view(-1, 1, w, h) # add channel dimension
        mu, log_sigma = self.encode(x) # log sigma is more stable (numerically)
        z = self.reparameterize(mu, log_sigma)
        x_hat = self.decode(z)
        return x_hat.view(-1, w, h), mu, log_sigma

We define a custom loss function, which combines binary-cross-entropy loss and Kullback–Leibler divergence.

In [9]:
bce_loss = nn.BCELoss(reduction='mean')

def loss_fn(x, x_hat, mu, log_var): # Input, reconstructed input.
    x = x.reshape(x_hat.shape)
    BCE = bce_loss(x_hat, x)
    KLD = -0.5*torch.sum(1+log_var-mu.pow(2)-log_var.exp()) # KL divergence.
    KLD = KLD.mean()       # Average over batch.
    loss = BCE + 1e-6*KLD  # Beta_norm * KLD.
    return loss

We define a function to make and return data-loaders.

In [10]:
def make_loaders(data, batch_size):
    N = len(data); N_t = int(0.9*N); N_d = N - N_t
    train_data, dev_data = D.random_split(data, [N_t, N_d])
    train_loader = D.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    dev_loader = D.DataLoader(dev_data, batch_size=batch_size, shuffle=True)
    return train_loader, dev_loader

We define a training routine.

In [11]:
def train_epoch(model, optimizer, train_loader):
    model.train()
    losses = 0
    for data in train_loader:
        x = data[0].to(device)
        optimizer.zero_grad()
        x_hat, mu, log_var = model.forward(x)
        loss = loss_fn(x, x_hat, mu, log_var)
        loss.backward()
        optimizer.step()
        losses += loss.item()    
    return losses / len(train_loader)  # average loss

And an evaluation routine.

In [12]:
def evaluate(model, dev_loader):
    model.eval()
    losses = 0

    for data in dev_loader:
        x = data[0].to(device)
        x_hat, mu, log_var = model.forward(x)
        loss = loss_fn(x, x_hat, mu, log_var)
        losses += loss.item()

    return losses / len(dev_loader)

We use our previously defined class to create an instance of our model. 

In [17]:
model = VAEModel(); model = model.to(device)
batch_size = 64; 

We create an instance of the CT_Dataset, specifying the organ, on which we wish to train our model. Be careful with running this, as the dataset variable will take up a lot of space in RAM.

In [21]:
resolution = 2**8
dataset = CT_Dataset(prep_data_dir, 'lung', resolution)
train_loader, dev_loader = make_loaders(dataset, batch_size)

We inspect a datapoint from our set to make sure it works properly.

In [23]:
# dataset.show_datapoint(900)

We specify the optimizer, we will be using.

In [19]:
lr = 3e-3
optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)

train_losses = []; dev_losses = []; lrs = []

In [None]:
NUM_EPOCHS = 500

for epoch in range(1, NUM_EPOCHS+1):
    lrs.append(optimizer.param_groups[0]['lr'])
    start_time = timer()
    train_loss = train_epoch(model, optimizer, train_loader)
    train_losses.append(train_loss)
    end_time = timer()
    dev_loss = evaluate(model, dev_loader)
    dev_losses.append(dev_loss)
    scheduler.step()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.4f}, Dev loss: {dev_loss:.4f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

In [None]:
model_no = 49
torch.save(model, f'../saved_models/model_{model_no}')
torch.save(torch.Tensor(train_losses), f'../losses/train_losses_model_{model_no}')
torch.save(torch.Tensor(dev_losses), f'../losses/dev_losses_model_{model_no}')
dev_losses_2 = torch.load(f'../losses/dev_losses_model_{44}')
print(dev_losses_2[-1], dev_losses[-1])
#print(len(dev_losses_2), len(dev_losses))
#plt.plot((dev_losses_2))
plt.plot((dev_losses))
plt.grid()

In [None]:
def draw(x, x_hat):
    fig, axs = plt.subplots(1, 2, figsize=(8,5))
    img_0 = x[0].detach().numpy()
    img_1 = x_hat[0].detach().numpy()
    #img = img.reshape((-1, 28, 28)).transpose((1, 0, 2)).reshape(-1, 10*28)
    #img = img * 0.3081 + 0.1307
    axs[0].imshow(img_0, vmin=0, vmax=1, cmap='gray')
    axs[1].imshow(img_1, vmin=0, vmax=1, cmap='gray')
    fig.canvas.draw()

In [None]:
model.eval()
x_test = next(iter(dev_loader))[0][0].view(1, 256, 256).to(device)
#x_test = torch.zeros((1, 256, 256)).to(device)
x_hat_test = model.forward(x_test)[0]
draw(x_test.cpu(), x_hat_test.cpu())

In [None]:
model.eval()
x_sampled = model.decode(torch.randn(1, latent_size).to(device)).view(1, 256, 256).cpu()
plt.imshow(x_sampled[0].detach().numpy(), vmin=0, vmax=1, cmap='gray')