# Segmentation of Medical Scans using Variational VAE's - Part 3/3
This series of notebooks enables reproduceability of our final models and testing results.

The third notebook goes through the process of creating, training and tuning a variational decoder, which will act as a segmenter.

We import some necessary libraries, and check if our GPU is available, while also retrieving some system stats. We need a lot of RAM, because our selected datasets are very large. We setup up some global constants.

In [1]:
# For ML
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torch.optim as optim
from torch import Tensor

# For displaying and evaluating results.
import numpy as np
from matplotlib import pyplot as plt

# For monitoring resource-usage and progress.
from timeit import default_timer as timer
import psutil
from os.path import join

# Our own utility functions, constants and classes.
from utility import CT_Dataset, superimpose, draw

# Our own DL models.
from models import VAEModel, SegmentationModel, Conv, ConvTranspose

# Paths.
root_dir = '../' # Relative to the working directory.
raw_data_dir = join(root_dir, 'raw_data')
prep_data_dir = join(root_dir, 'prep_data')
losses_dir = join(root_dir, 'losses')
models_dir = join(root_dir, 'saved_models')
checkpoint_dir = join(root_dir, 'checkpoints')


# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)

available_ram = round(psutil.virtual_memory()[0]/1000000000,2)
print('RAM: ' + str(available_ram) + 'GB')

Using cuda
NVIDIA GeForce GTX 1070
CUDA version: 11.7
RAM: 16.74GB


We define a new function to create dataloaders - This time splitting the dataset into a training-, development- and testing-set.

In [2]:
def make_seg_loaders(data, batch_size):
    N = len(data); N_train = int(0.8*N); 
    N_dev = int((N - N_train)/2); N_test = int(N - N_train - N_dev)
    train_data, dev_data, test_data = D.random_split(data, [N_train, N_dev, N_test])
    train_loader = D.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    dev_loader = D.DataLoader(dev_data, batch_size=batch_size, shuffle=True)
    test_loader = D.DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, dev_loader, test_loader

In [3]:
resolution = 2**8; batch_size = 32;
dataset = CT_Dataset(prep_data_dir, 'lung', resolution)
train_loader, dev_loader, test_loader = make_seg_loaders(dataset, batch_size)

Debugging nan loss:
- Tried different loss function -> returned non-nan.
- Tried adding eps to both BCE inputs -> no difference.
- Switched `y` and `y_hat` in  `loss = lossfn(y,y_hat)` in training and evaluation loops -> success?

In [4]:
loss_fn = nn.BCELoss(reduction='mean')

In [None]:
vae_model = torch.load(join(models_dir, 'vae_model.pt')).to(device)

In [5]:
seg_model = SegmentationModel(base=16).to(device)
lr = 3e-3
optimizer = optim.AdamW(seg_model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)

In [None]:
def train_epoch(vae_model, seg_model, optimizer, train_loader):
    seg_model.train()
    vae_model.eval()
    losses = 0
    
    for batch in train_loader:
        x = batch[0].to(device) # Batch of images.
        y = batch[1].to(device) # Batch of labels.
        optimizer.zero_grad()
        z = vae_model.forward_latent(x) # Get latent vector from VAE.
        y_hat = seg_model.forward(z) # Reconstruction from new decoder.
        loss = loss_fn(y_hat, y) # Compare reconstruction to label.
        loss.backward()
        optimizer.step()
        losses += loss.item()   
    
    return losses / len(train_loader)  # average loss

In [None]:
def evaluate(vae_model, seg_model, dev_loader):
    seg_model.eval()
    vae_model.eval()
    losses = 0

    for data in dev_loader:
        x = data[0].to(device)
        y = data[1].to(device)
        z = vae_model.forward_latent(x)
        y_hat = seg_model.forward(z)
        loss = loss_fn(y_hat, y)
        losses += loss.item() 
    return losses / len(dev_loader)

Training a single decoder is a lot faster, so we do not need checkpointing.

In [None]:
total_epochs = 10

train_losses = []; dev_losses = []; lrs = []

for epoch in range(1, total_epochs):
    lrs.append(optimizer.param_groups[0]['lr'])
    start_time = timer()
    train_loss = train_epoch(vae_model, seg_model, optimizer, train_loader)
    train_losses.append(train_loss)
    end_time = timer()
    dev_loss = evaluate(vae_model, seg_model, dev_loader)
    dev_losses.append(dev_loss)
    scheduler.step()
    
    print((f"Epoch {epoch}:, Train-loss: {train_loss:.4f}, Dev-loss: {dev_loss:.4f}, "f"Epoch-time = {(end_time - start_time):.3f}s"))

In [None]:
'''' Segmentation evaluation '''
def IoU(label, recon, thresh):
    inter = ((label >= thresh) & (recon >= thresh)) * 1.0
    union = ((label >= thresh) | (recon >= thresh)) * 1.0
    return inter.sum() / union.sum() / label.shape[0]