# MultiMorph 3D Atlas Construction Demo


This demo will cover building a 3D atlas using a pre-trained model.

#### Load libraries and dependencies

In [None]:
!pip install neurite

In [None]:
!git clone https://github.com/mabulnaga/multimorph.git
%cd multimorph
import sys
sys.path.append('/content/multimorph/src')

In [None]:
import os
os.environ['NEURITE_BACKEND'] = 'pytorch'
os.environ['VXM_BACKEND'] = 'pytorch'
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import argparse
import models as models
import pandas as pd
from dataloader import SubGroupLoader3D, PadtoDivisible
from typing import Tuple
import nibabel as nib
from build_atlas_inference import load_model, build_atlas
import matplotlib.pyplot as plt

## Specify the locations of your data
The data should be organized as a CSV with two headers, one for the full path to the 3D images and one for the segmentations. The segmentations are optional. We will use a small dataset of 4 OASIS-3 images for demonstration. We also specify the CSV header names pointing to the images, and the segmentations. Note that the segmentations are optional.

In [None]:
csv_path = 'data/oasis_3d_data/metadata.csv'
img_header_name = 'img_path'
segmentation_header_name = 'segmentation_path'
atlas_save_path = 'results/'

Specify the location of the model weights

In [None]:
model_path = 'models/model_cvpr.pt'

## Prepare the dataset

In [None]:
csv_data = pd.read_csv(csv_path)
dataset = SubGroupLoader3D(data=csv_data[img_header_name].tolist(), labels=None,
                                   segmentations=csv_data[segmentation_header_name].tolist(),
                                   file_names=None, segmentation_to_one_hot=False,
                                   )

img_size = dataset._get_img_size()
print(f'Image size: {img_size}')

### Load the model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mmnet = load_model(model_path, img_size)
mmnet = mmnet.to(device)

### Build the Atlas

In [None]:
atlas, atlas_segmentation = build_atlas(mmnet, dataset, device)

### Visualize and Save the atlas

In [None]:
os.makedirs(atlas_save_path, exist_ok=True)
nib.save(nib.Nifti1Image(atlas.squeeze().numpy(), np.eye(4)), os.path.join(atlas_save_path, 'atlas.nii.gz'))
nib.save(nib.Nifti1Image(atlas_segmentation.squeeze().numpy(), np.eye(4)), os.path.join(atlas_save_path, 'atlas_segmentation.nii.gz'))


In [None]:
# Plot a central slice of the atlas and atlas_segmentation as a 1x2 subfigure
atlas = atlas.squeeze()
atlas_segmentation = atlas_segmentation.squeeze()
fig, axes = plt.subplots(1, 2)
# Get the central slice index along the y-axis
central_slice = atlas.shape[1] // 2

# Plot atlas
axes[0].imshow(atlas.squeeze()[:,  central_slice,:], cmap='gray')
axes[0].set_title('Atlas')
axes[0].axis('off')

# Plot atlas_segmentation
axes[1].imshow(atlas_segmentation.squeeze()[:,  central_slice,:], cmap='turbo')
axes[1].set_title('Atlas Segmentation')
axes[1].axis('off')

plt.tight_layout()
plt.show()