plan: use the synthetic data generator you made to create a dataset, then create a reg unet that generates vector fields that we will interpret as velocity fields (for now, to start with). Use DVF2DDF and a basic norm-based regularization loss and an LNCC similarity loss to train the registration network.

In [None]:
from main_demo import synthesize_image
import matplotlib.pyplot as plt
import monai
import torch
import numpy as np
import tempfile, os, glob

In [None]:
# Create a temporary directory to work in
root_dir = tempfile.mkdtemp()
data_dir = os.path.join(root_dir, "synthetic_data")
print(f"Working in the following directory: {data_dir}")
save_image = monai.transforms.SaveImage(
    output_dir=data_dir,
    output_ext='.png',
    scale = 255, # applies to png format only
    separate_folder = False,
    print_log=False
)

# Save a bunch of synthetic images in the temporary directory
number_of_images_to_generate = 300
for _ in range(number_of_images_to_generate):
    image = synthesize_image((128,128), 32, 96)
    image = np.expand_dims(image, axis=0) # add channel dimension, which save_image expects
    save_image(image)
image_paths = glob.glob(os.path.join(data_dir,'*.png'))

# Split training vs validation data, and
# create data list made out of image pairs (including pairs in a symmetric and reflexive fashion)
image_paths_valid, image_paths_train = monai.data.utils.partition_dataset(image_paths, ratios=(2,8), shuffle=True)
def create_image_pairs_data(paths):
    """Given a list of image paths, create a data list where each data item is a dictionary containing a pair of images."""
    return [{'img0':paths[i0], 'img1':paths[i1]} for i0 in range(len(paths)) for i1 in range(len(paths))]
data_train = create_image_pairs_data(image_paths_train)
data_valid = create_image_pairs_data(image_paths_valid)

# Create transform for loading the data and producing a 2-channel tensor of images,
# the two channels being for target/fixed vs moving image in the registration task
transform = monai.transforms.Compose(
    transforms=[
        monai.transforms.LoadImageD(keys=['img0', 'img1'], image_only=True),
        monai.transforms.ToTensorD(keys=['img0', 'img1']),
        monai.transforms.AddChannelD(keys=['img0', 'img1']),
        monai.transforms.ConcatItemsD(keys=['img0', 'img1'], name='img01', dim=0),
        monai.transforms.DeleteItemsD(keys=['img0', 'img1']),
    ]
)

# Create datasets
# (TODO: Replace these by CacheDataset when training infrastructure is in place.
# For now they are regular datasets to prevent memory overhead while in early development.)
dataset_train = monai.data.Dataset(data=data_train, transform=transform)
dataset_valid = monai.data.Dataset(data=data_valid, transform=transform)

In [None]:
dataset_valid[0]['img01'].shape