# Segmentation of Medical Scans using Variational VAE's - Part 3/3
This series of notebooks enables reproduceability of our final models and testing results.

The third notebook goes through the process of creating, training and tuning a variational decoder, which will act as a segmenter.

We import some necessary libraries, and check if our GPU is available, while also retrieving some system stats. We need a lot of RAM, because our selected datasets are very large. We setup up some global constants.

In [1]:
# For ML
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset

# For displaying and evaluating results.
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

# For monitoring resource-usage and progress.
from timeit import default_timer as timer
from tqdm import tqdm # Install ipywidgets to remove warning.
import psutil
from os.path import join


root_dir = '../' # Location of project, relative to the working directory.
raw_data_dir = join(root_dir, 'raw_data')
prep_data_dir = join(root_dir, 'prep_data')
losses_dir = join(root_dir, 'losses')
models_dir = join(root_dir, 'saved_models')
checkpoint_dir = join(root_dir, 'checkpoints')
cmap_seg = ListedColormap(['none', 'red']) # For drawing tumors in red.


# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)

available_ram = round(psutil.virtual_memory()[0]/1000000000,2)
print('RAM: ' + str(available_ram) + 'GB')

  from .autonotebook import tqdm as notebook_tqdm


Using cuda
NVIDIA GeForce RTX 2060 with Max-Q Design
CUDA version: 11.7
RAM: 33.74GB


In [2]:
def superimpose(image, label):
    plt.imshow(torch.squeeze(image), cmap='gray')
    plt.imshow(torch.squeeze(label), cmap=cmap_seg)
    plt.show()

def draw(x, x_hat):
    fig, axs = plt.subplots(1, 2, figsize=(8,5))
    img_0 = x[0].detach().numpy()
    img_1 = x_hat[0].detach().numpy()
    axs[0].imshow(img_0, vmin=0, vmax=1, cmap='gray')
    axs[1].imshow(img_1, vmin=0, vmax=1, cmap='gray')
    fig.canvas.draw()

In [3]:
class CT_Dataset(Dataset):
    def __init__(self, path, organ, resolution):
        self.images = torch.load(join(path, organ + '_image_slices_' + str(resolution) + '.pt'))
        
        self.labels = torch.load(join(path, organ + '_label_slices_' + str(resolution) + '.pt'))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]

        # Adding channel dimension.
        return image[None, :], label[None, :]

    def show_datapoint(self, index):
        image, label = self.__getitem__(index)
        superimpose(image, label)

In [4]:
def make_seg_loaders(data, batch_size):
    N = len(data); N_train = int(0.8*N); 
    N_dev = int((N - N_train)/2); N_test = int(N - N_train - N_dev)
    train_data, dev_data, test_data = D.random_split(data, [N_train, N_dev, N_test])
    train_loader = D.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    dev_loader = D.DataLoader(dev_data, batch_size=batch_size, shuffle=True)
    test_loader = D.DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, dev_loader, test_loader

In [30]:
bce_loss = nn.BCELoss(reduction='mean')

def train_epoch(vae_model, seg_model, optimizer, train_loader):
    seg_model.train()
    vae_model.eval()
    losses = 0
    for batch in train_loader:
        x = batch[0].to(device)
        y = batch[1].to(device)
        optimizer.zero_grad()
        z = vae_model.forward_latent(x)
        y_hat = seg_model.forward(z)
        loss = bce_loss(y, y_hat)
        print(loss)
        loss.backward()
        optimizer.step()
        losses += loss.item()   
    return losses / len(train_loader)  # average loss

In [31]:
def evaluate(vae_model, seg_model, dev_loader):
    seg_model.eval()
    vae_model.eval()
    losses = 0

    for data in dev_loader:
        x = data[0].to(device)
        y = data[1].to(device)
        z = vae_model.forward_latent(x)
        y_hat = seg_model.forward(z)
        loss = bce_loss(y, y_hat)
        losses += loss.item() 
    return losses / len(dev_loader)

In [25]:
resolution = 2**8
batch_size = 32;
dataset = CT_Dataset(prep_data_dir, 'lung', resolution)
train_loader, dev_loader, test_loader = make_seg_loaders(dataset, batch_size)

In [32]:
from models import VAEModel, SegmentationModel, Conv, ConvTranspose
vae_model = torch.load('../saved_models/vae_model.pt').to(device)
seg_model = SegmentationModel(base=16).to(device)
lr = 1e-9
optimizer = optim.AdamW(seg_model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)

In [33]:
total_epochs = 50

train_losses = []; dev_losses = []; lrs = []

for epoch in range(total_epochs):
    lrs.append(optimizer.param_groups[0]['lr'])
    start_time = timer()
    train_loss = train_epoch(vae_model, seg_model, optimizer, train_loader)
    train_losses.append(train_loss)
    end_time = timer()
    dev_loss = evaluate(vae_model, seg_model, dev_loader)
    dev_losses.append(dev_loss)
    scheduler.step()
    print((f"Epoch {epoch}:, Train-loss: {train_loss:.4f}, Dev-loss: {dev_loss:.4f}, "f"Epoch-time = {(end_time - start_time):.3f}s"))

OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 6.00 GiB total capacity; 5.26 GiB already allocated; 0 bytes free; 5.30 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
'''' Segmentation evaluation '''
def IoU(label, recon, thresh):
    inter = ((label >= thresh) & (recon >= thresh)) * 1.0
    union = ((label >= thresh) | (recon >= thresh)) * 1.0
    return inter.sum() / union.sum() / label.shape[0]