# VAE-Segmentation
---

This notebook goes through the process of creating a pytorch-compatible dataset, and setting up a model for segmentation of tumors in various organs. It enables reproduceability of our final model and testing results. 

We import the necessary libraries.

In [18]:
# For ML
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils as U
import torch.utils.data as D
import torchvision.transforms as T
import torch.optim as optim
from torch import Tensor

# For reading raw data.
import json
import nibabel as nib

# For displaying and evaluating results.
import numpy as np
import matplotlib.pyplot as plt

# For monitoring resource-usage and progress.
from tqdm import tqdm
import os, psutil

---

Check if GPU is available and retrieve some system stats.

In [4]:
# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)

available_ram = round(psutil.virtual_memory()[0]/1000000000,2)
print('RAM: ' + str(available_ram) + 'GB')


Using cuda
NVIDIA GeForce GTX 1070
CUDA version: 11.7
RAM: 16.74GB


---

Making sure, we are in the correct working directory.

In [5]:
!pwd

/media/nv/Storage/Data-Science/vae_lung_tumor_segmentation


---

Setting up some global constants.

In [6]:
# Location of project directory, relative to the working directory.
root_dir = './' 

# The organs, we wish our model to consider.
organs = ['spleen', 'colon', 'lung']

# New dimensions (width and height) of datapoints.
d = 256 

# Number of chunks to divide our data into.
num_chunks = 10

# Data transforms.
resize_transform = T.Resize((d, d))
rand_rot_transform = T.RandomRotation(180)

---

We setup a folder structure for our data - both raw and preprocessed.

In [10]:
path = root_dir + 'raw_data'
if not os.path.exists(path):
    os.mkdir(path)
    for organ in organs:
        os.mkdir(os.path.join(path, organ))
        
if not os.path.exists('augmented_data'):
    os.makedirs('augmented_data')

The `raw_data/` sub-directories for each organ has to be populated manually using the unzipped files from [medicaldecathlon.com](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2)

---

Load current progress, as to not repeat work, which has already been completed, when running the entire notebook at once. The progress flags are stored externally in `progression.json`.

In [11]:
progress_file_path = f'{root_dir}progression.json'
with open(progress_file_path,'r') as f:
    progression = json.load(f)

---

We define a function which loads and stores our data in the proper formatting. As the datasets are huge, we monitor progress and RAM-usage.

In [12]:
def augment_data(organ, training_paths, data_tensor, progress):
    for _, path in enumerate(training_paths):
        progress.set_postfix(**{'RAM':round(psutil.virtual_memory()[3]/1000000000,2)})
        progress.update()

        # Get path to images - removed dot in path from json-file.
        nii_img = nib.load(f'{root_dir}raw_data/{organ}' + path['image'][1:])
        
        nii_data = nii_img.get_fdata()
        nii_data = Tensor(nii_img.get_fdata())
        
        # Ensure scale [0; 1]
        nii_data -= nii_data.min()
        nii_data /= nii_data.max() # Are the max the same in every data point?
        nii_data = nii_data.permute(2, 0, 1) # Shape: (slice, rows, columns)
        nii_data = resize_transform(nii_data)
        data_tensor = torch.cat((data_tensor, nii_data), 0)
        
    torch.save(data_tensor, f'{root_dir}augmented_data/{organ}_slices_unaugmented.pt')
    progress.close()

We process and format datasets from raw data for each organ, using the above function. We save progress after each organ is completed. Can be interrupted and resumed at any time, and accounts for progress, which has already been made.

In [13]:
for organ in organs:
    if not progression['loaded'][organ]:

        path = f'{root_dir}raw_data/{organ}/dataset.json'
        with open(path) as f:
            data_set = json.load(f)
        training_paths = data_set['training']

        data_tensor = torch.zeros((0, d, d))
        total_paths = len(training_paths)

        progress = tqdm(total=total_paths)
        progress.set_description(f'%s' % organ)

        try: 
            augment_data(organ, training_paths, data_tensor, progress)
            print('The ' + organ + ' was successfully loaded.')

            # Change state of progression.json
            progression['loaded'][organ] = True
            with open(progress_file_path, "w") as f: 
                json.dump(progression, f, indent=4)
            
        except KeyboardInterrupt:
            print ('Manually stopped.\nOrgan: ' + organ + ' was not saved.')
            progress.close()
            break
    else:
        print('The ' + organ + ' set has already been loaded.')

The spleen set has already been loaded.
The colon set has already been loaded.
The lung set has already been loaded.


---

The datasets are so large, that we need to split them into smaller chunks. We first initialize empty chunks.

In [14]:
for n in range(num_chunks):
    chunk = torch.zeros(0, d, d)
    torch.save(chunk, f'augmented_data/unaugmented_chunk_{n}.pt')

And fill them with augmented data from each organ, evenly split amongst the chunks.

In [15]:
if not progression['augmented']:
    progress = tqdm(total=len(organs)*num_chunks)
    progress.set_description(f'Augmentation')
    
    try:
        for organ in organs:
            data = torch.load(f'{root_dir}augmented_data/{organ}_slices_unaugmented.pt')
            N = data.shape[0]
            idx = torch.randperm(N)
            data = data[idx]
            split_idx = int(N/num_chunks)
            
            for n in range(num_chunks):
                path = f'{root_dir}augmented_data/unaugmented_chunk_{n}.pt'
                chunk = torch.load(path)
                chunk = torch.cat((chunk, data[n*split_idx:(n+1)*split_idx]), 0)
                torch.save(chunk, path)
                progress.update()

        print('Augmentation was successful.')

        # Change state of progression.json
        progression['augmented'] = True
        with open(progress_file_path, "w") as f: 
            json.dump(progression, f, indent=4)
            
    except KeyboardInterrupt:
        print ('Manually stopped.\nChunk ' + str(n) + ' was not saved.')
        progress.close()
        
else:
    print('Data augmentation and chunking already completed.')

Data augmentation and chunking already completed.


We define a bunch of utility functions for checkpointing, as training on such large datapoints may require that we break up the training in 

In [16]:
def list_checkpoints(dir):
    epochs = []
    for name in os.listdir(dir):
        if os.path.splitext(name)[-1] == '.pth':
            epochs += [int(name.strip('ckpt_.pth'))]
    return epochs

def save_checkpoint(dir, epoch, model, optimizer=None):
    checkpoint = {}; checkpoint['epoch'] = epoch

    if isinstance(model, torch.nn.DataParallel):
        checkpoint['model'] = model.module.state_dict()
    else:
        checkpoint['model'] = model.state_dict()

    if optimizer is not None:
        checkpoint['optimizer'] = optimizer.state_dict()
    else:
        checkpoint['optimizer'] = None

    torch.save(checkpoint, os.path.join(dir, 'ckpt_%02d.pth'% epoch))

def load_checkpoint(dir, epoch=0):
    if epoch == 0: epoch = max(list_checkpoints(dir))
    checkpoint_path = os.path.join(dir, 'ckpt_%02d.pth'% epoch)
    return torch.load(checkpoint_path, map_location='cpu')

def load_model(dir, model, epoch=0):
    ckpt = load_checkpoint(dir, epoch)
    if isinstance(model, torch.nn.DataParallel):
        model.module.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt['model'])
    return model

def load_optimizer(dir, optimizer, epoch=0):
    ckpt = load_checkpoint(dir, epoch)
    optimizer.load_state_dict(ckpt['optimizer'])
    return optimizer

---

We define our model architecture for the variational autoencoder.

In [19]:
latent_size = 256
w = h = 256 # output of encoder size.

class VAEModel(nn.Module):
    def __init__(self) -> None:
        super(VAEModel, self).__init__()
        self.activation = nn.LeakyReLU(0.05)
        self.encoder = nn.Sequential()
        for i in range(0, 8):
            self.encoder.append(nn.Sequential(
                nn.Conv2d(2**i, 2**(i+1), 4, 2, 1),
                #nn.BatchNorm2d(2**(i+1)),
                self.activation,))

        self.decoder = nn.Sequential()
        for i in range(0, 7):
            self.decoder.append(nn.Sequential(
                nn.ConvTranspose2d(2**(8-i), 2**(7-i), 4, 2, 1),
                #nn.BatchNorm2d(2**(7-i)),
                self.activation,))
        self.decoder.append(nn.Sequential(
            nn.ConvTranspose2d(2, 1, kernel_size = 4, stride = 2, padding=1),
            #nn.BatchNorm2d(1),
            nn.Sigmoid()
            )
        )

        self.fc_mu = nn.Linear(w, latent_size)
        self.fc_log_sigma = nn.Linear(w, latent_size)
        self.latent_de = nn.Linear(latent_size, w)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        
    
    def encode(self, x):
        h = self.encoder(x)
        h = h.view(-1, w)
        mu = self.fc_mu(h)
        log_sigma = self.fc_log_sigma(h)
        return mu, log_sigma

    def decode(self, z):
        z = z.view(-1, w, 1, 1)
        z = self.decoder(z)
        return z

    def reparameterize(self, mu, log_sigma):
        std = torch.exp(0.5*log_sigma)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x):
        x = x.view(-1, 1, w, h) # Adds channel dimension.
        mu, log_sigma = self.encode(x) # Log sigma is more stable (numerically).
        z = self.reparameterize(mu, log_sigma)
        z = self.latent_de(z)
        x_hat = self.decode(z)
        return x_hat.view(-1, w, h), mu, log_sigma

We check if the model output dimensions make sense.

In [20]:
model = VAEModel()
A = torch.randn((10, 256, 256))
B, mu_B, log_sigma_B = model.forward(A)
B.shape, mu_B.shape, log_sigma_B.shape

(torch.Size([10, 256, 256]), torch.Size([10, 256]), torch.Size([10, 256]))

We define a loss function, which will compare actual data to the data, we reconstruct with the variational autoencoder.

In [21]:
mse_loss = nn.MSELoss(reduction='mean')
bce_loss = nn.BCELoss(reduction='mean')
beta = 1; N = w*h; M = w; beta_norm = beta*M/N
print(f'{beta_norm = }')

def loss_fn(x, x_hat, mu, log_var): # input, reconstructed input    
    BCE = bce_loss(x_hat, x) # MSE = mse_loss(x_hat, x)

    # KL divergence.
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    KLD = KLD.mean() # Batch average.
    loss = BCE + 0.000001*KLD # beta_norm*KLD # ELBO loss with beta_norm=0.1
    return loss

beta_norm = 0.00390625


We define a function for drawing a datapoint along with its prediction.

In [23]:
def draw(x, x_hat):
    fig, axs = plt.subplots(1, 2, figsize=(8,5))
    img_0 = x[0].detach().numpy()
    img_1 = x_hat[0].detach().numpy()
    #img = img.reshape((-1, 28, 28)).transpose((1, 0, 2)).reshape(-1, 10*28)
    #img = img * 0.3081 + 0.1307
    axs[0].imshow(img_0, vmin=0, vmax=1, cmap='gray')
    axs[1].imshow(img_1, vmin=0, vmax=1, cmap='gray')
    fig.canvas.draw()

In [None]:
def make_train_loader(idx):
    batch_size = 256
    train_data = torch.load(f'{root_dir}augmented_data/unaugmented_chunk_{idx}.pt')
    train_loader = D.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    return train_loader

In [22]:
def make_dev_loader():
    batch_size = 256
    dev_data = torch.load(f'{root_dir}augmented_data/unaugmented_chunk_{9}.pt')
    dev_loader = D.DataLoader(dev_data, batch_size=batch_size, shuffle=True)
    return dev_loader

In [None]:
def make_loaders_2():
    batch_size = 64
    data = torch.load(f'{root_dir}augmented_data/data_chunk_unaugmented.pt')
    #zero_imgs = torch.zeros((int(data.shape[0]/8), 256, 256)) # regularization data
    #minus_one_imgs = -torch.ones((int(data.shape[0]/8), 256, 256)) # regularization data
    #data = torch.cat((data, zero_imgs, one_imgs), dim=0)
    #data = torch.cat((data, minus_one_imgs), dim=0)
    N = len(data)
    N_t = int(0.9*N)
    N_d = N - N_t
    train_data, dev_data = D.random_split(data, [N_t, N_d])
    train_loader = D.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    dev_loader = D.DataLoader(dev_data, batch_size=batch_size, shuffle=True)
    return train_loader, dev_loader

In [None]:
def plot_grad_flow(named_parameters):
    ave_grads = []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
    plt.plot(ave_grads, alpha=0.3, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.show()

We move our model to the GPU.

In [None]:
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

In [None]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    for i in torch.randperm(10):
        i = int(i)
        train_loader = make_train_loader(i)
        for x in train_loader:
            x = x.to(device)
            optimizer.zero_grad()
            x_hat, mu, log_var = model.forward(x)
            loss = loss_fn(x, x_hat, mu, log_var)
            loss.backward()
            # plot_grad_flow(model.named_parameters())
            optimizer.step()
            losses += loss.item()
    loader_length = num_chunks*len(train_loader) # 10 chunks
    return losses / (loader_length)  # average loss

In [None]:
def evaluate(model, dev_loader):
    model.eval()
    losses = 0

    for x in dev_loader:
        x = x.to(device)
        x_hat, mu, log_var = model.forward(x)
        loss = loss_fn(x, x_hat, mu, log_var)
        losses += loss.item()

    return losses / len(dev_loader)

In [None]:
train_losses = []
dev_losses = []

In [None]:
#model = torch.load('../saved_models/model_13')
#model = model.to(device)

In [None]:
from timeit import default_timer as timer
NUM_EPOCHS = 1000

#train_loader, dev_loader = make_loaders()
#train_loader, dev_loader = make_loaders_2()
dev_loader = make_dev_loader()
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    train_losses.append(train_loss)
    end_time = timer()
    dev_loss = evaluate(model, dev_loader)
    dev_losses.append(dev_loss)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Dev loss: {dev_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

In [None]:
torch.save(model, f'{root_dir}saved_models/model_22')
torch.save(torch.Tensor(train_losses), f'{root_dir}losses/train_losses_model_22')
torch.save(torch.Tensor(dev_losses), f'{root_dir}losses/dev_losses_model_22')
dev_losses_2 = torch.load(f'{root_dir}losses/dev_losses_model_15')
print(dev_losses_2[-1], dev_losses[-1])
print(len(dev_losses_2), len(dev_losses))
plt.plot((dev_losses_2))
plt.plot((dev_losses))
plt.xlim([80, 120])
plt.ylim([0.303, 0.306])

In [None]:
model.eval()
x_test = next(iter(dev_loader))[0].view(1, 256, 256).to(device)
#x_test = torch.zeros((1, 256, 256)).to(device)
x_hat_test = model.forward(x_test)[0]
draw(x_test.cpu(), x_hat_test.cpu())
mse_loss(x_test, x_hat_test)

In [None]:
model.eval()
x_sampled = model.decode(torch.randn(1, latent_size).to(device)).view(1, 256, 256).cpu()
plt.imshow(x_sampled[0].detach().numpy(), vmin=0, vmax=1, cmap='gray')