# Segmentation of Medical Scans using Variational VAE's
This notebook goes through the process of creating a pytorch-compatible dataset, and setting up a model for segmentation of tumors in various organs. It enables reproduceability of our final model and testing results.

---

## 0. Setup
We import the necessary libraries.

In [1]:
# For ML
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils as U
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.optim as optim
from torch import Tensor

# For reading raw data.
import json
import nibabel as nib

# For displaying and evaluating results.
import numpy as np
from matplotlib import pyplot as plt

# For monitoring resource-usage and progress.
from tqdm import tqdm
import os, sys, psutil


---

Check if GPU is available and retrieve some system stats.

In [2]:
# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)

available_ram = round(psutil.virtual_memory()[0]/1000000000,2)
print('RAM: ' + str(available_ram) + 'GB')


Using cuda
NVIDIA GeForce GTX 1070
CUDA version: 11.7
RAM: 16.74GB


---

Making sure, we are in the correct working directory.

In [3]:
!pwd

/media/nv/Storage/Data-Science/vae_lung_tumor_segmentation/net_md


---

Setting up some global constants.

In [4]:
root_dir = './' # Location of project, relative to the working directory.
raw_data_dir = os.path.join(root_dir, 'data', 'raw_data')
aug_data_dir = os.path.join(root_dir, 'data', 'aug_data')
model_dir = os.path.join(root_dir, 'model')
stats_dir = os.path.join(root_dir, 'stats')
output_dir = os.path.join(root_dir, 'output')
progress_file = os.path.join(stats_dir,'progression.json')

organs = ['spleen', 'colon', 'lung']
d = 256 # New dimensions (width and height) of datapoints.
num_chunks = 16 # Number of chunks to divide our data into.

# Data transforms.
resize_transform = T.Resize((d, d))
rand_rot_transform = T.RandomRotation(180)

---

We setup a folder structure for our data - both raw and preprocessed.

In [5]:
if not os.path.exists(raw_data_dir):
    os.mkdir(raw_data_dir)
    for organ in organs:
        os.mkdir(os.path.join(raw_data_dir, organ))
        
if not os.path.exists(aug_data_dir):
    os.makedirs(aug_data_dir)
if not os.path.exists(output_dir):
    os.makedirs(aug_data_dir)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
if not os.path.exists(stats_dir):
    os.makedirs(stats_dir)

The `raw_data/` sub-directories for each organ has to be populated manually using the unzipped files from [medicaldecathlon.com](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2)

---

Load current progress, as to not repeat work, which has already been completed, when running the entire notebook at once. The progress flags are stored externally in `progression.json`.

In [6]:
with open(progress_file,'r') as f:
    progression = json.load(f)

---

## 1. Preparing Data

We define a function which loads and stores our data in the proper formatting. As the datasets are huge, we monitor progress and RAM-usage.

In [7]:
def augment_data(organ, training_paths, data_tensor, progress):
    for _, path in enumerate(training_paths):
        progress.set_postfix(**{'RAM':round(psutil.virtual_memory()[3]/1000000000,2)})
        progress.update()

        # Get path to images - removed dot in path from json-file.
        nii_img = nib.load(os.path.join(raw_data_dir,organ + path['image'][1:]))
        
        nii_data = nii_img.get_fdata()
        nii_data = Tensor(nii_img.get_fdata())
        
        # Ensure scale [0; 1]
        nii_data -= nii_data.min()
        nii_data /= nii_data.max() # Are the max the same in every data point?
        nii_data = nii_data.permute(2, 0, 1) # Shape: (slice, rows, columns)
        nii_data = resize_transform(nii_data)
        data_tensor = torch.cat((data_tensor, nii_data), 0)
        
    torch.save(data_tensor, os.path.join(aug_data_dir,organ+'_slices_unaugmented.pt'))
    progress.close()

We process and format datasets from raw data for each organ, using the above function. We save progress after each organ is completed. Can be interrupted and resumed at any time, and accounts for progress, which has already been made.

In [8]:
for organ in organs:
    if not progression['loaded'][organ]:

        path = os.path.join(raw_data_dir,organ,'dataset.json')
        with open(path) as f:
            data_set = json.load(f)
        training_paths = data_set['training']

        data_tensor = torch.zeros((0, d, d))
        total_paths = len(training_paths)

        progress = tqdm(total=total_paths)
        progress.set_description(f'%s' % organ)

        try: 
            augment_data(organ, training_paths, data_tensor, progress)
            print('The ' + organ + ' was successfully loaded.')

            # Change state of progression.json
            progression['loaded'][organ] = True
            with open(progress_file, "w") as f: 
                json.dump(progression, f, indent=4)
            
        except KeyboardInterrupt:
            print ('Manually stopped.\nOrgan: ' + organ + ' was not saved.')
            progress.close()
            break
    else:
        print('The ' + organ + ' set has already been loaded.')

The spleen set has already been loaded.
The colon set has already been loaded.
The lung set has already been loaded.


---

The datasets are so large, that we need to split them into smaller chunks. We first initialize empty chunks, and fill them with augmented data from each organ, evenly split amongst the chunks.

In [9]:
if not progression['augmented']:
    
    for n in range(num_chunks):
        chunk = torch.zeros(0, d, d)
        torch.save(chunk, os.path.join(aug_data_dir, f'unaugmented_chunk_{n}.pt'))

    progress = tqdm(total=len(organs)*num_chunks)
    progress.set_description(f'Augmentation')
    
    try:
        for organ in organs:
            data = torch.load(os.path.join(aug_data_dir,organ+ '_slices_unaugmented.pt'))
            N = data.shape[0]
            idx = torch.randperm(N)
            data = data[idx]
            split_idx = int(N/num_chunks)
            
            for n in range(num_chunks):
                path = os.path.join(aug_data_dir, f'unaugmented_chunk_{n}.pt')
                chunk = torch.load(path)
                chunk = torch.cat((chunk, data[n*split_idx:(n+1)*split_idx]), 0)
                torch.save(chunk, path)
                progress.update()

        print('Augmentation was successful.')

        # Change state of progression.json
        progression['augmented'] = True
        with open(progress_file, "w") as f: 
            json.dump(progression, f, indent=4)
            
    except KeyboardInterrupt:
        print ('Manually stopped.\nChunk ' + str(n) + ' was not saved.')
        progress.close()
        
else:
    print('Data augmentation and chunking already completed.')

Data augmentation and chunking already completed.


---

## 2. Dataset Management
We define a custom dataset class.

---

# 3. Creating Model Architecture
We define our model architecture for the variational autoencoder, which consists of an encoder and a decoder. We start with the encoder.

In [10]:
class Encoder(nn.Module):
    def __init__(self, imageShape, firstFilterCount, act, layerwise=True):
        super(Encoder, self).__init__()
        self.act = act
        self.imageShape = imageShape
        self.firstFilterCount = firstFilterCount
        self.layerwise = layerwise

        self.convDownsamplingLayers = torch.nn.ModuleList()
        self.muEncodingLayers = torch.nn.ModuleList()
        self.logVarEncodingLayers = torch.nn.ModuleList()

        for level in range(int(np.log2(self.imageShape[1])-1)):
            if level == 0:
                self.convDownsamplingLayers.append(torch.nn.Conv2d(in_channels=self.imageShape[0], out_channels=firstFilterCount, kernel_size=4, stride=2, padding=1))
            else:
                self.convDownsamplingLayers.append(torch.nn.Conv2d(in_channels=firstFilterCount * 2**(level - 1), out_channels=firstFilterCount * 2**(level), kernel_size=4, stride=2, padding=1))

            features, code_length = self.firstFilterCount * 2 ** (level + 2), int(2 ** (level + 2))
            self.muEncodingLayers.append(torch.nn.Linear(in_features=features, out_features=code_length))
            self.logVarEncodingLayers.append(torch.nn.Linear(in_features=features, out_features=code_length))
    
    def sample(self, mu, logVar):
        # Reparameterize:
        std = torch.exp(0.5 * logVar)
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon
        return z
    
    def encode(self, x, scale):
        for layer in range(scale):
            x = self.convDownsamplingLayers[layer](x)

        # Define layer based on given scale.
        x = self.convDownsamplingLayers[scale](x)

        # Gaussian prior.
        shape = x.shape
        x = torch.flatten(x, start_dim=1)
        mu = (self.muEncodingLayers[scale](x))
        logVar = (self.logVarEncodingLayers[scale](x))
        
        return mu, logVar, shape
    
    def forward(self, x, scale):
        mu, logVar, shape = self.encode(x, scale)
        z = self.sample(mu, logVar)
        return z, mu, logVar, shape

And then the decoder architecture.

In [11]:
class Decoder(nn.Module):
    def __init__(self, imageShape, firstFilterCount, act, layerwise=True):
        super(Decoder, self).__init__()
        self.act = act
        self.imageShape = imageShape
        self.firstFilterCount = firstFilterCount
        self.layerwise = layerwise

        self.convUpsamplingLayers = torch.nn.ModuleList()
        self.zDecodingLayers = torch.nn.ModuleList()

        for level in range(int(np.log2(self.imageShape[1])-1)):
            if level == 0:
                self.convUpsamplingLayers.append(torch.nn.ConvTranspose2d(in_channels=firstFilterCount, out_channels=self.imageShape[0], kernel_size=4, stride=2, padding=1))
            else:
                self.convUpsamplingLayers.append(torch.nn.ConvTranspose2d(in_channels=int(firstFilterCount * 2**(level)), out_channels=int(firstFilterCount * 2**(level - 1)), kernel_size=4, stride=2, padding=1))
            features, code_length = self.firstFilterCount * 2 ** (level + 2), int(2 ** (level + 2))
            self.zDecodingLayers.append(torch.nn.Linear(in_features=code_length, out_features=features))
    
    def decode(self, z, scale, shape):
        x = self.act(self.zDecodingLayers[scale](z)).reshape(shape)
        
        # Transpose Convolutions
        for layer in range(scale):
            x = self.act(self.convUpsamplingLayers[scale-layer](x))
        x = self.convUpsamplingLayers[0](x)

        return x

    def forward(self, z, scale, shape, detach=False):
        if detach:
            z = z.detach()
        x = self.decode(z, scale, shape)
        return x

Finally, we use these two in combination to create our variational autoencoder.

In [12]:
class VAE(nn.Module):
    def __init__(self, imageShape, firstFilterCount, act, layerwise=True):
        super(VAE, self).__init__()
        self.act = act
        self.imageShape = imageShape
        self.firstFilterCount = firstFilterCount
        self.layerwise = layerwise

        self.encoder = Encoder(imageShape=imageShape, firstFilterCount=firstFilterCount, act=act, layerwise=layerwise)
        self.decoder = Decoder(imageShape=imageShape, firstFilterCount=firstFilterCount, act=act, layerwise=layerwise)
        self.decoder_segmentation = Decoder(imageShape=imageShape, firstFilterCount=firstFilterCount, act=act, layerwise=layerwise)

    def forward(self, x, lod, printCode=False):
        lod = lod - 2
        z, mu, logVar, shape = self.encoder.forward(x, lod)
        x_reconstructed = torch.sigmoid(self.decoder.forward(z, lod, shape))
        x_segmentation = torch.sigmoid(self.decoder_segmentation.forward(z, lod, shape, detach=True))
        if printCode:
            print(z)
        return ((x_reconstructed, mu, logVar), x_segmentation)

# 4. Defining Training- and Evaluation Routines
We define a class, which contain our testing routine, checkpoint management, evaluation routine, and some utility functions. This will allow us to easily run tests with different hyperparameters. 

In [17]:
class Test():
    def __init__(self, model, dir, load, optimizer, criterions, iou_thresh):
        self.model = model
        self.dir = dir
        self.optimizer = optimizer
        self.device = device
        self.criterions = criterions
        self.io_thresh = iou_thresh
        
        if load: 
            self.load_checkpoint(os.path.join(self.dir, 'checkpoint.pt'))
        
        self.model.to(self.device)
    
    def load_checkpoint(self, path):
        cp = torch.load(path)
        self.model.load_state_dict(cp['model_state_dict'])
        self.optimizer.load_state_dict(cp['optimizer_state_dict'])
        
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

    def save_checkpoint(self):
        path = os.path.join(self.dir, 'checkpoint.pt')
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

    def train(self, dataloader, epochs, lod, print_at=10):
        for epoch in range(epochs):
            batch_count = 0
            losses = np.array([0] * len(self.criterions), dtype=np.float32)

            for batch in dataloader:
                print(np.shape(batch))
                y = self.model.forward(batch['x'].to(self.device), lod)

                # Backpropagation
                self.model.zero_grad()
                for i, value in enumerate(zip(y, self.criterions, batch)):
                    output, criterion, key = value
                    loss = criterion(output, batch[key].to(self.device))
                    loss.backward()
                    losses[i] += loss

                self.optimizer.step()
                batch_count += 1

            if ((epoch % print_at) == 0):
                losses = losses / batch_count
                print('Epoch: ' + epoch, 'Reconst/kld loss: ' + losses[0], 
                      'Seg loss: ' + losses[1])

    def IoU(self, label, reconst):
        """Calculates the IoU metric and returns the result within (0, 1)."""
        i = ((label >= self.iou_thresh) & (reconst >= self.iou_thresh)) * 1.0
        u = ((label >= self.iou_thresh) | (reconst >= self.iou_thresh)) * 1.0
        return i.sum() / u.sum() / label.shape[0]

    def evaluate(self, dataloader, lod):
        batch_count = 0
        losses = np.array([0] * len(self.criterions), dtype=np.float32)
        iou = 0
        for batch in dataloader:
            y = self.model.forward(batch['x'].to(self.device), lod)

            for i, value in enumerate(zip(y, self.criterions, batch)):
                output, criterion, key = value
                loss = criterion(output, batch[key].to(self.device))
                losses[i] += loss
            
            iou += self.IoU(batch['t'].to(self.device), y[1])
            batch_count += 1
        
        losses /= batch_count
        iou /= batch_count
        print('Reconst/kld loss:', losses[0], 'Seg loss:', losses[1], 'IoU:', iou)
    
    def reconstruct(self, dataloader, lod, count):
        batch = next(iter(dataloader))
        count = min(count, len(batch['x']))
        x = batch['x'][:count].to(self.device)
        t = batch['t'][:count].to(self.device)
        y = self.model.forward(x, lod)
        x_reconst = y[0][0]
        x_segment = y[1]
        return x, t, x_reconst, x_segment

    def save_reconst(self, dataloader, lod, count, output_dir):
        x_ins, t_ins, x_outs, t_outs = self.reconstruct(dataloader, lod, count)
        x_ins = x_ins.detach().cpu().numpy()
        t_ins = t_ins.detach().cpu().numpy()
        x_outs = x_outs.detach().cpu().numpy()
        t_outs = t_outs.detach().cpu().numpy()
        
        for i, value in enumerate(zip(x_ins, t_ins, x_outs, t_outs)):
            x_in, t_in, x_out, t_out = value

            x = np.stack([x_in] * 3, axis=0).squeeze().transpose((1, 2, 0))
            t = np.stack([t_in] * 3, axis=0).squeeze().transpose((1, 2, 0))
            mask = t[..., 0] > self.iouThreshold

            x[mask] = np.array([0, 1, 0.5]) * t[mask]
            plt.imsave(os.path.join(output_dir, 'res_{}_sample_{}_in.png'.format(2 ** lod, i)), x)

            x_r = np.stack([x_out] * 3, axis=0).squeeze().transpose((1, 2, 0))
            t_r = np.stack([t_out] * 3, axis=0).squeeze().transpose((1, 2, 0))
            mask = t_r[..., 0] > self.iouThreshold
            x_r[mask] = np.array([0, 1, 0.5]) * t[mask]
            plt.imsave(os.path.join(output_dir, 'res_{}_sample_{}_out.png'.format(2 ** lod, i)), x_r)

We also define a custom loss function.

In [14]:
def loss_function(output, x):
    recon_x, mu, logVar = output
    batchSize = mu.shape[0]
    rl = (recon_x - x).pow(2).sum() / batchSize
    kld = -0.5 * torch.sum(1 + logVar - mu.pow(2) - logVar.exp()) / batchSize
    return rl + kld

---

# 5. Putting it all together
We pass our custom datasets to pytorch dataloaders. 

In [15]:
train_data = torch.load(os.path.join(aug_data_dir,'unaugmented_chunk_1.pt'))
test_data = torch.load(os.path.join(aug_data_dir,'unaugmented_chunk_2.pt'))

batchSize = 64
num_workers = 2

train_loader = DataLoader(
    dataset=train_data, batch_size=batchSize,
    shuffle=True, num_workers=num_workers
)

test_loader = DataLoader(
    dataset=test_data, batch_size=batchSize, 
    shuffle=True, num_workers=num_workers
)


In [16]:
model=VAE((1, d, d), 16, F.elu, False)

test = Test(model=model,
            dir=model_dir,
            load=False, 
            optimizer=optim.Adam(model.parameters(), lr=0.001), 
            criterions=[loss_function, torch.nn.BCELoss()], 
            iouThreshold=0.2)

In [None]:
epochs = 10
lod = 6
iterations = 1

for i in range(iterations):
    test.train(train_loader, epochs, lod,1)
    test.evaluate(test_loader, lod)
    test.save_reconst(test_loader, lod, 10, output_dir)
    #test.saveCheckpoint()