# Segmentation of Medical Scans using Variational VAE's
This notebook goes through the process of creating a pytorch-compatible dataset, and setting up a model for segmentation of tumors in various organs. It enables reproduceability of our final model and testing results.

---

## 0. Setup
We import some necessary libraries.

In [76]:
# For ML
import torch

# For reading raw data.
import json
import nibabel as nib

# For displaying and evaluating results.
import numpy as np
from matplotlib import pyplot as plt

# For monitoring resource-usage and progress.
from tqdm import tqdm
import os, sys, psutil
from os.path import join, exists


Check if GPU is available and retrieve some system stats.

In [77]:
# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)

available_ram = round(psutil.virtual_memory()[0]/1000000000,2)
print('RAM: ' + str(available_ram) + 'GB')

Using cuda
NVIDIA GeForce GTX 1070
CUDA version: 11.7
RAM: 16.74GB


Making sure, we are in the correct working directory.

In [78]:
!pwd

/media/nv/Storage/Data-Science/vae_lung_tumor_segmentation/net_md


Setting up some global constants.

In [79]:
root_dir = './' # Location of project, relative to the working directory.
raw_data_dir = join(root_dir, 'data', 'raw_data')
aug_data_dir = join(root_dir, 'data', 'aug_data')
model_dir = join(root_dir, 'model')
stats_dir = join(root_dir, 'stats')
output_dir = join(root_dir, 'output')

organs = ['spleen', 'colon', 'lung']

d = 256 # New dimensions (width and height) of datapoints.
num_chunks = 16 # Number of chunks to divide our data into.

We define a function to load current progress, as to not repeat work, which has already been completed, when running the entire notebook at once. The progress flags are stored externally in `progression.json`.

In [80]:
def get_progress():
    with open('progress.json','r') as f:
        return json.load(f)

def set_progress(progress):
    with open('progress.json', "w") as f: 
        json.dump(progress, f, indent=4)

We setup a folder structure for our data - both raw and preprocessed.

In [81]:
if not exists(aug_data_dir): os.makedirs(aug_data_dir)
if not exists(output_dir): os.makedirs(aug_data_dir)
if not exists(model_dir): os.makedirs(model_dir)

if not exists(raw_data_dir):
    os.mkdir(raw_data_dir)
    for organ in organs:
        os.mkdir(os.path.join(raw_data_dir, organ))

The `raw_data/` sub-directories for each organ has to be populated manually using the unzipped files from [medicaldecathlon.com](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2)

---

## 1. Inspecting the Data

In [None]:
from PIL import Image
import gzip

organ = 'lung'
organ_path = join(raw_data_dir, organ)
file = join(organ_path, 'dataset.json')
with open(file,'r') as f:
     manifest = json.loads(f.read())

train_manifest = manifest['training']
test_manifest = manifest['test']

image_path = train_manifest[1]['image'][2:]
label_path = train_manifest[1]['label'][2:]

with gzip.open(join(organ_path, image_path),'rb') as f:
    decomp_image = f.read()

decomp_image

---

## 2. Preparing Data

We define a function which loads and stores our data in the proper formatting. As the datasets are huge, we monitor progress and RAM-usage.

In [None]:
from torch import Tensor
import torchvision.transforms as Transform

resize = Transform.Resize((d, d))
rotate = Transform.RandomRotation(180)

def prep_data(organ, training_paths, data_tensor, bar):
    for _, path in enumerate(training_paths):
        bar.set_postfix(**{'RAM':round(psutil.virtual_memory()[3]/10e8, 2)})
        bar.update()

        # Get path to images - removed dot in path from json-file.
        nii_img = nib.load(join(raw_data_dir,organ + path['image'][1:]))
        
        nii_data = nii_img.get_fdata()
        nii_data = Tensor(nii_img.get_fdata())
        
        # Ensure scale [0; 1]
        nii_data -= nii_data.min()
        nii_data /= nii_data.max() # Are the max the same in every data point?
        nii_data = nii_data.permute(2, 0, 1) # Shape: (slice, rows, columns)
        nii_data = resize(nii_data)
        data_tensor = torch.cat((data_tensor, nii_data), 0)
        
    torch.save(data_tensor, join(aug_data_dir,organ+'_slices_unaugmented.pt'))
    bar.close()

We prep and format datasets from raw data for each organ, using the above function. We save progress after each organ is completed. Can be interrupted and resumed at any time, and accounts for progress, which has already been made.

In [None]:
progress = get_progress()

for organ in organs:
    if not progress['loaded'][organ]:

        path = join(raw_data_dir,organ,'dataset.json')
        with open(path) as f:
            data_set = json.load(f)
        training_paths = data_set['training']

        data_tensor = torch.zeros((0, d, d))
        total_paths = len(training_paths)

        bar = tqdm(total=total_paths)
        bar.set_description(f'%s' % organ)

        try: 
            prep_data(organ, training_paths, data_tensor, bar)
            print('The', organ, 'was successfully loaded.')

            progress['loaded'][organ] = True
            set_progress(progress)
            
        except KeyboardInterrupt:
            print ('Manually stopped.\nOrgan:', organ, 'was not saved.')
            bar.close()
            break
    else:
        print('The', organ, 'set has already been loaded.')

---

The datasets are so large, that we need to split them into smaller chunks. We first initialize empty chunks, and fill them with augmented data from each organ, evenly split amongst the chunks.

In [None]:
progress = get_progress()

if not progress['augmented']:
    for n in range(num_chunks):
        chunk = torch.zeros(0, d, d)
        torch.save(chunk, join(aug_data_dir, f'unaugmented_chunk_{n}.pt'))

    bar = tqdm(total=len(organs)*num_chunks)
    bar.set_description(f'Augmentation')
    
    try:
        for organ in organs:
            data = torch.load(join(aug_data_dir,organ+ '_slices_unaugmented.pt'))
            N = data.shape[0]
            idx = torch.randperm(N)
            data = data[idx]
            split_idx = int(N/num_chunks)
            
            for n in range(num_chunks):
                path = join(aug_data_dir, f'unaugmented_chunk_{n}.pt')
                chunk = torch.load(path)
                chunk = torch.cat((chunk, data[n*split_idx:(n+1)*split_idx]), 0)
                torch.save(chunk, path)
                bar.update()

        print('Augmentation was successful.')

        # Change state of progression.json
        progress['augmented'] = True
        set_progress(progress)
            
    except KeyboardInterrupt:
        print ('Manually stopped.\nChunk', str(n), 'was not saved.')
        bar.close()
        
else:
    print('Data augmentation and chunking already completed.')

We inspect the dataset.

In [None]:
some_data = torch.load(os.path.join(aug_data_dir,'unaugmented_chunk_1.pt'))
np.shape(some_data)

---

## 3. Dataset Management
We define a custom dataset class.

In [None]:
from torch.utils.data import Dataset
import gzip

class CT_Dataset(Dataset):
    def __init__(self, lod=9):
        assert(lod < 10)
        self.data = {}
        self.current_lod = lod

        for lvl in range(lod + 1):
            x = self.load(lvl, 'scans', np.float16, factor=1.0)
            t = self.load(lvl, 'labels', np.uint8, factor=(1.0 / 255.0))
            self.data['lod_{}'.format(lvl)] = {'x': x, 't': t}

    def load(self, lvl, prefix, dtype, factor):
        resolution = 2**lvl
        path = join(raw_data_dir,'_data_{}_res_{}.gz'.format(prefix, resolution))
        
        with gzip.open(path, 'rb') as f:
            data = f.read()
        
        data = np.frombuffer(data, dtype=dtype)
        samples = int(data.shape[0] / resolution**2)
        data = data.reshape((samples, resolution, resolution))
        data = np.expand_dims((data * factor).astype(np.float32), axis=1)
        return data

    def set_lod(self, lod):
        assert(lod < 10)
        self.current_lod = lod

    def __len__(self):
        return len(self.data['lod_0']['x'])
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        return {'x': self.data['lod_{}'.format(self.current_lod)]['x'][idx],
                't': self.data['lod_{}'.format(self.current_lod)]['t'][idx]}

---

# 4. Creating Model Architecture
We first import the necessary architecture-components from pytorch.

In [None]:
from torch.nn import Module, ModuleList, Conv2d, ConvTranspose2d, Linear

We define our model architecture for the segmenting variational autoencoder, which consists of an encoder and two separate decoders, so it makes sense to define each component as their own class, i.e. pytorch module. We start with the encoder.

In [None]:
class Encoder(Module):
    def __init__(self, shape, initial):
        super(Encoder, self).__init__()
        self.shape = shape
        self.initial = initial

        self.downsample_block = ModuleList()
        self.mu_encoding_block = ModuleList()
        self.var_encoding_block = ModuleList()

        for level in range(int(np.log2(self.shape[1])-1)):
            if level == 0:
                self.downsample_block.append(
                    Conv2d(in_channels=self.shape[0], out_channels=initial, kernel_size=4, stride=2, padding=1))
            else:
                self.downsample_block.append(
                    Conv2d(in_channels=initial*2**(level-1), out_channels=initial*2**(level), kernel_size=4, stride=2, padding=1))

            features, length = self.initial*2**(level+2), int(2**(level+2))
            
            self.mu_encoding_block.append(
                Linear(in_features=features, out_features=length))
            
            self.var_encoding_block.append(
                Linear(in_features=features, out_features=length))
    
    def encode(self, x, scale):
        for layer in range(scale):
            x = self.downsample_block[layer](x)
            
        x = self.downsample_block[scale](x)
        shape = x.shape
        x = torch.flatten(x, start_dim=1)
        mu = (self.mu_encoding_block[scale](x))
        log_var = (self.var_encoding_block[scale](x))
        
        return mu, log_var, shape

    def sample(self, mu, log_var):
        std = torch.exp(0.5 * log_var) # Reparameterization.
        epsilon = torch.randn_like(std)
        return mu + std * epsilon # z

    def forward(self, x, scale):
        mu, log_var, shape = self.encode(x, scale)
        z = self.sample(mu, log_var)
        return z, mu, log_var, shape

And then the decoder architecture.

In [None]:
class Decoder(Module):
    def __init__(self, shape, initial, act):
        super(Decoder, self).__init__()
        self.act = act
        self.shape = shape
        self.initial = initial

        self.conv_upsample_block = ModuleList()
        self.z_decoding_block = ModuleList()

        for level in range(int(np.log2(self.shape[1])-1)):
            if level == 0:
                self.conv_upsample_block.append(
                    ConvTranspose2d(in_channels=initial, out_channels=self.shape[0], kernel_size=4, stride=2, padding=1))
            else:
                self.conv_upsample_block.append(
                    ConvTranspose2d(in_channels=int(initial*2**(level)), out_channels=int(initial*2**(level-1)), kernel_size=4, stride=2, padding=1))

            features, length = self.initial*2**(level+2), int(2**(level+2))

            self.z_decoding_block.append(
                Linear(in_features=length, out_features=features))
    
    def decode(self, z, scale, shape):
        x = self.act(self.z_decoding_block[scale](z)).reshape(shape)
        
        for layer in range(scale): # Transposed convolutions.
            x = self.act(self.conv_upsample_block[scale-layer](x))
        x = self.conv_upsample_block[0](x)

        return x

    def forward(self, z, scale, shape, detach=False):
        if detach: z = z.detach()
        x = self.decode(z, scale, shape)
        return x

Finally, we use these two in combination to create our variational autoencoder.

In [None]:
class VAE(Module):
    def __init__(self, shape, initial, act):
        super(VAE, self).__init__()
        self.encoder = Encoder(shape, initial)
        self.decoder = Decoder(shape, initial, act)
        self.segmenter = Decoder(shape, initial, act)

    def forward(self, x, lod, print=False):
        lod = lod - 2
        z, mu, log_var, shape = self.encoder.forward(x, lod)
        
        x_reconst = torch.sigmoid(
            self.decoder.forward(z, lod, shape))
        
        x_segment = torch.sigmoid(
            self.segmenter.forward(z, lod, shape, detach=True))

        if print: print(z)
        
        return ((x_reconst, mu, log_var), x_segment)

# 5. Defining Training- and Evaluation Routines
We define a class, which contain our testing routine, checkpoint management, evaluation routine, and some utility functions. This will allow us to easily run tests with different hyperparameters. 

In [None]:
class Test():
    def __init__(self, model, dir, load, optimizer, criterions, iou_thresh):
        self.model = model
        self.dir = dir
        self.optimizer = optimizer
        self.device = device
        self.criterions = criterions
        self.io_thresh = iou_thresh
        
        if load: 
            self.load_checkpoint(os.path.join(self.dir, 'checkpoint.pt'))
        
        self.model.to(self.device)

    def train(self, dataloader, epochs, lod, print_at=10):
        for epoch in range(epochs):
            batch_count = 0
            losses = np.array([0] * len(self.criterions), dtype=np.float32)

            for batch in dataloader:
                print(np.shape(batch))
                y = self.model.forward(batch['x'].to(self.device), lod)

                # Backpropagation
                self.model.zero_grad()
                for i, value in enumerate(zip(y, self.criterions, batch)):
                    output, criterion, key = value
                    loss = criterion(output, batch[key].to(self.device))
                    loss.backward()
                    losses[i] += loss

                self.optimizer.step()
                batch_count += 1

            if ((epoch % print_at) == 0):
                losses = losses / batch_count
                print('Epoch: ' + epoch, 'Reconst/kld loss: ' + losses[0], 
                      'Seg loss: ' + losses[1])

    def IoU(self, label, reconst):
        i = ((label >= self.iou_thresh) & (reconst >= self.iou_thresh)) * 1.0
        u = ((label >= self.iou_thresh) | (reconst >= self.iou_thresh)) * 1.0
        return i.sum() / u.sum() / label.shape[0]

    def evaluate(self, dataloader, lod):
        batch_count = 0
        losses = np.array([0] * len(self.criterions), dtype=np.float32)
        iou = 0
        for batch in dataloader:
            y = self.model.forward(batch['x'].to(self.device), lod)

            for i, value in enumerate(zip(y, self.criterions, batch)):
                output, criterion, key = value
                loss = criterion(output, batch[key].to(self.device))
                losses[i] += loss
            
            iou += self.IoU(batch['t'].to(self.device), y[1])
            batch_count += 1
        
        losses /= batch_count
        iou /= batch_count
        print('Reconst/kld loss:', losses[0], 'Seg loss:', losses[1], 'IoU:', iou)
    
    def load_checkpoint(self, path):
        cp = torch.load(path)
        self.model.load_state_dict(cp['model_state_dict'])
        self.optimizer.load_state_dict(cp['optimizer_state_dict'])
        
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

    def save_checkpoint(self):
        path = os.path.join(self.dir, 'checkpoint.pt')
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

    def reconstruct(self, dataloader, lod, count):
        batch = next(iter(dataloader))
        count = min(count, len(batch['x']))
        x = batch['x'][:count].to(self.device)
        t = batch['t'][:count].to(self.device)
        y = self.model.forward(x, lod)
        x_reconst = y[0][0]
        x_segment = y[1]
        return x, t, x_reconst, x_segment

    def save_reconst(self, dataloader, lod, count, output_dir):
        x_ins, t_ins, x_outs, t_outs = self.reconstruct(dataloader, lod, count)
        x_ins = x_ins.detach().cpu().numpy()
        t_ins = t_ins.detach().cpu().numpy()
        x_outs = x_outs.detach().cpu().numpy()
        t_outs = t_outs.detach().cpu().numpy()
        
        for i, value in enumerate(zip(x_ins, t_ins, x_outs, t_outs)):
            x_in, t_in, x_out, t_out = value

            x = np.stack([x_in]*3, axis=0).squeeze().transpose((1,2,0))
            t = np.stack([t_in]*3, axis=0).squeeze().transpose((1,2,0))
            
            mask = t[..., 0] > self.iou_thresh
            x[mask] = np.array([0, 1, 0.5])*t[mask]
            
            plt.imsave(join(output_dir, 
                'input_' + str(i) + 'resolution_' + str(2**lod) + '.png'), x)

            x_r = np.stack([x_out]*3, axis=0).squeeze().transpose((1,2,0))
            t_r = np.stack([t_out]*3, axis=0).squeeze().transpose((1,2,0))
            
            mask = t_r[..., 0] > self.iou_thresh
            x_r[mask] = np.array([0, 1, 0.5])*t[mask]
            
            plt.imsave(join(output_dir,
                'output_' + str(i) + 'resolution_' + str(2**lod) + '.png'), x_r)

We also define a custom loss function.

In [None]:
def loss_function(output, x):
    recon_x, mu, log_var = output
    batchSize = mu.shape[0]
    rl = (recon_x - x).pow(2).sum() / batchSize
    kld = -0.5*torch.sum(1+log_var-mu.pow(2)-log_var.exp()) / batchSize
    return rl + kld

---

# 6. Putting it all together
We pass our custom datasets to pytorch dataloaders. 

In [None]:
from torch.utils.data import DataLoader

train_data = torch.load(os.path.join(aug_data_dir,'unaugmented_chunk_1.pt'))
test_data = torch.load(os.path.join(aug_data_dir,'unaugmented_chunk_2.pt'))

batchSize = 64
num_workers = 2

train_loader = DataLoader(
    dataset=train_data, batch_size=batchSize,
    shuffle=True, num_workers=num_workers
)

test_loader = DataLoader(
    dataset=test_data, batch_size=batchSize, 
    shuffle=True, num_workers=num_workers
)


In [None]:
from torch.nn.functional import elu
from torch.optim import Adam

model = VAE((1, d, d), 16, elu)

test = Test(
    model=model,
    dir=model_dir,
    load=False, 
    optimizer=Adam(model.parameters(), lr=0.001), 
    criterions=[loss_function, torch.nn.BCELoss()], 
    iou_thresh=0.2)

In [None]:
epochs = 10
lod = 6
iterations = 1

for i in range(iterations):
    test.train(train_loader, epochs, lod, 1)
    test.evaluate(test_loader, lod)
    test.save_reconst(test_loader, lod, 10, output_dir)
    #test.save_checkpoint()