GradICON deformable registration of FA images. (WIP)

In [None]:
import os
import glob
import random
from collections import namedtuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import monai
import torch
import torch.nn

import footsteps
import pickle

In [None]:
fa_dir = 'dti_fit_images/fa'
data = [{"fa":path, "filename":os.path.basename(path)} for path in glob.glob(os.path.join(fa_dir,'*'))]
data_train, data_valid = monai.data.utils.partition_dataset(data, ratios=(8,2))

In [None]:
device = torch.device('cuda')

In [None]:
transform = monai.transforms.Compose([
    monai.transforms.LoadImageD(keys="fa"),
    monai.transforms.AddChannelD(keys="fa"),
    # The input images are known (140,140,140); we pad out to 144 in each dim
    monai.transforms.SpatialPadD(keys="fa", spatial_size=(144,144,144), mode="constant"),
    monai.transforms.ToTensorD(keys="fa"),
    monai.transforms.ToDeviceD(keys="fa", device=device),
])

In [None]:
ds_train = monai.data.CacheDataset(data_train, transform)
ds_valid = monai.data.CacheDataset(data_valid, transform)

In [None]:
warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="zeros")

def mse_loss(b1, b2):
    """Return image similarity loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D)
    It is scaled up a bit here."""
    return 10000*((b1-b2)**2).mean()

def compose_ddf(u,v):
    """Compose two displacement fields, return the displacement that warps by v followed by u"""
    return u + warp(v,u)

_, H, W, D = ds_train[0]['fa'].shape

# Compute discrete spatial derivatives
def diff_and_trim(array, axis):
    """Take the discrete difference along a spatial axis, which should be 2,3, or 4.
    Return a difference tensor with all spatial axes trimmed by 1."""
    return torch.diff(array, axis=axis)[:, :, :(H-1), :(W-1), :(D-1)]

def size_of_spatial_derivative(u):
    """Return the squared Frobenius norm of the spatial derivative of the given displacement field.
    To clarify, this is about the derivative of the actual displacement field map, not the deformation
    that the displacement field map defines. The expected input shape is (batch,3,H,W,D).
    Output shape is (batch)."""
    dx = diff_and_trim(u, 2)
    dy = diff_and_trim(u, 3)
    dz = diff_and_trim(u, 4)
    return(dx**2 + dy**2 + dz**2).sum(axis=1).mean(axis=[1,2,3])

In [None]:
ModelOutput = namedtuple("ModelOutput", "all_loss, sim_loss, gradicon_loss, deformation_AB")

class Model(torch.nn.Module):
    def __init__(self, lambda_reg, compute_sim_loss):
        super().__init__()
        self.reg_net = monai.networks.nets.UNet(
            3,  # spatial dims
            2,  # input channels (one for fixed image and one for moving image)
            3,  # output channels (to represent 3D displacement vector field)
            (32, 32, 32, 32, 64),  # channel sequence
            (2, 2, 2, 2),  # convolutional strides
            dropout=0.2,
            norm="batch"
        )
        self.lambda_reg = lambda_reg
        self.compute_sim_loss = compute_sim_loss

    def forward(self, img_A, img_B) -> ModelOutput:
        img_pair_AB = torch.cat((img_A, img_B), dim=1)
        img_pair_BA = img_pair_AB[:,[1,0]]

        deformation_AB = self.reg_net(img_pair_AB) # deforms img_B to the space of img_A
        deformation_BA = self.reg_net(img_pair_BA) # deforms img_A to the space of img_B

        img_B_warped = warp(img_B, deformation_AB)
        img_A_warped = warp(img_A, deformation_BA)
        sim_loss_A = self.compute_sim_loss(img_A, img_B_warped)
        sim_loss_B = self.compute_sim_loss(img_B, img_A_warped)
        composite_deformation_A = compose_ddf(deformation_AB, deformation_BA)
        composite_deformation_B = compose_ddf(deformation_BA, deformation_AB)
        gradicon_loss_A = size_of_spatial_derivative(composite_deformation_A).mean()
        gradicon_loss_B = size_of_spatial_derivative(composite_deformation_B).mean()
        
        sim_loss = sim_loss_A + sim_loss_B
        gradicon_loss = gradicon_loss_A + gradicon_loss_B
        
        return ModelOutput(
            all_loss = sim_loss + self.lambda_reg * gradicon_loss,
            sim_loss = sim_loss,
            gradicon_loss = gradicon_loss,
            deformation_AB = deformation_AB
        )

In [None]:
class LossCurves:
    def __init__(self, name : str):
        self.name = name
        
        self.epochs =[]
        self.all_losses = []
        self.sim_losses = []
        self.gradicon_losses = []
        
        self.clear_buffers()
        
    def clear_buffers(self):
        self.all_losses_buffer = []
        self.sim_losses_buffer = []
        self.gradicon_losses_buffer = []
        
    def add_to_buffer(self, model_output : ModelOutput):
        self.all_losses_buffer.append(model_output.all_loss.item())
        self.sim_losses_buffer.append(model_output.sim_loss.item())
        self.gradicon_losses_buffer.append(model_output.gradicon_loss.item())
        
    def aggregate_buffers_for_epoch(self, epoch : int):
        self.epochs.append(epoch)
        self.all_losses.append(np.mean(self.all_losses_buffer))
        self.sim_losses.append(np.mean(self.sim_losses_buffer))
        self.gradicon_losses.append(np.mean(self.gradicon_losses_buffer))
        self.clear_buffers()
        
    def plot(self, savepath=None):
        fig, axs = plt.subplots(1,3,figsize = (15,5))
        axs[0].plot(self.epochs, self.all_losses)
        axs[0].set_title(f"{self.name}: overall loss")
        axs[1].plot(self.epochs, self.sim_losses)
        axs[1].set_title(f"{self.name}: similarity loss")
        axs[2].plot(self.epochs, self.gradicon_losses, label="gradicon loss")
        axs[2].set_title(f"{self.name}: gradicon loss")
        for ax in axs:
            ax.set_xlabel("epoch")
        if savepath is not None:
            plt.savefig(savepath)
        plt.show()

In [None]:
lncc_loss = monai.losses.LocalNormalizedCrossCorrelationLoss(
    spatial_dims=3,
    kernel_size=5,
    smooth_nr = 0,
    smooth_dr = 1e-4
)

In [None]:
dl_train = monai.data.DataLoader(ds_train, shuffle=True, batch_size=1, drop_last=True)
dl_valid = monai.data.DataLoader(ds_valid, shuffle=True, batch_size=2, drop_last=True)
max_epochs = 600
validate_when = lambda e : ((e%2==0) and (e!=0)) or (e==max_epochs-1)
model = Model(
    lambda_reg = 90,
    compute_sim_loss = mse_loss
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_curves_train = LossCurves("training")
loss_curves_valid = LossCurves("validation")

In [None]:
for e in range(max_epochs):
    print(f'Epoch {e+1}/{max_epochs}:')
    
    # Train
    model.train()
    dl_train_iter = iter(dl_train)
    while True:
        try:
            b1 = next(dl_train_iter)
            b2 = next(dl_train_iter)
        except StopIteration:
            break
        
        optimizer.zero_grad()
        model_output = model(b1['fa'], b2['fa'])
        model_output.all_loss.backward()
        optimizer.step()
        
        loss_curves_train.add_to_buffer(model_output)
        del(model_output)
    loss_curves_train.aggregate_buffers_for_epoch(e)
    print(f"\tTraining loss: {loss_curves_train.all_losses[-1]:.4f} ({loss_curves_train.sim_losses[-1]:.4f},{loss_curves_train.gradicon_losses[-1]:.4f})")
    
    # Validate
    if validate_when(e):
        model.eval()
        dl_valid_iter = iter(dl_valid)
        while True:
            try:
                b1 = next(dl_valid_iter)
                b2 = next(dl_valid_iter)
            except StopIteration:
                break

            with torch.no_grad():
                model_output = model(b1['fa'], b2['fa'])
                loss_curves_valid.add_to_buffer(model_output)
        loss_curves_valid.aggregate_buffers_for_epoch(e) 
        print("\tValidation loss:", loss_curves_valid.all_losses[-1])
        
    

In [None]:
# SAVE

loss_curves_train.plot(savepath = footsteps.output_dir + 'loss_plot_train.png')
loss_curves_valid.plot(savepath = footsteps.output_dir + 'loss_plot_valid.png')

with open(footsteps.output_dir + 'loss_curves.p', 'wb') as f:
    pickle.dump([loss_curves_train, loss_curves_valid],f)

torch.save(model.state_dict(), footsteps.output_dir + 'model_state_dict.pth')

In [None]:
# LOAD

model.load_state_dict(torch.load(footsteps.output_dir + 'model_state_dict.pth'))

with open(footsteps.output_dir + 'loss_curves.p', 'rb') as f:
    loss_curves_train, loss_curves_valid = pickle.load(f)

In [None]:
import util

d1 = random.choice(ds_valid)
d2 = random.choice(ds_valid)
img_A = d1['fa'].unsqueeze(0)
img_B = d2['fa'].unsqueeze(0)
model.eval()
with torch.no_grad():
    model_output = model(img_A,img_B)

img_B_warped = warp(img_B, model_output.deformation_AB)

preview_slices = (80,80,80)

print("moving:")
util.preview_image(img_B[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("warped moving:")
util.preview_image(img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("target:")
util.preview_image(img_A[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("checkerboard of warped moving and target:")
util.preview_checkerboard(img_A[0,0].cpu(), img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("deformation vector field:")
util.preview_3D_vector_field(model_output.deformation_AB[0].cpu(), slices=preview_slices)
print("deformed grid:")
util.preview_3D_deformation(model_output.deformation_AB[0].cpu(),5, slices=preview_slices)
print("jacobian determinant:")
det = util.jacobian_determinant(model_output.deformation_AB[0].cpu())
util.preview_image(det, normalize_by='slice', threshold=0, slices=preview_slices)
print("sim loss:", model_output.sim_loss.item())
print("gradicon loss:", model_output.gradicon_loss.item())
print("overall loss:", model_output.all_loss.item())
print("Number of folds:", (det<0).sum())
