# TF to Torch Conversion

## Purpose
Reproduce the original SynthMorph demo in Torch.
- Data generation
- Registration model training
- Registration (inference) examples  

In [None]:
%load_ext autoreload
%autoreload 2

: 

In [None]:
import vte.experiments.voxel_morph.model.synthmorph as models
import vte.experiments.voxel_morph.model.synthmorph_new as new
import vte.experiments.voxel_morph.datamodule.synth as datamodule
from vte.experiments.voxel_morph.synthmorph_utils import(
    conform, post_predict, image_to_numpy,\
    invert_grayscale, overlay_images,\
    plot_array_row, superimpose_circles,\
    convert_to_single_rgb, rotate
)
# import vte.experiments.voxel_morph.model2 as model2
# import vte.experiments.voxel_morph.layers as layers
from pathlib import Path
import urllib.request
from PIL import Image
from matplotlib import pyplot as plt
from cv2 import resize
import numpy as np 
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl

## 1. Synthetic Images
Currrently still using NumPy code to generate synthetic labels and images, Torch implementation is still WIP.

### 1.1. Generate Label (i.e. Segmentation) Map

In [None]:
# Input shapes.
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 16

In [None]:
label_map = datamodule.generate_map(in_shape, num_label)
plt.imshow(label_map, cmap='tab20c')

### 1.2. Generate Image from Label Map

In [None]:
n = 4
gen_images = [datamodule.map_to_image(label_map) for _ in range (n)]

# Set up the subplot layout
fig, axes = plt.subplots(1, n, figsize=(n*3, 3))

# Iterate over the images and plot them in the subplots
for i in range(n):
    axes[i].imshow(gen_images[i][1], cmap='gray')
    axes[i].axis('off')

# Adjust the spacing between subplots
plt.subplots_adjust(wspace=0.05)

# Display the plots
plt.show()


In [None]:
# Warning: One label map takes about 6 seconds to generate
# WIP to use Torch instead of NumPy for label map  generation
size= 40
in_shape = (256,) * 2
num_labels = 16
train_data = datamodule.SynthMorphOnlineDataset(
    size=size,
    input_size=in_shape,
    num_labels=num_labels,
)

dataloader = DataLoader(
    dataset=train_data,
    batch_size=1,
    num_workers=8,
)

## Model Training

In [None]:
# Fresh model
in_shape = (256,) * 2
unet_enc_nf = [256] * 4
unet_dec_nf = [256] * 8
model = new.SynthMorph(
    vol_size=in_shape,
    num_labels=train_data.num_labels,
    enc_nf=unet_enc_nf,
    dec_nf=unet_dec_nf,
    lmd=1,
)

In [None]:
# # Model from checkpoint
# checkpoint_path = './lightning_logs/version_67/checkpoints/epoch=9999-step=400000.ckpt'
# in_shape = (256,) * 2
# unet_enc_nf = [256] * 4
# unet_dec_nf = [256] * 8
# model = model.load_from_checkpoint(
#     checkpoint_path, 
#     vol_size=in_shape,
#     num_labels=train_data.num_labels,
#     enc_nf=unet_enc_nf,
#     dec_nf=unet_dec_nf,
#     lmd=1, 
# )

In [None]:
max_epochs = 10000
steps = 40
max_steps = max_epochs * steps
trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=max_epochs,
    max_steps=max_steps,
)


In [None]:
trainer.fit(model=model, train_dataloaders=dataloader)

## Model Inference and Evaluation

In [None]:
# Fresh model
in_shape = (256,) * 2
unet_enc_nf = [256] * 4
unet_dec_nf = [256] * 8
model = new.SynthMorph(
    vol_size=in_shape,
    num_labels=train_data.num_labels,
    enc_nf=unet_enc_nf,
    dec_nf=unet_dec_nf,
    lmd=1,
)

In [None]:
# Model from best checkpoint
checkpoint_path = './lightning_logs/version_64/checkpoints/epoch=4999-step=200000.ckpt'
in_shape = (256,) * 2
unet_enc_nf = [256] * 4
unet_dec_nf = [256] * 8
model = model.load_from_checkpoint(
    checkpoint_path, 
    vol_size=in_shape,
    num_labels=16,
    enc_nf=unet_enc_nf,
    dec_nf=unet_dec_nf,
    lmd=1, 
)

In [None]:
model = model.cuda()   # temporary fix
model = model.eval()

### MNIST

In [None]:
mnist = MNIST(root= "./data", train=False, download=True)

In [None]:
images = np.array(mnist.data)
labels = np.array(mnist.targets)
indices = np.arange(len(labels))
# Dictionary of indices based on target labels
label_indices_dict = {}
unique_labels = np.unique(labels)
for label in unique_labels:
    label_indices_dict[label] = np.where(labels == label)[0].tolist()

In [None]:
# Example prediction
in_shape = (256,) * 2
digit = 0
indices = label_indices_dict[digit]
ori_moving = images[indices[342]]
ori_fixed = images[indices[233]]
moving = conform(x=ori_moving, in_shape=in_shape)
fixed = conform(x=ori_fixed, in_shape=in_shape)
moved, warp = model.predict_step(moving, fixed)
moved, warp = post_predict(moved), post_predict(warp)
moving, fixed = np.squeeze(moving), np.squeeze(fixed)

In [None]:
movement_plot = [moving, fixed, moved]
movement_headers = ['Moving', 'Fixed', 'Moved']
plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
from scipy.ndimage.filters import gaussian_filter
warp_plot = [warp[0, ...], warp[1, ...]]
warp_headers=['Warp X-Axis', 'Warp Y-Axis']
plot_array_row(warp_plot, warp_headers, cmap='gray')

### Oasis-1 2D (Brain dataset)

In [None]:
oasis_path = Path.home() / "oasis_2d"
oasis_path.mkdir(exist_ok=True)
filename = oasis_path / '2D-OASIS-TUTORIAL.npz'
if not filename.exists():
    url = 'https://surfer.nmr.mgh.harvard.edu/pub/data/voxelmorph/2D-OASIS-TUTORIAL.npz'
    urllib.request.urlretrieve(url, filename)
oasis_data = np.load(filename)['images']


In [None]:
in_shape = (256,) * 2
ori_moving = oasis_data[20]
ori_fixed = oasis_data[1]
moving = conform(x=ori_moving, in_shape=in_shape)
fixed = conform(x=ori_fixed, in_shape=in_shape)
moved, warp = model.predict_step(moving, fixed)
moved, warp = post_predict(moved), post_predict(warp)
moving, fixed = np.squeeze(moving), np.squeeze(fixed)

In [None]:
movement_plot = [moving, fixed, moved]
movement_headers = ['Moving', 'Fixed', 'Moved']
plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
warp_plot = [warp[0, ...], warp[1, ...]]
warp_headers=['Warp X-Axis', 'Warp Y-Axis']
plot_array_row(warp_plot, warp_headers, cmap='gray')

### Affine registration test

#### Superimposed circles on synthetic image

In [None]:
# Label generation
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 16
label_map = datamodule.generate_map(in_shape, num_label)

In [None]:
# Image generation
ori_image = np.squeeze(datamodule.map_to_image(label_map)[1])
image = ori_image.copy()


In [None]:
# Superimpose circles on image
pixel_value = 255
size_range = (0.030, 0.030)
dist_range = (70, 71)
rotate = 0
x_shift = 0
y_shift = 0
superimposed = superimpose_circles(
    image, 
    pixel_value, 
    size_range, 
    dist_range, 
    rotate,
    x_shift,
    y_shift,
)

superimposed_array = [ori_image, superimposed]
superimposed_headers = ['Original', 'Superimposed']
plot_array_row(superimposed_array, superimposed_headers, cmap='gray')

In [None]:
# Registration using random mask
mask = np.zeros(shape=(256, 256), dtype=np.float32)
mask = superimpose_circles(
    mask, 
    pixel_value=255, 
    size_range=(0.030, 0.030), 
    dist_range=(70, 71), 
    rotate=0,
    x_shift=3,
    y_shift=3,
)
moving = conform(x=mask, in_shape=in_shape)
fixed = conform(x=superimposed, in_shape=in_shape)

moved, warp = model.predict_step(moving, fixed)
moved, warp = post_predict(moved), post_predict(warp)
moving, fixed = np.squeeze(moving), np.squeeze(fixed)  # post-process for plotting

In [None]:
movement_plot = [moving, fixed, moved]
movement_headers = ['Moving', 'Fixed', 'Moved']
plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
rgb_fixed = convert_to_single_rgb(fixed, 'red')
rgb_moving = convert_to_single_rgb(moving, 'green')
rgb_moved = convert_to_single_rgb(moved, 'blue')

overlay_before = overlay_images(rgb_fixed, rgb_moving)
overlay_after = overlay_images(rgb_fixed, rgb_moved)

overlay_plot = [overlay_before, overlay_after,]
overlay_headers = ['Fixed and Moving', 'Fixed and Moved']
plot_array_row(overlay_plot, overlay_headers, cmap=None)

In [None]:
warp_plot = [warp[0, ...], warp[1, ...]]
warp_headers=['Warp X-Axis', 'Warp Y-Axis']
plot_array_row(warp_plot, warp_headers, cmap='gray')


### Test Pre-trained Weights
Weights link (Keras .h5): <br>
https://drive.google.com/uc?id=1xridvtyEWgWsWJPYVrQfDCtSgbj2beRz