ICON or GradICON deformable registration of DTI images. (WIP)

In [None]:
import os
import glob
import random
import shutil
from collections import namedtuple
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import monai
import torch
import torch.nn

import footsteps
import pickle
import util

from dti_warp import WarpDTI, TensorTransformType, MseLossDTI
from util import ComposeDDF

In [None]:
device = torch.device('cuda')
spatial_size = (144,144,144)

The input images are known to be $140\times140\times140$, and we can pad them out to $144$ in each dimension.

In [None]:
data_dir = Path('/data/ebrahim-data/abcd/registration-experiments/ISBI-2023-project/dti_fit_images_nontest/')
fa_dir = data_dir/'fa'
dti_dir = data_dir/'dti'
data = [{'dti':str(path), 'fa':str(path.parent.parent/'fa'/path.name), "filename":path.name} for path in dti_dir.glob('*')]
data_train, data_valid = monai.data.utils.partition_dataset(data, ratios=(8,2))

In [None]:
k = ['fa', 'dti']

base_transforms = [
    monai.transforms.LoadImageD(keys=k),
    monai.transforms.EnsureChannelFirstD(keys=k),
    monai.transforms.SpatialPadD(keys=k, spatial_size=spatial_size, mode="constant"),
    monai.transforms.ToTensorD(keys=k),
    monai.transforms.ToDeviceD(keys=k, device=device),
]

In [None]:
transform_valid = monai.transforms.Compose(base_transforms)
transform_train = monai.transforms.Compose(base_transforms)

In [None]:
def mse_loss(b1, b2):
    """Return image similarity loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D).
    It is scaled up a bit here."""
    return 10000*((b1-b2)**2).mean()

def mse_tensors(b1, b2):
    """Return the mean of the squared distances between tensors in two batches of DTIs,
    each of shape (batch_size, channels, H,W,D)."""
    return blehsdfjsadflkjahsdf

def ncc_loss(b1, b2):
    """Return the negative NCC loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D).
    It is averaged over batches and channels."""
    mu1 = b1.mean(dim=(2,3,4)) # means
    mu2 = b2.mean(dim=(2,3,4))
    alpha1 = (b1**2).mean(dim=(2,3,4)) # second moments
    alpha2 = (b2**2).mean(dim=(2,3,4))
    alpha12 = (b1*b2).mean(dim=(2,3,4)) # cross term
    numerator = alpha12 - mu1*mu2
    denominator = torch.sqrt((alpha1 - mu1**2) * (alpha2-mu2**2))
    ncc = numerator / denominator
    return -ncc.mean() # average over batches and channels

H, W, D = spatial_size

# Compute discrete spatial derivatives
def diff_and_trim(array, axis):
    """Take the discrete difference along a spatial axis, which should be 2,3, or 4.
    Return a difference tensor with all spatial axes trimmed by 1."""
    return torch.diff(array, axis=axis)[:, :, :(H-1), :(W-1), :(D-1)]

def size_of_spatial_derivative(u):
    """Return the squared Frobenius norm of the spatial derivative of the given displacement field.
    To clarify, this is about the derivative of the actual displacement field map, not the deformation
    that the displacement field map defines. The expected input shape is (batch,3,H,W,D).
    Output shape is (batch)."""
    dx = diff_and_trim(u, 2)
    dy = diff_and_trim(u, 3)
    dz = diff_and_trim(u, 4)
    return(dx**2 + dy**2 + dz**2).sum(axis=1).mean(axis=[1,2,3])

In [None]:
from enum import Enum
from typing import Union

ModelOutput = namedtuple("ModelOutput", "all_loss, sim_loss, icon_loss, deformation_AB, sim_loss_weighted, icon_loss_weighted")

class IconLossType(Enum):
    ICON = 0
    GRADICON = 1
    
# A deformable registration model
class RegModel(torch.nn.Module):
    def __init__(self,
                 device,
                 lambda_reg,
                 lambda_sim,
                 down_convolutions,
                 depth,
                 max_channels,
                 init_channels,
                 icon_loss_type : IconLossType = IconLossType.GRADICON,
                 downsample_early = True
                ):
        """
        Create a deformable registration network
        
        Args:
            device:
            lambda_reg: Hyperparameter for weight of icon/gradicon loss
            lambda_sim: Hyperparameter for weight of similarity loss (not independent from lambda_reg so
                not really a new hyperparameter)
            compute_sim_loss: A function that compares two batches of images and returns a similarity loss
            down_convolutions: How many stride=2 convolutions to include in the down-convolution part of the unets
                               when at the original image scale. We assume the original image size is divisible by
                               2**down_convolutions
            depth: Total number of layers in half of the unet. Increase this to increase model capacity.
                   Must be >= down_convolutions
            max_channels: As you go to deeper layers, channels grow by powers of two... up to a maximum given here.
            init_channels: how many channels in the first layer
            icon_loss_type: whether to use ICON or GradICON
            downsample_early: The CNN can contain strided and unstrided convolutions to achieve the requested
                              depth; this paramter decides whether to prefer putting strided convolutions earlier
                              or later.
        """
        super().__init__()
        self.icon_loss_type = icon_loss_type
        if depth < down_convolutions:
            raise ValueError("Must have depth >= down_convolutions")
        # (We will assume that the original image size is divisible by 2**n.)

        
        
        num_twos = down_convolutions # The number of 2's we will put in the sequence of convolutional strides.
        num_ones = depth-down_convolutions # The number of 1's
        num_one_two_pairs = min(num_ones, num_twos) # The number of 1,2 pairs to stick in the middle
        if downsample_early:
            stride_sequence = (2,)*(num_twos-num_one_two_pairs) + (1,2)*num_one_two_pairs + (1,)*(num_ones-num_one_two_pairs)
        else:
            stride_sequence = (1,)*(num_ones-num_one_two_pairs) + (1,2)*num_one_two_pairs + (2,)*(num_twos-num_one_two_pairs)
        channel_sequence = [min(init_channels*2**c,max_channels) for c in range(num_twos+num_ones+1)]
        
        self.reg_net = monai.networks.nets.UNet(
            3,  # spatial dims
            12, # input channels (6 for lower triangular entries of fixed image and 6 for moving image)
            3,  # output channels (to represent 3D displacement vector field)
            channel_sequence,
            stride_sequence,
            dropout=0.2,
            norm="batch"
        )
        self.strides = stride_sequence
        self.channels = channel_sequence
        
        self.lambda_reg = lambda_reg
        self.lambda_sim = lambda_sim
        self.compute_sim_loss = MseLossDTI(device=device)
        
        self.warp = WarpDTI(device=device, tensor_transform_type=TensorTransformType.FINITE_STRAIN)
        self.compose_ddf = ComposeDDF()
        self.to(device)
    
    def update_lambda_reg(self, new_lambda_reg):
        self.lambda_reg = new_lambda_reg
        
    def forward(self, img_A, img_B, return_warp_only = False) -> Union[ModelOutput,torch.Tensor]:
        deformation_AB = self.reg_net(torch.cat([img_A, img_B], dim=1)) # deforms img_B to the space of img_A
        if return_warp_only:
            return deformation_AB
        deformation_BA = self.reg_net(torch.cat([img_B, img_A], dim=1)) # deforms img_A to the space of img_B

        img_B_warped = self.warp(img_B, deformation_AB)
        img_A_warped = self.warp(img_A, deformation_BA)
        sim_loss_A = self.compute_sim_loss(img_A, img_B_warped)
        sim_loss_B = self.compute_sim_loss(img_B, img_A_warped)
        composite_deformation_A = self.compose_ddf(deformation_AB, deformation_BA)
        composite_deformation_B = self.compose_ddf(deformation_BA, deformation_AB)
        
        if self.icon_loss_type == IconLossType.GRADICON:
            icon_loss_A = size_of_spatial_derivative(composite_deformation_A).mean()
            icon_loss_B = size_of_spatial_derivative(composite_deformation_B).mean()
        elif self.icon_loss_type == IconLossType.ICON:
            icon_loss_A = (composite_deformation_A**2).mean()
            icon_loss_B = (composite_deformation_B**2).mean()
        
        sim_loss = sim_loss_A + sim_loss_B
        icon_loss = icon_loss_A + icon_loss_B
    
        return ModelOutput(
            all_loss = self.lambda_sim * sim_loss + self.lambda_reg * icon_loss,
            sim_loss = sim_loss,
            icon_loss = icon_loss,
            sim_loss_weighted = self.lambda_sim * sim_loss,
            icon_loss_weighted = self.lambda_reg * icon_loss,
            deformation_AB = deformation_AB
        )

In [None]:
class JacobianDeterminant(torch.nn.Module):
    """Given a batch of displacement vector fields vf, compute the jacobian determinant scalar field."""

    def __init__(self, spatial_dims):
        super().__init__()
        self.spatial_dims = spatial_dims
    def diff_and_trim(self, array, axis):
        H,W,D = self.spatial_dims
        return torch.diff(array, axis=axis)[:, :, :(H-1), :(W-1), :(D-1)]

    def forward(self, vf):
        """
        vf is assumed to be a vector field of shape (B,3,H,W,D),
        and it is interpreted as a displacement field.
        So it is defining a batch of discretely sampled maps from a subset of 3-space into 3-space,
        namely (for batch index b) the map that sends point (x,y,z) to the point (x,y,z)+vf[b,:,x,y,z].
        This function computes a jacobian determinant by taking discrete differences in each spatial direction.

        Returns a numpy array of shape (b,H-1,W-1,D-1).
        """
        dx = self.diff_and_trim(vf, 2)
        dy = self.diff_and_trim(vf, 3)
        dz = self.diff_and_trim(vf, 4)
        
        # Add derivative of identity map
        dx[:,0] += 1
        dy[:,1] += 1
        dz[:,2] += 1

        # Compute determinant at each spatial location
        det = dx[:,0]*(dy[:,1]*dz[:,2]-dz[:,1]*dy[:,2]) - dy[:,0]*(dx[:,1]*dz[:,2] -
                                                dz[:,1]*dx[:,2]) + dz[:,0]*(dx[:,1]*dy[:,2]-dy[:,1]*dx[:,2])

        return det

In [None]:
class LossCurves:
    def __init__(self, name : str, include_folds : bool = False, spatial_dims : tuple = None):
        self.name = name
        
        self.epochs =[]
        self.all_losses = []
        self.sim_losses = []
        self.icon_losses = []
        
        self.include_folds = include_folds
        if include_folds:
            if spatial_dims is None:
                raise Exception("Need argument spatial_dims to include fold count.")
            self.jacobian_determinant = JacobianDeterminant(spatial_dims)
            self.fold_counts = []
        
        self.clear_buffers()
        
    def clear_buffers(self):
        self.all_losses_buffer = []
        self.sim_losses_buffer = []
        self.icon_losses_buffer = []
        
        if self.include_folds:
            self.fold_counts_buffer = []
        
    def add_to_buffer(self, model_output : ModelOutput):
        self.all_losses_buffer.append(model_output.all_loss.item())
        self.sim_losses_buffer.append(model_output.sim_loss.item())
        self.icon_losses_buffer.append(model_output.icon_loss.item())
        
        if self.include_folds:
            det = self.jacobian_determinant(model_output.deformation_AB)
            num_folds = (det<0).sum(dim=(1,2,3))
            num_folds_mean = num_folds.to(dtype=torch.float).mean().item() # average over batch
            self.fold_counts_buffer.append(num_folds_mean)
        
    def aggregate_buffers_for_epoch(self, epoch : int):
        self.epochs.append(epoch)
        self.all_losses.append(np.mean(self.all_losses_buffer))
        self.sim_losses.append(np.mean(self.sim_losses_buffer))
        self.icon_losses.append(np.mean(self.icon_losses_buffer))
        if self.include_folds:
            self.fold_counts.append(np.mean(self.fold_counts_buffer))
        self.clear_buffers()
        
    def plot(self, savepath=None):
        fig, axs = plt.subplots(1,3 if not self.include_folds else 4,figsize = (15,5))
        axs[0].plot(self.epochs, self.all_losses)
        axs[0].set_title(f"{self.name}: overall loss")
        axs[1].plot(self.epochs, self.sim_losses)
        axs[1].set_title(f"{self.name}: similarity loss")
        axs[2].plot(self.epochs, self.icon_losses, label="icon loss")
        axs[2].set_title(f"{self.name}: icon loss")
        if self.include_folds:
            axs[3].plot(self.epochs, self.fold_counts, label="average folds")
            axs[3].set_title(f"{self.name}: average folds")
        for ax in axs:
            ax.set_xlabel("epoch")
        if savepath is not None:
            plt.savefig(savepath)
        plt.show()

In [None]:
# The SVD in the tensor transformations take a lot longer than it takes to copy into GPU,
# so there's not much sense in using CacheDataset. Might as well save the GPU memory and use it for model size.

cache_dir = Path('./PersistentDatasetCacheDir')
if cache_dir.exists():
    shutil.rmtree(cache_dir)
cache_dir.mkdir(exist_ok=True)
ds_train = monai.data.PersistentDataset(data_train, transform_train, cache_dir=cache_dir/'train')
ds_valid = monai.data.PersistentDataset(data_valid, transform_valid, cache_dir=cache_dir/'valid')

---

Model creation and training for single scale approach, with affine augmentation

---

In [None]:
dl_train = monai.data.DataLoader(ds_train, shuffle=True, batch_size=2, drop_last=True)
dl_valid = monai.data.DataLoader(ds_valid, shuffle=True, batch_size=2, drop_last=True)

validate_when = lambda e : ((e%5==0) and (e!=0)) or (e==max_epochs-1)
print_aggregate_when = lambda e : True
last_printed_epoch = -1

schedule_lambda_reg = True
lambda_reg_step_size = 0.1 # How much to increase lambda_reg each time it advances
cooldown = 3 # How many epochs to allow before checking whether training loss increases and advacing lambda_reg if so
cooldown_counter = cooldown
lambda_reg_goal = 2 # Stop training once lambda_reg advances past this

# from customRandAffine import AffineAugmentation
# affine_aug = AffineAugmentation(spatial_size, 0.8)

model = RegModel(
    lambda_reg = 0.1,
    lambda_sim = 1e7,
    device=device,
    down_convolutions=4,
    depth=4,
    max_channels=256,
    init_channels=32,
    icon_loss_type=IconLossType.ICON,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

schedule_lr = False
min_lr=1e-5
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1**(1/3000))

loss_curves_train = LossCurves("training")
loss_curves_valid = LossCurves("validation", include_folds=True, spatial_dims=spatial_size)

batches_per_epoch_train = len(dl_train)//2
print("Number of trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
print("Number of batches per epoch:", batches_per_epoch_train)

In [None]:
e=0

In [None]:
# TRAINING

max_epochs = 10
while e < max_epochs:
#     current_lr = scheduler.get_last_lr()[0]
    current_lr = optimizer.state_dict()['param_groups'][0]['lr']
    print(f'Epoch {e+1}/{max_epochs} (LR = {current_lr:.1e}, lambda_reg = {model.lambda_reg:.1e}):')
    
    # Train
    model.train()
    dl_train_iter = iter(dl_train)
    while True:
        try:
            b1 = next(dl_train_iter)
            b2 = next(dl_train_iter)
        except StopIteration:
            break
            
#         fixed, moving = affine_aug(b1[fa_key], b2[fa_key])
        fixed, moving = b1['dti'], b2['dti']
        
        optimizer.zero_grad()
        model_output = model(fixed, moving)
        model_output.all_loss.backward()
        optimizer.step()
        
        loss_curves_train.add_to_buffer(model_output)
        del(model_output)
        print('.', end='')
    print()

    if schedule_lr:
        if scheduler.get_last_lr()[0] > min_lr:
            scheduler.step()

    loss_curves_train.aggregate_buffers_for_epoch(e)
    if print_aggregate_when(e):
        l_all = np.mean(loss_curves_train.all_losses[last_printed_epoch+1:])
        l_sim = np.mean(loss_curves_train.sim_losses[last_printed_epoch+1:])
        l_icon = np.mean(loss_curves_train.icon_losses[last_printed_epoch+1:])
        print(f"\tTraining loss: {l_all:.4f} (sim (weighted)={model.lambda_sim*l_sim:.4f}, ic (weighted)={model.lambda_reg*l_icon:.4f})")
        print(f"\t(aggregated from epochs {last_printed_epoch+2} to {len(loss_curves_train.all_losses)})")
        last_printed_epoch = e
    
    # Validate
    if validate_when(e):
        model.eval()
        dl_valid_iter = iter(dl_valid)
        while True:
            try:
                b1 = next(dl_valid_iter)
                b2 = next(dl_valid_iter)
            except StopIteration:
                break
            
#             fixed, moving = affine_aug(b1[fa_key], b2[fa_key])
            fixed, moving = b1['dti'], b2['dti']

            with torch.no_grad():
                model_output = model(fixed, moving)
                loss_curves_valid.add_to_buffer(model_output)
        loss_curves_valid.aggregate_buffers_for_epoch(e) 
        print("\tValidation loss:", loss_curves_valid.all_losses[-1])
        print("\tAverage folds:", loss_curves_valid.fold_counts[-1])
    
    if schedule_lambda_reg:
        cooldown_counter -= 1;
        if cooldown_counter<=0 and loss_curves_train.all_losses[-1] > loss_curves_train.all_losses[-2]:
            print(f"Updating lambda_reg.")
            model.update_lambda_reg(model.lambda_reg + lambda_reg_step_size)
            cooldown_counter = cooldown # reset cooldown_counter
        if model.lambda_reg > lambda_reg_goal:
            print("Reached goal lambda_reg.")
            break
            
    e += 1
        
        
    

---

Saving, plotting, and loading

---

In [None]:
# SAVE

loss_curves_train.plot(savepath = footsteps.output_dir + 'loss_plot_train.png')
loss_curves_valid.plot(savepath = footsteps.output_dir + 'loss_plot_valid.png')

with open(footsteps.output_dir + 'loss_curves.p', 'wb') as f:
    pickle.dump([loss_curves_train, loss_curves_valid],f)

torch.save(model.state_dict(), footsteps.output_dir + 'model_state_dict.pth')

In [None]:
# LOAD

experiment_name_to_load = "beh-1"
load_dir = os.path.join('results', experiment_name_to_load)

model.load_state_dict(torch.load(os.path.join(load_dir,'model_state_dict.pth')))

with open(os.path.join(load_dir,'loss_curves.p'), 'rb') as f:
    loss_curves_train, loss_curves_valid = pickle.load(f)

---

Preview

---

In [None]:
# Quick preview

dl = dl_valid # Choose whether to view performance on training or on validation data
it = iter(dl)
d1 = next(it)
d2 = next(it)

img_A = d1['fa']
img_B = d2['fa']

# img_A, img_B  = affine_aug(img_A, img_B)

model.eval()
with torch.no_grad():
    ddf = model(d1['dti'], d2['dti'], return_warp_only=True)

warp = monai.networks.blocks.Warp(mode='bilinear', padding_mode='zeros')
img_B_warped = warp(img_B, ddf)

preview_slices = (80,80,80)

print("moving:")
util.preview_image(img_B[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("warped moving:")
util.preview_image(img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("target:")
util.preview_image(img_A[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("checkerboard of warped moving and target:")
util.preview_checkerboard(img_A[0,0].cpu(), img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("deformation vector field:")
util.preview_3D_vector_field(ddf[0].cpu(), slices=preview_slices)
print("deformed grid:")
util.preview_3D_deformation(ddf[0].cpu(),5, slices=preview_slices)
print("jacobian determinant:")
det = util.jacobian_determinant(ddf[0].cpu())
util.preview_image(det, normalize_by='slice', threshold=0, slices=preview_slices)
num_folds = (det<0).sum()
print("Number of folds:", num_folds, f"(folding rate {100*num_folds/np.prod(det.shape)}%)")

In [None]:
# Slow but more informative preview

dl = dl_valid # Choose whether to view performance on training or on validation data
it = iter(dl)
d1 = next(it)
d2 = next(it)

img_A = d1['fa']
img_B = d2['fa']

# img_A, img_B  = affine_aug(img_A, img_B)

model.eval()
with torch.no_grad():
    model_output = model(d1['dti'], d2['dti'])

warp = monai.networks.blocks.Warp(mode='bilinear', padding_mode='zeros')
img_B_warped = warp(img_B, model_output.deformation_AB)

preview_slices = (80,80,80)

print("moving:")
util.preview_image(img_B[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("warped moving:")
util.preview_image(img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("target:")
util.preview_image(img_A[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("checkerboard of warped moving and target:")
util.preview_checkerboard(img_A[0,0].cpu(), img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("deformation vector field:")
util.preview_3D_vector_field(model_output.deformation_AB[0].cpu(), slices=preview_slices)
print("deformed grid:")
util.preview_3D_deformation(model_output.deformation_AB[0].cpu(),5, slices=preview_slices)
print("jacobian determinant:")
det = util.jacobian_determinant(model_output.deformation_AB[0].cpu())
util.preview_image(det, normalize_by='slice', threshold=0, slices=preview_slices)
print("sim loss:", model_output.sim_loss.item())
print("(grad?)icon loss:", model_output.icon_loss.item())
print("overall loss:", model_output.all_loss.item())
num_folds = (det<0).sum()
print("Number of folds:", num_folds, f"(folding rate {100*num_folds/np.prod(det.shape)}%)")