# SynthMorph Affine PyTorch Demo
## Purpose
Reproduce the original affine components of SynthMorph demo in Torch.
- Data generation with affine augmentations
- Affine registration model training
- Registration (inference) examples  

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import urllib
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np 
import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl

# local code
from synthmorph import networks, models, layers, losses, datamodule as dm, utils

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'     # note: only gpu has been tested so far
torch.set_default_device(device)
# mp.set_start_method('spawn')

## SynthMorph Affine Generation

### Generate Label (i.e. Segmentation) Map

In [None]:
in_shape = (256,) * 2
num_labels = 16

In [None]:
# Input shapes.
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 4
label_map = dm.generate_map(in_shape, num_label, device=device)
n = 4
affine_args = dict(
    translate=(0.05, 0.05),
    scale=(0.9, 0.9)
)
gen_args = dict(
    warp_std=0,
    warp_res=(8, 16, 32),
    zero_background=1,
    affine_args=affine_args,
)

gen = [dm.labels_to_image(label_map, **gen_args) for _ in tqdm(range(n))]
gen_images = [g['image'] for g in gen]
gen_labels= [g['label'] for g in gen]

plot_num = min(n, 4)
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))

for i in range(plot_num):
    image = gen_images[i].squeeze().tolist()
    axes[i].imshow(image, cmap='gray')
    axes[i].axis('off')

plt.subplots_adjust(wspace=0.05)
plt.show()

In [None]:
# Plot each label of an image
ind = 1
image = gen_images[ind].squeeze().tolist()
labels = gen_labels[ind].squeeze().tolist()
plot_num = gen_labels[ind].shape[0] + 1
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))
axes[0].imshow(image, cmap='gray')
axes[0].axis('off')
for c in range(1, plot_num):
    ax = axes[c]
    l = labels[c - 1]
    ax.imshow(l, cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])

plt.subplots_adjust(wspace=0.05)
plt.show()

## Affine Model Training

In [None]:
# Training data generator
size=40
in_shape = (256,) * 2
num_labels = 16
affine_args = dict(
    translate=(0.05, 0.05),
    scale=(0.9, 0.9)
)
gen_args = dict(
    warp_std=0, # no deformable
    warp_res=(8, 16, 32),
    zero_background=1,
    affine_args=affine_args,
)

train_data = dm.SMShapesDataset(
    size=size,
    input_size=in_shape,
    num_labels=num_labels,
    gen_args=gen_args,
)
dataloader_kwargs = {'num_workers': 8, 'persistent_workers': True,} if device == 'cuda' else {}
dataloader = DataLoader(
    dataset=train_data,
    batch_size=1,
    shuffle=True,
    generator=torch.Generator(device=device),
    **dataloader_kwargs
)


In [None]:
# You can generate the Torch version of the original author's weights from tf2torch.ipynb
# State dict weights for the registration model, different from PL checkpoint
weights_path = Path(".") / 'weights'
# reg_weights = weights_path / 'torch' / "authors.pth"   # 'None' for no weight loading
reg_weights = None
# Fresh model
in_shape = (256,) * 2
enc_nf = [256] * 4
dec_nf = [256] * 0
add_nf = [256] * 4
model = models.SynthMorphAffine(
    vol_size=in_shape,
    enc_nf=enc_nf,
    dec_nf=dec_nf,
    add_nf=add_nf,
    lr=1e-04,
    reg_weights=reg_weights,
)
n_param = utils.torch_model_parameters(model.reg_model)

In [None]:
max_epochs = 2500
steps = train_data.size
max_steps = max_epochs * steps
trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=max_epochs,
    max_steps=max_steps,
    log_every_n_steps=steps
)
trainer.fit(model=model, train_dataloaders=dataloader)

### Testing

In [None]:
in_shape = (256,) * 2
reg_model = networks.VxmAffineFeatureDetector(
    in_shape=in_shape,
)

In [None]:
from torchvision.transforms import RandomAffine
def create_square_image(image_size, square_size):
    # Create a black background
    image = np.zeros((image_size, image_size), dtype=np.uint8)

    # Set pixels in the square to 255
    start_x = (image_size - square_size) // 2
    start_y = (image_size - square_size) // 2

    image[start_x:start_x + square_size, start_y:start_y + square_size] = 255

    return image

aff_transformer = RandomAffine(degrees=0, translate=(0.20, 0.20))
moving = torch.as_tensor(create_square_image(256, 100), device=device).unsqueeze(0).unsqueeze(1)
moving = dm.minmax_norm(moving)
fixed = aff_transformer(moving)
out = reg_model(moving, fixed)
out

In [None]:
plt.imshow(fixed.squeeze().tolist(), cmap='gray')

In [None]:
def plot_keypoints(coords):
    coords = np.asarray(coords)
    x_coords = coords[:, 0]
    y_coords = coords[:, 1]
    plt.scatter(x_coords, y_coords, marker='o', color='r')

cen_source, cen_target = out
plot_keypoints(cen_source.squeeze().tolist())

In [None]:
plot_keypoints(cen_target.squeeze().tolist())