GradICON deformable registration of FA images. (WIP)

In [None]:
import os
import glob
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import monai
import torch

In [None]:
fa_dir = 'dti_fit_images/fa'
data = [{"fa":path, "filename":os.path.basename(path)} for path in glob.glob(os.path.join(fa_dir,'*'))]
data_train, data_valid = monai.data.utils.partition_dataset(data, ratios=(8,2))

In [None]:
device = torch.device('cpu')

In [None]:
transform = monai.transforms.Compose([
    monai.transforms.LoadImageD(keys="fa"),
    monai.transforms.AddChannelD(keys="fa"),
    # The input images are known (140,140,140); we pad out to 144 in each dim
    monai.transforms.SpatialPadD(keys="fa", spatial_size=(144,144,144), mode="constant"),
    monai.transforms.ToTensorD(keys="fa"),
    monai.transforms.ToDeviceD(keys="fa", device=device),
])

In [None]:
ds_train = monai.data.CacheDataset(data_train, transform)
ds_valid = monai.data.CacheDataset(data_valid, transform)

In [None]:
reg_net = monai.networks.nets.UNet(
    3,  # spatial dims
    2,  # input channels (one for fixed image and one for moving image)
    3,  # output channels (to represent 3D displacement vector field)
    (16, 32, 32, 32, 32, 64),  # channel sequence
    (1, 2, 2, 2, 2),  # convolutional strides
    dropout=0.2,
    norm="batch"
).to(device)

In [None]:
warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="zeros")

def sim_loss(b1, b2):
    """Return image similarity loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D)"""
    return ((b1-b2)**2).mean()

def compose_ddf(u,v):
    """Compose two displacement fields, return the displacement that warps by v followed by u"""
    return u + warp(v,u)

_, H, W, D = ds_train[0]['fa'].shape

# Compute discrete spatial derivatives
def diff_and_trim(array, axis):
    """Take the discrete difference along a spatial axis, which should be 2,3, or 4.
    Return a difference tensor with all spatial axes trimmed by 1."""
    return torch.diff(array, axis=axis)[:, :, :(H-1), :(W-1), :(D-1)]

def size_of_spatial_derivative(u):
    """Return the squared Frobenius norm of the spatial derivative of the given displacement field.
    To clarify, this is about the derivative of the actual displacement field map, not the deformation
    that the displacement field map defines. The expected input shape is (batch,3,H,W,D).
    Output shape is (batch)."""
    dx = diff_and_trim(u, 2)
    dy = diff_and_trim(u, 3)
    dz = diff_and_trim(u, 4)
    return(dx**2 + dy**2 + dz**2).sum(axis=1).mean(axis=[1,2,3])

In [None]:
dl_train = monai.data.DataLoader(ds_train, shuffle=True, batch_size=2, drop_last=True)
max_epochs = 1
for e in range(max_epochs):
    dl_train_iter = iter(dl_train)
    while True:
        try:
            b1 = next(dl_train_iter)
            b2 = next(dl_train_iter)
        except StopIteration:
            break
        
        print('peup')
        
        img_A = b1['fa']
        img_B = b2['fa']
        img_pair_AB = torch.cat((img_A, img_B), dim=1)
        img_pair_BA = img_pair_AB[:,[1,0]]
        
        deformation_AB = reg_net(img_pair_AB) # deforms img_B to the space of img_A
        deformation_BA = reg_net(img_pair_BA) # deforms img_A to the space of img_B
        
        img_B_warped = warp(img_B, deformation_AB)
        img_A_warped = warp(img_A, deformation_BA)
        sim_loss_A = sim_loss(img_A, img_B_warped)
        sim_loss_B = sim_loss(img_B, img_A_warped)
        composite_deformation_B = compose_ddf(deformation_BA, deformation_AB)
        composite_deformation_A = compose_ddf(deformation_AB, deformation_BA)
        gradicon_loss_A = size_of_spatial_derivative(composite_deformation_A).mean()
        gradicon_loss_B = size_of_spatial_derivative(composite_deformation_B).mean()