# VAE-Segmentation
---

This notebook goes through the process of creating a pytorch-compatible dataset, and setting up a model for segmentation of tumors in various organs.

Check if GPU is available.

In [59]:
import torch

# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))


Using cuda
NVIDIA GeForce GTX 1070


---

Import necessary libraries.

In [60]:
import torch.nn.functional as F
import torch.utils as U
import torchvision.transforms as T
from torch import Tensor
import json
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm # Progress bar.
import os

---

Make sure we are in the correct directory.

In [61]:
!pwd

/media/nv/Storage/Data-Science/vae_lung_tumor_segmentation


Specifying the location of the project's directory, in relation to the working directory.

In [62]:
rel_root_dir = './'

---

Setting up some global constants.

In [67]:
organs = ['spleen', 'colon', 'lung']
d = 256 # New dimensions (width and height) of datapoints.

# Data transform definitions.
resize_transform = T.Resize((d, d))
rand_rot_transform = T.RandomRotation(180)

---

Load current progress, as to not repeat work, which has already been completed, when running the entire notebook at once. The progress is stored in `progression.json`.

In [64]:
progress_file_path = f'{rel_root_dir}progression.json'
with open(progress_file_path,'r') as f:
    progression = json.load(f)


---

We define a function which loads and stores our data in the proper formatting.

In [65]:
def augment_data(organ, training_paths, data_tensor, progress):
    for path in enumerate(training_paths):
        progress.update()

        # Get path to images - removed dot in path from json-file.
        nii_img = nib.load(f'{rel_root_dir}raw_data/{organ}' + path['image'][1:])
        nii_data = nii_img.get_fdata()
        nii_data = Tensor(nii_img.get_fdata())
        
        # Ensure scale [0; 1]
        nii_data -= nii_data.min()
        nii_data /= nii_data.max() # Are the max the same in every data point?
        nii_data = nii_data.permute(2, 0, 1) # Shape: (slice, rows, columns)
        nii_data = resize_transform(nii_data)
        data_tensor = torch.cat((data_tensor, nii_data), 0)
        
    torch.save(data_tensor, f'{rel_root_dir}augmented_data/{organ}_slices_unaugmented.pt')
    progress.close()

We generate and save augmented dataset from raw data for each organ, using the above function. We save progress after each organ is completed. Can be interrupted and resumed at any time, and accounts for progress, which has already been made.

In [69]:
for organ in organs:
    if not progression['loaded'][organ]:

        path = f'{rel_root_dir}raw_data/{organ}/dataset.json'
        with open(path) as f:
            data_set = json.load(f)
        training_paths = data_set['training']

        data_tensor = torch.zeros((0, d, d))
        total_paths = len(training_paths)

        progress = tqdm(total=total_paths)
        progress.set_description(f'%s' % organ)

        try: 
            augment_data(organ, training_paths, data_tensor, progress)
            print('The ' + organ + 'successfully loaded.')

            # Change state of progression.json
            progression['loaded'][organ] = True
            with open(progress_file_path, "w") as f: 
                json.dump(progression, f, indent=4)
            
        except KeyboardInterrupt:
            print ('Manually stopped.\nOrgan: ' + organ + ' was not saved.')
            progress.close()
            break
    else:
        print('The ' + organ + ' set has already been loaded.')


The spleen set has already been loaded.
The colon set has already been loaded.
The lung set has already been loaded.


---

We define a bunch of utility functions for checkpointing.

In [71]:
def list_checkpoints(dir):
    epochs = []
    for name in os.listdir(dir):
        if os.path.splitext(name)[-1] == '.pth':
            epochs += [int(name.strip('ckpt_.pth'))]
    return epochs

def save_checkpoint(dir, epoch, model, optimizer=None):
    checkpoint = {}; checkpoint['epoch'] = epoch

    if isinstance(model, torch.nn.DataParallel):
        checkpoint['model'] = model.module.state_dict()
    else:
        checkpoint['model'] = model.state_dict()

    if optimizer is not None:
        checkpoint['optimizer'] = optimizer.state_dict()
    else:
        checkpoint['optimizer'] = None

    torch.save(checkpoint, os.path.join(dir, 'ckpt_%02d.pth'% epoch))

def load_checkpoint(dir, epoch=0):
    if epoch == 0: epoch = max(list_checkpoints(dir))
    checkpoint_path = os.path.join(dir, 'ckpt_%02d.pth'% epoch)
    return torch.load(checkpoint_path, map_location='cpu')

def load_model(dir, model, epoch=0):
    ckpt = load_checkpoint(dir, epoch)
    if isinstance(model, torch.nn.DataParallel):
        model.module.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt['model'])
    return model

def load_optimizer(dir, optimizer, epoch=0):
    ckpt = load_checkpoint(dir, epoch)
    optimizer.load_state_dict(ckpt['optimizer'])
    return optimizer

---