# SynthMorph PyTorch Demo

## Purpose
Reproduce the original SynthMorph demo in Torch.
- Data generation
- Registration model training
- Registration (inference) examples  

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import urllib.request
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np 
import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl

# local code
from synthmorph import models, layers, losses, datamodule as dm, utils


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'     # note: only gpu has been tested so far
mp.set_start_method('spawn')

## SynthMorph Generation Demo

### Generate Label (i.e. Segmentation) Map

In [None]:
# Input shapes.
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 2

In [None]:
label_map = dm.generate_map(in_shape, num_label, device=device)
label_map_viz = label_map.tolist() if device == "cuda" else label_map.numpy()

plt.imshow(label_map_viz, cmap='tab20c')
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.show()

### Generate Image from Label Map

In [None]:
n = 4
gen_arg = dict(
    warp_std=3,
    warp_res=(8, 16, 32),
)

gen = [dm.labels_to_image(label_map, **gen_arg) for _ in tqdm(range(n))]
gen_images = [g['image'] for g in gen]
gen_labels= [g['label'] for g in gen]

# Set up the subplot layout
plot_num = min(n, 4)
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))
# Iterate over the images and plot them in the subplots

for i in range(plot_num):
    image = gen_images[i].squeeze().tolist()
    axes[i].imshow(image, cmap='gray')
    axes[i].axis('off')

# Adjust the spacing between subplots
plt.subplots_adjust(wspace=0.05)

# Display the plots
plt.show()


In [None]:
# Plot each label of first image 
plot_num = num_label + 1
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))
axes[0].imshow(gen_images[0].squeeze().tolist(), cmap='gray')
axes[0].axis('off')
for i in range(1, plot_num):
    image = gen_labels[0].squeeze().tolist()[i - 1]
    axes[i].imshow(image, cmap='gray')
    axes[i].axis('off')

plt.subplots_adjust(wspace=0.05)
plt.show()

## Model Training

In [None]:
size= 40
in_shape = (256,) * 2
num_labels = 2
gen_arg = dict(
    warp_std=3,
    warp_res=(8, 16, 32),
)

train_data = dm.SynthMorphDataset(
    size=size,
    input_size=in_shape,
    num_labels=num_labels,
    gen_arg=gen_arg,
)
dataloader = DataLoader(
    dataset=train_data,
    batch_size=1,
    num_workers=8,
    shuffle=True,
    persistent_workers=True,
)

In [None]:
# You can generate the Torch version of the original author's weights from tf2torch.ipynb
# State dict weights for the registration model, different from PL checkpoint
reg_weights = Path(".") / "torch_weights.pth"   # 'None' for no weight loading
# Fresh model
in_shape = (256,) * 2
unet_enc_nf = [256] * 4
unet_dec_nf = [256] * 8
model = models.SynthMorph(
    vol_size=in_shape,
    num_labels=train_data.num_labels,
    enc_nf=unet_enc_nf,
    dec_nf=unet_dec_nf,
    lmd=1,
    reg_weights=reg_weights
)
n_param = utils.torch_model_parameters(model.reg_model)

In [None]:
# # Model from checkpoint
# checkpoint_path = './lightning_logs/version_67/checkpoints/epoch=9999-step=400000.ckpt'
# in_shape = (256,) * 2
# unet_enc_nf = [256] * 4
# unet_dec_nf = [256] * 8
# model = model.load_from_checkpoint(
#     checkpoint_path, 
#     vol_size=in_shape,
#     num_labels=train_data.num_labels,
#     enc_nf=unet_enc_nf,
#     dec_nf=unet_dec_nf,
#     lmd=1, 
# )

In [None]:
max_epochs = 100
steps = train_data.size
max_steps = max_epochs * steps
trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=max_epochs,
    max_steps=max_steps,
    log_every_n_steps=steps
)


In [None]:
trainer.fit(model=model, train_dataloaders=dataloader)

## Model  Evaluation

In [None]:
# # Model from best checkpoint
# checkpoint_path = './lightning_logs/version_70/checkpoints/epoch=4913-step=196560.ckpt'
# in_shape = (256,) * 2
# unet_enc_nf = [256] * 4
# unet_dec_nf = [256] * 8
# model = models.SynthMorph.load_from_checkpoint(
#     checkpoint_path, 
#     vol_size=in_shape,
#     num_labels=16,
#     enc_nf=unet_enc_nf,
#     dec_nf=unet_dec_nf,
# )

In [None]:
model = model.cuda()   # temporary fix
model = model.eval()

### SynthMorph Images

In [None]:
gen = next(iter(dataloader))
moving = gen['moving']
fixed = gen['fixed']
moved, warp = model.predict_step(moving, fixed)
moved_np, warp_np = dm.torch2numpy(moved), dm.torch2numpy(warp)
moving_np, fixed_np = dm.torch2numpy(moving), dm.torch2numpy(fixed)

In [None]:
movement_plot = [moving_np, fixed_np, moved_np]
movement_headers = ['Moving', 'Fixed', 'Moved']
utils.plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
warp_plot = [warp_np[0, ...], warp_np[1, ...]]
warp_headers=['Warp X-Axis', 'Warp Y-Axis']
utils.plot_array_row(warp_plot, warp_headers, cmap='gray')

In [None]:
# Evaluate dice score

moving_map = gen['moving_map']
fixed_map = gen['fixed_map']
_, warp = model.predict_step(moving, fixed)
moved_map = layers.SpatialTransformer(fill_value=0)([moving_map, warp]).round().clip(0, 1)
dice = -losses.Dice().loss(fixed_map, moved_map)
dice.tolist()

In [None]:
# Plot moved label
image = moved_np
labels =dm.torch2numpy(moved_map)
plot_num = labels.shape[0] + 1
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))
axes[0].imshow(image.squeeze().tolist(), cmap='gray')
axes[0].axis('off')
for i in range(1, plot_num):
    label = labels.squeeze().tolist()[i - 1]
    axes[i].imshow(label, cmap='gray')
    axes[i].axis('off')

plt.subplots_adjust(wspace=0.05)
plt.show()

### MNIST

In [None]:
mnist = MNIST(root= "./data", train=False, download=True)

In [None]:
images = np.array(mnist.data)
labels = np.array(mnist.targets)
indices = np.arange(len(labels))
# Dictionary of indices based on target labels
label_indices_dict = {}
unique_labels = np.unique(labels)
for label in unique_labels:
    label_indices_dict[label] = np.where(labels == label)[0].tolist()

In [None]:
# Example prediction
in_shape = (256,) * 2
digit = 6
indices = label_indices_dict[digit]
ori_moving = images[indices[455]]
ori_fixed = images[indices[32]]
moving = dm.conform(x=ori_moving, in_shape=in_shape, device=device)
fixed = dm.conform(x=ori_fixed, in_shape=in_shape, device=device)
moved, warp = model.predict_step(moving, fixed)
moved, warp = dm.torch2numpy(moved), dm.torch2numpy(warp)
moving, fixed = moving.tolist(), fixed.tolist()
moving, fixed = np.squeeze(moving), np.squeeze(fixed)

In [None]:
movement_plot = [moving, fixed, moved]
movement_headers = ['Moving', 'Fixed', 'Moved']
utils.plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
warp_plot = [warp[0, ...], warp[1, ...]]
warp_headers=['Warp X-Axis', 'Warp Y-Axis']
utils.plot_array_row(warp_plot, warp_headers, cmap='gray')

### Oasis-1 2D (Brain dataset)

In [None]:
oasis_path = Path.home() / "oasis_2d"
oasis_path.mkdir(exist_ok=True)
filename = oasis_path / '2D-OASIS-TUTORIAL.npz'
if not filename.exists():
    url = 'https://surfer.nmr.mgh.harvard.edu/pub/data/voxelmorph/2D-OASIS-TUTORIAL.npz'
    urllib.request.urlretrieve(url, filename)
oasis_data = np.load(filename)['images']


In [None]:
in_shape = (256,) * 2
ori_moving = oasis_data[2]
ori_fixed = oasis_data[7]
moving = dm.conform(x=ori_moving, in_shape=in_shape, device=device)
fixed = dm.conform(x=ori_fixed, in_shape=in_shape, device=device)
moved, warp = model.predict_step(moving, fixed)
moved, warp = dm.torch2numpy(moved), dm.torch2numpy(warp)
moving, fixed = moving.detach().cpu().numpy(), fixed.detach().cpu().numpy()
moving, fixed = np.squeeze(moving), np.squeeze(fixed)

In [None]:
movement_plot = [moving, fixed, moved]
movement_headers = ['Moving', 'Fixed', 'Moved']
utils.plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
warp_plot = [warp[0, ...], warp[1, ...]]
warp_headers=['Warp X-Axis', 'Warp Y-Axis']
utils.plot_array_row(warp_plot, warp_headers, cmap='gray')

### Affine registration test
Evaluates registration for the following properties of target shapes:
- Size (scaling)
- Coordinates (translation, rotation)

#### Superimposed circles on synthetic image
Note: The label map is only used to create synthetic images which serve as background, only the circles are considered labels (i.e. target shapes)

In [None]:
# Label generation
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 16
label_map = dm.generate_map(in_shape, num_label)

In [None]:
# Image generation
ori_image = np.squeeze(dm.map_to_image(label_map)[1])
image = ori_image.copy()


In [None]:
# Example of superimposing circles on image
pixel_value = 255
size_range = (0.030, 0.030)
dist_range = (70, 71)
rotate = 0
x_shift = 0
y_shift = 0
superimposed = utils.superimpose_circles(
    image, 
    pixel_value, 
    size_range, 
    dist_range, 
    rotate,
    x_shift,
    y_shift,
)

superimposed_array = [ori_image, superimposed]
superimposed_headers = ['Original', 'Superimposed']
utils.plot_array_row(superimposed_array, superimposed_headers, cmap='gray')

In [None]:
# Create random image-mask pairs for moving and fixed
moving_superimpose = lambda moving: utils.superimpose_circles(
    moving, 
    pixel_value=255, 
    size_range=(0.030, 0.030), 
    dist_range=(70, 71), 
    rotate=0,
    x_shift=0,
    y_shift=0,
)
moving= moving_superimpose(np.squeeze(dm.map_to_image(label_map)[1]))

fixed_superimpose = lambda fixed: utils.superimpose_circles(
    fixed, 
    pixel_value=255, 
    size_range=(0.025, 0.035), 
    dist_range=(65, 75), 
    rotate=0,
    x_shift=3,
    y_shift=3,
)
zeros = np.zeros(shape=(256, 256), dtype=np.float32)
fixed_mask = fixed_superimpose(zeros)
fixed = fixed_superimpose(np.squeeze(dm.map_to_image(label_map)[1]))



moving = dm.conform(x=moving, in_shape=in_shape, device=device)
fixed = dm.conform(x=fixed, in_shape=in_shape, device=device)

moved, warp = model.predict_step(moving, fixed)
moved, warp = dm.torch2numpy(moved), dm.torch2numpy(warp)
# post-process for plotting
moving, fixed = moving.detach().cpu().numpy(), fixed.detach().cpu().numpy()
moving, fixed = np.squeeze(moving), np.squeeze(fixed)  

In [None]:
movement_plot = [moving, fixed, moved]
movement_headers = ['Moving', 'Fixed', 'Moved']
utils.plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
# The labels here are supposed to be the circles only,
#  ignore areas which share the same value
rgb_fixed = utils.convert_to_single_rgb(fixed_mask, 'red')
rgb_moving = utils.convert_to_single_rgb(moving, 'green')
rgb_moved = utils.convert_to_single_rgb(moved, 'blue')

overlay_before = utils.overlay_images(rgb_fixed, rgb_moving)
overlay_after = utils.overlay_images(rgb_fixed, rgb_moved)

overlay_plot = [overlay_before, overlay_after,]
overlay_headers = ['Fixed and Moving', 'Fixed and Moved']
utils.plot_array_row(overlay_plot, overlay_headers, cmap=None)

In [None]:
warp_plot = [warp[0, ...], warp[1, ...]]
warp_headers=['Warp X-Axis', 'Warp Y-Axis']
utils.plot_array_row(warp_plot, warp_headers, cmap='gray')