# MultiMorph Demo on OASIS-1 (2D)


In [None]:
!git clone https://github.com/mabulnaga/multimorph.git
%cd multimorph/src

### Load Libraries

In [None]:
#!pip install neurite
!pip install monai --no-deps
!pip install git+https://github.com/adalca/neurite.git --force-reinstall --no-deps
!pip install pystrum
#!pip install voxelmorph

In [None]:
# imports
import pathlib
import os

import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from tqdm.notebook import trange, tqdm

os.environ['NEURITE_BACKEND'] = 'pytorch'
os.environ['VXM_BACKEND'] = 'pytorch'
import neurite as ne

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# OASIS-1 Data
download OASIS-1 2D Data

In [None]:
!mkdir -p oasisdata
!wget -q https://surfer.nmr.mgh.harvard.edu/ftp/data/neurite/data/neurite-oasis.2d.v1.0.tar
!tar xf neurite-oasis.2d.v1.0.tar -C oasisdata/
!rm neurite-oasis.2d.v1.0.tar

### Load the OASIS Data
stack the data together to a tensor to create a dataloader from. Specify the path where the oasis data lives

In [None]:
# specify the full path to the OASIS data
oasis_data_path ='oasisdata'

files = [f/'slice_norm.nii.gz' for f in pathlib.Path(oasis_data_path).iterdir() if f.is_dir()]
slices = [torch.from_numpy(nib.load(f).get_fdata())[..., 0] for f in files]
oasis_data = torch.stack(slices, dim=0).float().to(device) # put all data on cuda.

# get the segmentations
seg_files = [f/'slice_seg4.nii.gz' for f in pathlib.Path(oasis_data_path).iterdir() if f.is_dir()]
seg_slices = [torch.from_numpy(nib.load(f).get_fdata())[...,0] for f in seg_files]
oasis_data_segmentation = torch.stack(seg_slices,dim=0).float().to(device)


oasis_data = oasis_data.transpose(2, 1)
oasis_data_segmentation = oasis_data_segmentation.transpose(2,1).unsqueeze(1)

print(oasis_data.shape)
print(oasis_data_segmentation.shape)

### Create a 80/20 split between train and test.
The train data loader will randomly sample between 2 and 12 images per iteration. The test data loader will construct an atlas on the entire test set.

The `GroupDataLoader` samples random subsamples of the dataset. The `SubGroupLoader` samples the entire dataset. We use the `SubGroupLoader` to construct an atlas for the test set.
This data loader assumes the tensors are wrapped in a list.

In [None]:
from dataloader import GroupDataLoader, SubGroupLoader
from torch.utils.data import DataLoader


# number of images to sample at each training iteration
n_inputs_range = [2,12]

#split into train and test (80/20)
train_pct = 0.8
N_data = len(oasis_data)
N_train = int(N_data * train_pct)
N_test = N_data - N_train
range_data = np.arange(0,N_data)
train_idx = np.random.choice(range_data, N_train, replace=False)
test_idx =np.setdiff1d(range_data, train_idx)

# split the image and segmentations into the appropriate data split
oasis_data_train = oasis_data[train_idx,:]
oasis_data_test = oasis_data[test_idx,:]
oasis_data_segmentation_train = oasis_data_segmentation[train_idx,:]
oasis_data_segmentation_test = oasis_data_segmentation[test_idx,:]

# create data loaders for train and test. The GroupDataLoader will randomly sample n_input_ranges image at each iteration.
# for the test data loader, we load the entire test set.

dataset_oasis_train = GroupDataLoader(data=oasis_data_train,labels=np.zeros(N_train), class_labels=[0],
                                      segmentations=oasis_data_segmentation_train, n_inputs_range=n_inputs_range,transform=None)
dataloader_oasis_train = DataLoader(dataset_oasis_train,batch_size=1,shuffle=True)

dataset_oasis_test = SubGroupLoader(data=[oasis_data_test],labels=None, # labels=[np.zeros(N_test)],
                                     segmentations=[oasis_data_segmentation_test], transform=None)
dataloader_oasis_test = DataLoader(dataset_oasis_test, batch_size=1, shuffle=False)



#### Visualize a few samples from the dataloader

In [None]:
# grab two sets of samples from the training dataloader and visualize.
for i in range(0,2):
    sample = next(iter(dataloader_oasis_train))
    images = sample['image']
    segmentations = sample['segmentation']
    # undo one-hot encoding
    segmentations = torch.argmax(segmentations, dim=2, keepdim=True)
    # plot images
    slices = [f for f in images[0, :, 0, ...].cpu().detach().numpy()]
    ne.py.plot.slices(slices,do_colorbars=True)
    # plot segmentations
    slices_seg = [f for f in segmentations[0, :, 0, ...].cpu().detach().numpy()]
    ne.py.plot.slices(slices_seg,do_colorbars=True, cmaps=['turbo']*len(slices_seg))

# Setup the Model and Losses

In [None]:
import models as models
import layers as layers
import losses as losses
import torch.optim as optim


#### get the image size to set up the Spatial Transformer

In [None]:
oasis_img_size = list(map(int, list(oasis_data.shape[1:])))
print(oasis_img_size)

### setup the loss functions
We will use a combination of local NCC loss on image similarity, l2 regularization on the determinant of the deformation field, and Dice loss on the brain structures

In [None]:
img_size = torch.Size(oasis_img_size)

# image and regularization loss
criterion = losses.local_NCC_2d(volshape=img_size, lbd=1) #0.01
#criterion = losses.MinVarAndGrad2d(volshape=img_size, lbd=0.1) #0.01

# segmentation loss
seg_loss = losses.DiceWarpLoss2d(img_size)
lambda_seg = 0.5

mmnet = models.GroupNet(in_channels=1, out_channels=2, img_size=img_size).to(device)  # updated note: have vxms.models.MultiMorph now.

optimizer = optim.Adam(mmnet.parameters(), lr=0.001) #0.01


# Train

In [None]:
nb_epochs = 100
batch_size = 1
data_loader = dataloader_oasis_train
# move the model to the device
mmnet = mmnet.to(device)

pbar = trange(nb_epochs)
loss_hist = np.zeros(nb_epochs)

for i in pbar:
    total_running_loss = list()
    mmnet.train()
    for sample in data_loader:
        images = sample['image'].to(device)
        segmentations = sample['segmentation'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()
        # predict the warp fields
        predw = mmnet(images)
        # compute the loss
        loss = criterion(images, predw)
        loss_seg = seg_loss(segmentations,predw)
        loss = loss + lambda_seg * loss_seg
        total_running_loss.append(loss.item())
        #optimize
        loss.backward()
        optimizer.step()

    # print stats
    m = np.mean(total_running_loss)
    pbar.set_description(f'{m:.5f}')
    loss_hist[i] = m

print('Finished Training')

#### Visualize the training loss

In [None]:
# lets quickly see the training curve
plt.figure(figsize=(15, 7))
plt.subplot(1,2,1)
plt.plot(loss_hist)
plt.ylabel('Loss')
plt.xlabel('Epochs')
plt.title('Training Loss')
plt.grid()


# Visualize the warps and atlas on sampels from the train set

### setup a few plotting and warping functions

In [None]:
from utils import warp_seg, warp_image, warp_grid, setup_grid_tensor, plot_row_slices

### load a sample and plot the images and segmentations after warping

In [None]:
img_size = oasis_img_size
dataset = dataset_oasis_train
warp_dim = img_size

# set up the train data loader
mmnet.to(device)
gen =  iter(DataLoader(dataset,batch_size=1,shuffle=False))

# grab a sample from the data loader
sample = next(gen)
images = sample['image'].to(device)
segs = sample['segmentation'].to(device)
N_images = np.shape(images)[1]

# warp and plot original, warped images
warped, predicted_warp = warp_image(images,mmnet,img_size)
warped_seg = warp_seg(segs, predicted_warp, img_size=img_size)

# plot the original and warped images
plot_row_slices(images, do_colorbars=False, suptitle='Original Images')
plot_row_slices(warped, do_colorbars=False, suptitle='Warped Images')

# plot the original and warped segmentations
plot_row_slices(torch.round(torch.argmax(segs, dim=2, keepdim=True)), do_colorbars=False, cmaps=['turbo']*N_images, suptitle='Original Segmentations')
plot_row_slices(torch.round(torch.argmax(warped_seg, dim=2, keepdim=True)), do_colorbars=False, cmaps=['turbo']*N_images, suptitle='Warped Segmentations')

# Create a grid and plot the warped grid
grids = setup_grid_tensor(N_slices=N_images,spacing=5, img_size=img_size).to(device)
warped_grid = warp_grid(grids, predicted_warp, warp_dim=img_size)
plot_row_slices(warped_grid, do_colorbars=False, suptitle='Warped Grid')

### visualize the resultant atlases, and compare to constructing an atlas by taking the mean

In [None]:
# plot mean shapes
atl = torch.mean(images, dim=1, keepdims=True)
atlw = torch.mean(warped, dim=1, keepdims=True)
slices = [f.cpu().detach().numpy() for f in [atl, atlw]]
titles = ['Atlas Computed by Taking the Mean', 'MultiMorph Atlas']
ne.plot.slices(slices, do_colorbars=True, titles=titles)

# plot mean segmentations
mean_seg = torch.round(torch.argmax(torch.mean(segs,dim=1,keepdim=True), dim=2, keepdims=True))
mean_warped_seg = torch.round(torch.argmax(torch.mean(warped_seg, dim=1, keepdims=True), dim=2, keepdim=True))
slices = [f.cpu().detach().numpy() for f in [mean_seg, mean_warped_seg]]
titles = ['Segmentations of Mean Atlas', 'Segmentations of MultiMorph Atlas']
ne.plot.slices(slices, do_colorbars=True, titles=titles, cmaps=['turbo']*len(slices))

# Construct an Atlas on the Entire Test Set

In [None]:
import time

mmnet = mmnet.to('cpu')
mmnet.eval()
warp_layer = layers.group.Warp2d(img_size)
for sample in dataloader_oasis_test:
    st = time.time()
    images = sample['image'].to('cpu')
    segmentations = sample['segmentation'].to('cpu')
    # predict the warp fields
    predw = mmnet(images)

    # warp and plot original, warped images
    warped = warp_layer(images, predw)
    warped_seg = warp_layer(segmentations, predw)
    #warped_seg = warp_seg(segmentations, predicted_warp, img_size=img_size)

# construct the atlases
atlas = torch.mean(warped, dim=1, keepdims=True)
atlas_segmentation = torch.round(torch.argmax(torch.mean(warped_seg, dim=1, keepdims=True), dim=2, keepdim=True))

et = time.time()
print(f'Atlas Construction took {et-st:.3f} seconds for {images.shape[1]} images')


#### plot the constructed atlases

In [None]:

# plot these two side by side
plt.figure(figsize=(15, 7))
plt.subplot(1,2,1)
plt.imshow(atlas[0, 0, 0, ...].cpu().detach().numpy(), cmap='gray')
plt.axis('off')
plt.title('MultiMorph Atlas')
plt.subplot(1,2,2)
plt.imshow(atlas_segmentation[0, 0, 0, ...].cpu().detach().numpy(), cmap='turbo')
plt.axis('off')
plt.title('Segmentations of MultiMorph Atlas')


### save the atlases (optional)

In [None]:
atlas_path = 'oasis_atlas_2d.nii.gz'
atlas_segmentation_path = 'oasis_atlas_segmentation_2d.nii.gz'
nib.save(nib.Nifti1Image(atlas[0, 0, 0, ...].cpu().detach().numpy(), np.eye(4)), atlas_path)
nib.save(nib.Nifti1Image(atlas_segmentation[0, 0, 0, ...].float().cpu().detach().numpy(), np.eye(4)), atlas_segmentation_path)
