# SynthMorph Affine PyTorch Demo
## Purpose
Reproduce the original affine components of SynthMorph demo in Torch.
- Data generation with affine augmentations
- Affine registration model training
- Registration (inference) examples  

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import urllib
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np 
import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl

# local code
from synthmorph import networks, models, layers, losses, datamodule as dm, utils

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'     # note: only gpu has been tested so far
torch.set_default_device(device)
# mp.set_start_method('spawn')

## SynthMorph Affine Generation

### Generate Label (i.e. Segmentation) Map

In [None]:
# Input shapes.
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 2
label_map = dm.generate_map(in_shape, num_label, device=device)
n = 4
affine_args = dict(
    translate=(0.15, 0.15),
    scale=(0.5, 0.7)
)
gen_args = dict(
    warp_std=0,
    warp_res=(8, 16, 32),
    zero_background=1,
    affine_args=affine_args,
    # remove later
    mean_min=255,
    mean_max=255,
    std_min = 0,
    std_max = 0,
    bias_std=0,
    blur_std=0,
    gamma_std=0,
    dc_offset=0,
)

gen = [dm.labels_to_image(label_map, **gen_args) for _ in tqdm(range(n))]
gen_images = [g['image'] for g in gen]
gen_labels= [g['label'] for g in gen]

plot_num = min(n, 4)
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))

for i in range(plot_num):
    image = gen_images[i].squeeze().tolist()
    axes[i].imshow(image, cmap='gray')
    axes[i].axis('off')

plt.subplots_adjust(wspace=0.05)
plt.show()

In [None]:
# Plot each label of an image
ind = 0
image = gen_images[ind].squeeze().tolist()
labels = gen_labels[ind].squeeze().tolist()
plot_num = gen_labels[ind].shape[0] + 1
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))
axes[0].imshow(image, cmap='gray')
axes[0].axis('off')
for c in range(1, plot_num):
    ax = axes[c]
    l = labels[c - 1]
    ax.imshow(l, cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])

plt.subplots_adjust(wspace=0.05)
plt.show()

## Model Training

In [None]:
# Training data generator
size = 100
in_shape = (256,) * 2
num_labels = 16
scale = (0.5, 0.7)
affine_args = dict(
    scale=scale,
    translate=(round((1 - max(scale))/2, 3),) * 2
)
gen_args = dict(
    warp_std=0,
    warp_res=(8, 16, 32),
    zero_background=1,
    affine_args=affine_args,
    # # remove later
    # mean_min=255,
    # mean_max=255,
    # std_min = 0,
    # std_max = 0,
    # bias_std=0,
    # blur_std=0,
    # gamma_std=0,
    # dc_offset=0,
)

train_data = dm.SMShapesDataset(
    size=size,
    input_size=in_shape,
    num_labels=num_labels,
    gen_args=gen_args,
)

# dataloader_kwargs = {
#     'num_workers': 8,
#     'persistent_workers': True,
#     'pin_memory': True,
# } if device == 'cuda' else {}
dataloader_kwargs = {}


dataloader = DataLoader(
    dataset=train_data,
    batch_size=4,
    shuffle=True,
    generator=torch.Generator(device=device),
    **dataloader_kwargs
)


In [None]:
# Debug num workers > 0 causing ones and zeros tensors 
dataloader_out = next(iter(dataloader))["fixed"]
print(dataloader_out.abs().sum())   # should be a positive number > 0

In [None]:
# You can generate the Torch version of the original author's weights from tf2torch.ipynb
# State dict weights for the registration model, different from PL checkpoint
weights_path = Path(".") / 'weights'
# reg_weights = weights_path / 'torch' / "authors.pth"   # 'None' for no weight loading
reg_weights = None
# Fresh model
in_shape = (256,) * 2
enc_nf = [256] * 4
dec_nf = [256] * 0
add_nf = [256] * 4
model = models.SynthMorphAffine(
    vol_size=in_shape,
    enc_nf=enc_nf,
    dec_nf=dec_nf,
    add_nf=add_nf,
    lr=1e-04,
    reg_weights=reg_weights,
)
n_param = utils.torch_model_parameters(model.reg_model)

# Model from checkpoint
checkpoint_path = './lightning_logs/version_81/checkpoints/epoch=499-step=12500.ckpt'
model = models.SynthMorphAffine.load_from_checkpoint(
    checkpoint_path,
    lr=1e-05,
)

In [None]:
max_epochs = 500
steps = train_data.size // dataloader.batch_size
max_steps = max_epochs * steps
trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=max_epochs,
    max_steps=max_steps,
    log_every_n_steps=steps,
    # detect_anomaly=True
)
trainer.fit(model=model, train_dataloaders=dataloader)

## Evaluation

### Affine Synthmorph

In [None]:
# Testing data generator
size = 100
in_shape = (256,) * 2
num_labels = 16
scale = (0.5, 0.7)
affine_args = dict(
    scale=scale,
    translate=(round((1 - max(scale))/2, 3),) * 2
)
gen_args = dict(
    warp_std=0,
    warp_res=(8, 16, 32),
    zero_background=1,
    affine_args=affine_args,
    # # remove later
    # mean_min=255,
    # mean_max=255,
    # std_min = 0,
    # std_max = 0,
    # bias_std=0,
    # blur_std=0,
    # gamma_std=0,
    # dc_offset=0,
)

test_data = dm.SMShapesDataset(
    size=size,
    input_size=in_shape,
    num_labels=num_labels,
    gen_args=gen_args,
)

# dataloader_kwargs = {
#     'num_workers': 8,
#     'persistent_workers': True,
#     'pin_memory': True,
# } if device == 'cuda' else {}
dataloader_kwargs = {}


test_dataloader = DataLoader(
    dataset=test_data,
    batch_size=1,
    shuffle=True,
    generator=torch.Generator(device=device),
    **dataloader_kwargs
)


In [None]:
# Model from checkpoint
checkpoint_path = './lightning_logs/version_82/checkpoints/epoch=499-step=12500.ckpt'
model = models.SynthMorphAffine.load_from_checkpoint(
    checkpoint_path,
)
model = model.cuda().eval()

In [None]:
# Evaluate dice over label maps
eval_size = int(1e3)
dice_arr = np.zeros(shape=(eval_size,))
for i in tqdm(range(eval_size)):
    gen = next(iter(test_dataloader))
    moving = gen['moving']
    fixed = gen['fixed']
    moving_map = gen['moving_map']
    fixed_map = gen['fixed_map']
    moved, warp = model.predict_step(moving, fixed) 
    moved_map = layers.SpatialTransformer(fill_value=0)([networks.torch_to_tf(moving_map), networks.torch_to_tf(warp)])
    moved_map = networks.tf_to_torch(moved_map.clip(0, 1).round())
    dice = -losses.Dice().loss(fixed_map, moved_map)
    dice_arr[i] = dice.tolist()
dice_arr.mean()

In [None]:
gen = next(iter(test_dataloader))
moving = gen['moving']
fixed = gen['fixed']
moving_map = gen['moving_map']
fixed_map = gen['fixed_map']
moved, warp = model.predict_step(moving, fixed)
moved_np, warp_np = dm.torch2numpy(moved), dm.torch2numpy(warp)
moving_np, fixed_np = dm.torch2numpy(moving), dm.torch2numpy(fixed)
movement_plot = [moving_np, fixed_np, moved_np]
movement_headers = ['Moving', 'Fixed', 'Moved ']
utils.plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
moved_map = layers.SpatialTransformer(fill_value=0)([networks.torch_to_tf(moving_map), networks.torch_to_tf(warp)])
moved_map = networks.tf_to_torch(moved_map).clip(0, 1).round()
dice = -losses.Dice().loss(fixed_map, moved_map).tolist()

moving_map_np = dm.torch2numpy(moving_map[:, 1:, ...].sum(dim=1))
fixed_map_np = dm.torch2numpy(fixed_map[:, 1:, ...].sum(dim=1))
moved_map_np = dm.torch2numpy(moved_map[:, 1:, ...].sum(dim=1))
rgb_fixed = utils.convert_to_single_rgb(fixed_map_np, 'red')
rgb_moving = utils.convert_to_single_rgb(moving_map_np, 'green')
rgb_moved = utils.convert_to_single_rgb(moved_map_np, 'blue')
overlay_before = utils.overlay_images(rgb_fixed, rgb_moving)
overlay_after = utils.overlay_images(rgb_fixed, rgb_moved)
movement_plot = [overlay_before, overlay_after]
movement_headers = ['Moving (green) and Fixed (red)', f'Fixed (red) and Moved (blue)\nDice: {dice:.4f}']
utils.plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
results = model.reg_model(moving, fixed)
moving_keypoints = dm.torch2numpy(results['com_1'])
fixed_keypoints = dm.torch2numpy(results['com_2'])
keypoints = [moving_keypoints, fixed_keypoints]
headers = ["Moving", "Fixed"]
x_mid = in_shape[0] // 2
y_mid = in_shape[1] // 2
utils.plot_keypoints(keypoints, headers, xlim=(-x_mid, x_mid), ylim=(-y_mid, y_mid))


### Superimposed circles


In [None]:
# Background for circles
size = 40
in_shape = (256,) * 2
num_labels = 16
# Note: affine transformations are not needed du
gen_args = dict(
    warp_std=0, # no deformable
    warp_res=(8, 16, 32),   # ignore when warp_std=0
    zero_background=1,
    mean_max=200,   # Prevent sharing too similar intensities as circles (e.g. 255)
    affine_args=None
)

bg_data = dm.SMShapesDataset(
    size=size,
    input_size=in_shape,
    num_labels=num_labels,
    gen_args=gen_args,
)

dataloader_kwargs = {
    'num_workers': 0,
    # 'persistent_workers': True,
    # 'pin_memory': True,
} if device == 'cuda' else {}


bg_dataloader = DataLoader(
    dataset=bg_data,
    batch_size=1,
    shuffle=True,
    generator=torch.Generator(device=device),
    **dataloader_kwargs
)

In [None]:
# Create random image-mask pairs for moving and fixed
bg = next(iter(bg_dataloader))

moving_superimpose = lambda moving: utils.superimpose_circles(
    moving, 
    pixel_value=255, 
    size_range=(0.030, 0.030), 
    dist_range=(50, 51), 
    rotate=0,
    x_shift=0,
    y_shift=0,
)
moving = moving_superimpose(dm.torch2numpy(bg['moving']))
moving = moving_superimpose(np.zeros(in_shape))
moving_map = moving_superimpose(np.zeros(in_shape))

fixed_superimpose = lambda fixed: utils.superimpose_circles(
    fixed, 
    pixel_value=255, 
    size_range=(0.030, 0.030), 
    dist_range=(50, 51), 
    rotate=0,
    x_shift=5,
    y_shift=5,
)
fixed = fixed_superimpose(dm.torch2numpy(bg['fixed']))
fixed = fixed_superimpose(np.zeros(in_shape))
fixed_map = fixed_superimpose(np.zeros(in_shape))

moving = dm.conform(x=moving, in_shape=in_shape, device=device)
fixed = dm.conform(x=fixed, in_shape=in_shape, device=device)

moved, warp = model.predict_step(moving, fixed)
moved, warp = dm.torch2numpy(moved), dm.torch2numpy(warp)
# post-process for plotting
moving, fixed = dm.torch2numpy(moving.squeeze()), dm.torch2numpy(fixed.squeeze())

In [None]:
movement_plot = [moving, fixed, moved]
movement_headers = ['Moving', 'Fixed', 'Moved']
utils.plot_array_row(movement_plot, movement_headers, cmap='gray')

In [None]:
# The labels here are supposed to be the circles only,
#  ignore areas which share the same value
rgb_fixed = utils.convert_to_single_rgb(fixed, 'red')
rgb_moving = utils.convert_to_single_rgb(moving, 'green')
rgb_moved = utils.convert_to_single_rgb(moved, 'blue')

overlay_before = utils.overlay_images(rgb_fixed, rgb_moving)
overlay_after = utils.overlay_images(rgb_fixed, rgb_moved)

overlay_plot = [overlay_before, overlay_after,]
overlay_headers = ['Fixed and Moving', 'Fixed and Moved']
utils.plot_array_row(overlay_plot, overlay_headers, cmap=None)