plan: use the synthetic data generator you made to create a dataset, then create a reg unet that generates vector fields that we will interpret as velocity fields (for now, to start with). Use DVF2DDF and a basic norm-based regularization loss and an LNCC similarity loss to train the registration network.

In [None]:
from main_demo import synthesize_image
import matplotlib.pyplot as plt
import monai
import torch
import numpy as np
import tempfile, os, glob, random

In [None]:
# Create a temporary directory to work in
root_dir = tempfile.mkdtemp()
data_dir = os.path.join(root_dir, "synthetic_data")
print(f"Working in the following directory: {data_dir}")
save_image = monai.transforms.SaveImage(
    output_dir=data_dir,
    output_ext='.png',
    scale = 255, # applies to png format only
    separate_folder = False,
    print_log=False
)

# Save a bunch of synthetic images in the temporary directory
number_of_images_to_generate = 300
for _ in range(number_of_images_to_generate):
    image = synthesize_image((128,128), 32, 96)
    image = np.expand_dims(image, axis=0) # add channel dimension, which save_image expects
    save_image(image)
image_paths = glob.glob(os.path.join(data_dir,'*.png'))

# Split training vs validation data, and
# create data list made out of image pairs (including pairs in a symmetric and reflexive fashion)
image_paths_valid, image_paths_train = monai.data.utils.partition_dataset(image_paths, ratios=(2,8), shuffle=True)
def create_image_pairs_data(paths):
    """Given a list of image paths, create a data list where each data item is a dictionary containing a pair of images."""
    return [{'img0':paths[i0], 'img1':paths[i1]} for i0 in range(len(paths)) for i1 in range(len(paths))]
data_train = create_image_pairs_data(image_paths_train)
data_valid = create_image_pairs_data(image_paths_valid)

# Create transform for loading the data and producing a 2-channel tensor of images,
# the two channels being for target/fixed vs moving image in the registration task
transform = monai.transforms.Compose(
    transforms=[
        monai.transforms.LoadImageD(keys=['img0', 'img1'], image_only=True),
        monai.transforms.ToTensorD(keys=['img0', 'img1']),
        monai.transforms.AddChannelD(keys=['img0', 'img1']),
        monai.transforms.ConcatItemsD(keys=['img0', 'img1'], name='img01', dim=0),
        monai.transforms.DeleteItemsD(keys=['img0', 'img1']),
    ]
)

# Create datasets
# (TODO: Replace these by CacheDataset when training infrastructure is in place.
# For now they are regular datasets to prevent memory overhead while in early development.)
dataset_train = monai.data.Dataset(data=data_train, transform=transform)
dataset_valid = monai.data.Dataset(data=data_valid, transform=transform)

In [None]:
spatial_dims = 2

reg_net = monai.networks.nets.UNet(
    spatial_dims,  # dimensionality of input and output image domain
    2,  # input channels (one for fixed image and one for moving image)
    spatial_dims,  # output channels (they represent the components of a displacement vector field)
    (16, 32, 32, 32, 32),  # channel sequence
    (1, 2, 2, 2),  # convolutional strides
    dropout=0.2,
    norm="batch"
)


integrate = monai.networks.blocks.DVF2DDF(num_steps=7, mode='bilinear', padding_mode='zeros')
warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="zeros")

In [None]:
# Image similarity loss

lncc_loss = monai.losses.LocalNormalizedCrossCorrelationLoss(
    spatial_dims=spatial_dims,
    kernel_size=3,
    kernel_type='rectangular',
    reduction="mean"
)

In [None]:
# Plotting tools

def plot_2D_vector_field(vector_field, downsampling, axes=None):
    """Plot a 2D vector field given as a tensor of shape (2,H,W).
    The plot origin will be in the lower left.
    Using "x" and "y" for the rightward and upward directions respectively,
      the vector at location (x,y) in the plot image will have
      vector_field[1,y,x] as its x-component and
      vector_field[0,y,x] as its y-component.
    """
    if axes is None:
        axes = plt.axes()
    
    downsample2D = monai.networks.layers.factories.Pool['AVG', 2](
        kernel_size=downsampling)
    vf_downsampled = downsample2D(vector_field.unsqueeze(0))[0]
    axes.quiver(
        vf_downsampled[1, :, :], vf_downsampled[0, :, :],
        angles='xy', scale_units='xy', scale=downsampling,
        headwidth=4.
    )

def plot_2D_deformation(vector_field, grid_spacing, axes=None, **kwargs):
    """
    Interpret vector_field as a displacement vector field defining a deformation,
    and plot an x-y grid warped by this deformation.
    vector_field should be a tensor of shape (2,H,W)
    """
    if axes is None:
        axes = plt.axes()
    _, H, W = vector_field.shape
    grid_img = np.zeros((H,W))
    grid_img[np.arange(0, H, grid_spacing),:]=1
    grid_img[:,np.arange(0, W, grid_spacing)]=1
    grid_img = torch.tensor(grid_img, dtype=vector_field.dtype).unsqueeze(0) # adds channel dimension, now (C,H,W)
    warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="zeros")
    grid_img_warped = warp(grid_img.unsqueeze(0), vector_field.unsqueeze(0))[0]
    axes.imshow(grid_img_warped[0], origin='lower', cmap='gist_gray')

In [None]:
# Try out a forward pass

data_item = random.choice(dataset_train)
reg_net_input = data_item['img01'].unsqueeze(0)
fixed_img = reg_net_input[:,[0]]
moving_img = reg_net_input[:,[1]]
velocity_field = reg_net(reg_net_input)
displacement_field = integrate(velocity_field)
warped_moving_img = warp(moving_img, displacement_field)
loss = lncc_loss(warped_moving_img, fixed_img)

for varname in ['fixed_img', 'moving_img', 'reg_net_input', 'velocity_field', 'displacement_field', 'warped_moving_img']:
    var = globals()[varname]
    print(f"Shape of {varname}: {var.shape}")
print(f"Loss: {loss}")

In [None]:
# Preview everything in that forward pass

fig, axs = plt.subplots(2,3,figsize=(15,10))
axs = axs.reshape(-1)

axs[0].imshow(fixed_img[0,0], origin='lower')
axs[0].set_title('fixed_img')

axs[1].imshow(moving_img[0,0], origin='lower')
axs[1].set_title('moving_img')

plot_2D_vector_field(velocity_field[0].detach(), 4, axs[2])
axs[2].set_title('velocity_field')

plot_2D_vector_field(displacement_field[0].detach(), 4, axs[3])
axs[3].set_title('displacement_field')

plot_2D_deformation(displacement_field[0].detach(), 4, axs[4])
axs[4].set_title('grid warped by displacement_field')

axs[5].imshow(warped_moving_img[0,0].detach(), origin='lower')
axs[5].set_title('warped_moving_img')

plt.show()