# SynthMorph Affine PyTorch Demo
## Purpose
Reproduce the original affine components of SynthMorph demo in Torch.
- Data generation with affine augmentations
- Affine registration model training
- Registration (inference) examples  

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import urllib
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np 
import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl

# local code
from synthmorph import models, layers, losses, datamodule as dm, utils

In [3]:
with torch.no_grad():
    torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'     # note: only gpu has been tested so far
torch.set_default_device(device)
# mp.set_start_method('spawn')

## SynthMorph Affine Generation

### Generate Label (i.e. Segmentation) Map

In [4]:
in_shape = (256,) * 2
num_labels = 16

In [4]:
# Input shapes.
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 4
label_map = dm.generate_map(in_shape, num_label, device=device)
n = 16
gen_args = dict(
    warp_std=0,
    warp_res=(8, 16, 32),
    zero_background=1,
    affine=True,
)

gen = [dm.labels_to_image(label_map, **gen_args) for _ in tqdm(range(n))]
gen_images = [g['image'] for g in gen]
gen_labels= [g['label'] for g in gen]

plot_num = min(n, 4)
fig, axes = plt.subplots(1, plot_num, figsize=(plot_num*8, 8))

for i in range(plot_num):
    image = gen_images[i].squeeze().tolist()
    axes[i].imshow(image, cmap='gray')
    axes[i].axis('off')

plt.subplots_adjust(wspace=0.05)
plt.show()

AssertionError: Torch not compiled with CUDA enabled