Affine registration of FA images. (WIP)

In [None]:
import os
import glob
import random
from collections import namedtuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import monai
import torch
import torch.nn

import footsteps
import pickle

In [None]:
device = torch.device('cuda')
spatial_size = (144,144,144)

The input images are known to be $140\times140\times140$, and we will pad them out to $144$ in each dimension.

We assume that each dimension in `spatial_size` is divisible by $2^{\texttt{num}\_\texttt{scales}-1}$, because we will downsample by a factor of $2$ a bunch of times to produce images at different scales.

In [None]:
fa_dir = './dti_fit_images/fa'
fa_key = 'fa'
data = [{fa_key:path, "filename":os.path.basename(path)} for path in glob.glob(os.path.join(fa_dir,'*'))]
data_train, data_valid = monai.data.utils.partition_dataset(data, ratios=(8,2))

`fa_keys` is a list mapping index to key for scale at that index: $0$ is the base resolution, $1$ is downscaled by a factor of $2$, $2$ is further downscaled by a factor of $2$, etc.

In [None]:
base_transforms = [
    monai.transforms.LoadImageD(keys=fa_key),
    monai.transforms.AddChannelD(keys=fa_key),
    monai.transforms.SpatialPadD(keys=fa_key, spatial_size=spatial_size, mode="constant"),
    monai.transforms.ToTensorD(keys=fa_key),
    monai.transforms.ToDeviceD(keys=fa_key, device=device),
]

In [None]:
# Control the overall scale of affine transform
a=0.5

S = spatial_size[0]

rand_affine_params = {
    'prob':1.,
    'mode': 'bilinear',
    'padding_mode': 'zeros',
    'spatial_size':spatial_size,
    'cache_grid':True,
    'rotate_range': (a*np.pi/2,)*3,
    'shear_range': (0,)*6, # no shearing
    'translate_range': (a*S/16,)*3,
    'scale_range': (a*0.4,)*3,
}

rand_affine_transform = monai.transforms.RandAffineD(keys=fa_key, **rand_affine_params)

In [None]:
transform_valid = monai.transforms.Compose(base_transforms + [rand_affine_transform])
transform_train = monai.transforms.Compose(base_transforms + [rand_affine_transform])

In [None]:
ds_train = monai.data.CacheDataset(data_train, transform_train)
ds_valid = monai.data.CacheDataset(data_valid, transform_valid)

In [None]:
warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="zeros")

def mse_loss(b1, b2):
    """Return image similarity loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D).
    It is scaled up a bit here."""
    return 10000*((b1-b2)**2).mean()

def ncc_loss(b1, b2):
    """Return the negative NCC loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D).
    It is averaged over batches and channels."""
    mu1 = b1.mean(dim=(2,3,4)) # means
    mu2 = b2.mean(dim=(2,3,4))
    alpha1 = (b1**2).mean(dim=(2,3,4)) # second moments
    alpha2 = (b2**2).mean(dim=(2,3,4))
    alpha12 = (b1*b2).mean(dim=(2,3,4)) # cross term
    numerator = alpha12 - mu1*mu2
    denominator = torch.sqrt((alpha1 - mu1**2) * (alpha2-mu2**2))
    ncc = numerator / denominator
    return -ncc.mean() # average over batches and channels

def compose_ddf(u,v):
    """Compose two displacement fields, return the displacement that warps by v followed by u"""
    return u + warp(v,u)

_, H, W, D = ds_train[0][fa_key].shape

# Compute discrete spatial derivatives
def diff_and_trim(array, axis):
    """Take the discrete difference along a spatial axis, which should be 2,3, or 4.
    Return a difference tensor with all spatial axes trimmed by 1."""
    return torch.diff(array, axis=axis)[:, :, :(H-1), :(W-1), :(D-1)]

def size_of_spatial_derivative(u):
    """Return the squared Frobenius norm of the spatial derivative of the given displacement field.
    To clarify, this is about the derivative of the actual displacement field map, not the deformation
    that the displacement field map defines. The expected input shape is (batch,3,H,W,D).
    Output shape is (batch)."""
    dx = diff_and_trim(u, 2)
    dy = diff_and_trim(u, 3)
    dz = diff_and_trim(u, 4)
    return(dx**2 + dy**2 + dz**2).sum(axis=1).mean(axis=[1,2,3])

def compose_affine(u,v):
    """Return the product u.v of two affine transforms given as tensors of shape (b,3,4)
    where b is in the batch dimension."""
    b=u.shape[0]
    last_row = torch.tensor([0,0,0,1],device=u.device, dtype=u.dtype).view((1,1,4))
    last_row = torch.repeat_interleave(last_row,b,dim=0)
    u2 = torch.cat([u,last_row], dim=1)
    v2 = torch.cat([v,last_row], dim=1)
    return torch.matmul(u2, v2)[:,:3,:]

In [None]:
translate = torch.eye(4)
translate[:3,3] = torch.tensor(spatial_size)/2
translate_inv = torch.linalg.inv(translate)
translate = translate.to(device)
translate_inv = translate_inv.to(device)
def center_transform(transform):
    dv = transform.device
    return torch.matmul(translate_inv.to(dv),torch.matmul(transform,translate.to(dv)))
def uncenter_transform(transform):
    dv = transform.device
    return torch.matmul(translate.to(dv),torch.matmul(transform,translate_inv.to(dv)))

In [None]:
#TEMP
import util
dl = monai.data.DataLoader(ds_train, shuffle=True, batch_size=1, drop_last=True)
it = iter(dl)
d1 = next(it)[fa_key].cpu()
d2 = next(it)[fa_key].cpu()
affine_transform = monai.networks.layers.AffineTransform(spatial_size)
t = center_transform(torch.linalg.inv( d1.meta['affine'].float() ))
util.preview_image(
    affine_transform(d1,uncenter_transform(t))[0,0]
)

In [None]:
ModelOutput = namedtuple("ModelOutput", "affine,warped_moving,sim_loss,regularization_loss,supervised_loss,all_loss,true_theta")

class AffineRegModel(torch.nn.Module):
    def __init__(self,
                 compute_sim_loss,
                 down_convolutions,
                 depth,
                 max_channels,
                 init_channels,
                 spatial_size,
                 lambda_reg = 1.,
                 cnn_dropout=0.1,
                 fc_dropout=0.1,
                 fc_hidden_layers = None
                ):
        """
        Create affine registration model
        
        Args:
            compute_sim_loss: A function that compares two batches of images and returns a similarity loss
            down_convolutions: How many stride=2 convolutions to include in the down-convolution part of the unets
                               when at the original image scale. We assume the original image size is divisible by
                               2**down_convolutions
            depth: Total number of layers in half of the unet. Increase this to increase model capacity.
                   Must be >= down_convolutions
            max_channels: As you go to deeper layers, channels grow by powers of two... up to a maximum given here.
            init_channels: how many channels in the first layer
            spatial_size: The spatial size of the input images as a 3-tuple.
            cnn_dropout:
            fc_dropout:
            fc_hidden_layers: List of hidden layer sizes for the fully connected network at the end. By default
                              it's an empty list, which means the fully connected network simply goes from
                              the flattened CNN output to the entries of an affine matrix.
        """
        super().__init__()
        self.compute_sim_loss = compute_sim_loss
        self.lambda_reg = lambda_reg
        
        self.reg_net_architecture_info = []
        if depth < down_convolutions:
            raise ValueError("Must have depth >= down_convolutions")
            
        self.spatial_size = spatial_size
        cnn_spatial_size_factor = 2**down_convolutions
            
        # (We will assume that the original image size is divisible by 2**n.)
        for i,d in enumerate(spatial_size):
            if d%cnn_spatial_size_factor != 0:
                raise ValueError(f"Since down_convolutions={down_convolutions} spatial dimension must be divisible by {cnn_spatial_size_factor}, but got size {d} in spatial dimension {i}.")
        
        self.cnn_output_spatial_size = [s // 2**down_convolutions for s in spatial_size]
        
        num_twos = down_convolutions # The number of 2's we will put in the sequence of convolutional strides.
        num_ones = depth-down_convolutions # The number of 1's
        num_one_two_pairs = min(num_ones, num_twos) # The number of 1,2 pairs to stick in the middle
        stride_sequence = (2,)*(num_twos-num_one_two_pairs) + (1,2)*num_one_two_pairs + (1,)*(num_ones-num_one_two_pairs)
        channel_sequence = [min(init_channels*2**c,max_channels) for c in range(num_twos+num_ones)]
        
        self.cnn_output_flattened_size = channel_sequence[-1]*np.prod(self.cnn_output_spatial_size)
        self.stride_sequence = stride_sequence
        self.channel_sequence = channel_sequence

        
        cnn_layers = []
        for i in range(depth):
            in_channels = channel_sequence[i-1] if i>0 else 2
            cnn_layers.append(monai.networks.blocks.Convolution(
                spatial_dims=3,
                in_channels=in_channels,
                out_channels=channel_sequence[i],
                dropout=cnn_dropout,
                strides=stride_sequence[i]
            ))
        
        self.cnn = torch.nn.Sequential(*cnn_layers)
        
        fc_layers = []
        fc_layer_sizes = fc_hidden_layers if fc_hidden_layers is not None else []
        fc_layer_sizes.append(4*3)
        for i in range(len(fc_layer_sizes)):
            fc_layers.append(
                torch.nn.Linear(
                    self.cnn_output_flattened_size  if i==0 else fc_layer_sizes[i-1],
                    fc_layer_sizes[i]),
            )
            if i!=len(fc_layer_sizes)-1:
                fc_layers.append(torch.nn.Dropout(fc_dropout))
                fc_layers.append(torch.nn.PReLU())
        
        self.fc = torch.nn.Sequential(*fc_layers)
        
        
        # We interpret the output of self.fc as a difference from the identity matrix, and we want
        # that difference to start training as zero
        self.fc[-1].weight.data.zero_()
        self.fc[-1].bias.data.zero_()
        
        # Affine matrix for identity transform with shape 1,3,4
        self.id134 = torch.cat([torch.eye(3), torch.zeros(3).unsqueeze(1)], dim=1).unsqueeze(0)
        
        # Affine transformer that operates in MONAI style coordinates
        self.affine_transform = monai.networks.layers.AffineTransform(self.spatial_size)
        
        
    def forward(self, img_A, img_B, compute_warped_B=False) -> ModelOutput:
        """
        img_A: target image
        img_B: moving image
        """
        
        cnn_output = self.cnn(torch.cat([img_A,img_B], dim=1))
        cnn_output_flattened = cnn_output.view(-1, self.cnn_output_flattened_size)
        theta_minus_id = self.fc(cnn_output_flattened).view(-1, 3, 4)
        
        # This sum conveniently broadcasts over the batch dimension
        id134 = self.id134.to(theta_minus_id.device)
        theta = theta_minus_id + id134
        
        # apply transform with torch coordinates interpretation
#         grid = torch.nn.functional.affine_grid(theta, img_B.size(), align_corners=False)
#         warped_B = torch.nn.functional.grid_sample(img_B, grid, align_corners=False)

        # apply transform with MONAI coordinates interpretation
        if compute_warped_B:
            last_row = torch.tensor([0,0,0,1],device=theta.device, dtype=theta.dtype).view((1,1,4))
            last_row = torch.repeat_interleave(last_row,theta.shape[0],dim=0)
            theta_uncentered = torch.cat([theta,last_row],dim=1)
            theta_uncentered = uncenter_transform(theta_uncentered)
            warped_B = self.affine_transform(img_B,theta_uncentered)
        else:
            warped_B=None
        
        # compute image similarity
#         sim_loss = self.compute_sim_loss(img_A, warped_B)
    
        # get the ground truth correct transform, MONAI coordinates
        with torch.no_grad():
            Ta = img_A.meta['affine'].float()
            Tb = img_B.meta['affine'].float()
            true_theta = center_transform(torch.linalg.solve(Tb, Ta))[:,:3,:]
        
        # supervision
        supervised_loss = ((theta - true_theta.to(theta.device))**2).mean()
        
        # Frobenius norm loss
#         regularization_loss = (theta_minus_id**2).sum()

        # ICon loss
#         cnn_output_rev = self.cnn(torch.cat([img_B,img_A], dim=1))
#         cnn_output_rev_flattened = cnn_output_rev.view(-1, self.cnn_output_flattened_size)
#         theta_rev_minus_id = self.fc(cnn_output_rev_flattened).view(-1, 3, 4)
#         theta_rev = theta_rev_minus_id + id134
#         regularization_loss = ((compose_affine(theta,theta_rev) - id134)**2).mean()
#         regularization_loss += ((compose_affine(theta_rev,theta) - id134)**2).mean()

        # since we computed theta_rev we might as well include a comparison of that to ground truth
#         with torch.no_grad():
#             true_theta_rev = torch.linalg.solve(Ta, Tb)[:,:3,:]
#         sim_loss += ((theta_rev - true_theta_rev.to(theta.device))**2).mean()
        
        regularization_loss = torch.tensor(0)
        sim_loss = torch.tensor(0)
        
        return ModelOutput(
            affine = theta,
            warped_moving = warped_B,
            sim_loss = sim_loss,
            regularization_loss = regularization_loss,
            supervised_loss = supervised_loss,
#             all_loss = sim_loss + self.lambda_reg*regularization_loss,
            all_loss = supervised_loss,
            true_theta= true_theta,
        )

In [None]:
# Not necessarily needed and not sufficiently tested

# Conversions between torch and monai coords
S1,S2,S3 = spatial_size
M_to_T = torch.tensor([[0, 0, 2/(S3-1), 0], [0,2/(S2-1),0,0], [2/(S1-1),0,0,0], [0,0,0,1]], dtype=torch.float32).to(device)
T_to_M = torch.linalg.inv(M_to_T)

def get_correct_transform_torch_coords(img_A, img_B):
    
    Ta = img_A.meta['affine'].float()
    Tb = img_B.meta['affine'].float()
    
    # Correct transform in monai coords: Tb * Ta^{-1}
    T = torch.linalg.solve(Tb, Ta)
    
    # Convert to torch coords
    dv = T.device
    T_torch = torch.matmul(M_to_T.to(dv), torch.matmul(T, T_to_M.to(dv)))
    
    return T_torch[:,:3,:]

In [None]:
class LossCurves:
    def __init__(self, name : str, spatial_dims : tuple = None):
        self.name = name
        
        self.epochs =[]
        self.all_losses = []
        self.sim_losses = []
        self.regularization_losses = []
        self.supervised_losses = []
        
        self.clear_buffers()
        
    def clear_buffers(self):
        self.all_losses_buffer = []
        self.sim_losses_buffer = []
        self.supervised_losses_buffer = []
        self.regularization_losses_buffer = []
        
    def add_to_buffer(self, model_output : ModelOutput):
        self.all_losses_buffer.append(model_output.all_loss.item())
        self.sim_losses_buffer.append(model_output.sim_loss.item())
        self.supervised_losses_buffer.append(model_output.supervised_loss.item())
        self.regularization_losses_buffer.append(model_output.regularization_loss.item())

    def aggregate_buffers_for_epoch(self, epoch : int):
        self.epochs.append(epoch)
        self.all_losses.append(np.mean(self.all_losses_buffer))
        self.sim_losses.append(np.mean(self.sim_losses_buffer))
        self.regularization_losses.append(np.mean(self.regularization_losses_buffer))
        self.supervised_losses.append(np.mean(self.supervised_losses_buffer))
        self.clear_buffers()
        
    def plot(self, savepath=None):
        fig, axs = plt.subplots(1,4,figsize = (15,5))
        axs[0].plot(self.epochs, self.all_losses)
        axs[0].set_title(f"{self.name}: overall loss")
        axs[1].plot(self.epochs, self.sim_losses)
        axs[1].set_title(f"{self.name}: similarity loss")
        axs[2].plot(self.epochs, self.regularization_losses, label="regularization loss")
        axs[2].set_title(f"{self.name}: regularization loss")
        axs[3].plot(self.epochs, self.supervised_losses, label="supervised loss")
        axs[3].set_title(f"{self.name}: supervised loss")
        for ax in axs:
            ax.set_xlabel("epoch")
        if savepath is not None:
            plt.savefig(savepath)
        plt.show()

In [None]:
lncc_loss = monai.losses.LocalNormalizedCrossCorrelationLoss(
    spatial_dims=3,
    kernel_size=5,
    smooth_nr = 0,
    smooth_dr = 1e-4
)

In [None]:
dl_train = monai.data.DataLoader(ds_train, shuffle=True, batch_size=1, drop_last=True)
dl_valid = monai.data.DataLoader(ds_valid, shuffle=True, batch_size=2, drop_last=True)
max_epochs = 400
validate_when = lambda e : ((e%5==0) and (e!=0)) or (e==max_epochs-1)

model = AffineRegModel(
    compute_sim_loss = mse_loss,
    down_convolutions=4,
    depth=8,
    max_channels=128,
    init_channels=8,
    spatial_size=spatial_size,
    lambda_reg=10,
    fc_hidden_layers=[512]
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

schedule_lr = True
min_lr=1e-6
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.985)

loss_curves_train = LossCurves("training")
loss_curves_valid = LossCurves("validation", spatial_dims=spatial_size)

In [None]:
# TRAINING

for e in range(max_epochs):
#     current_lr = scheduler.get_last_lr()[0]
    current_lr = optimizer.state_dict()['param_groups'][0]['lr']
    print(f'Epoch {e+1}/{max_epochs} (LR = {current_lr:.1e}):')
    
    # Train
    model.train()
    dl_train_iter = iter(dl_train)
    while True:
        try:
            b1 = next(dl_train_iter)
            b2 = next(dl_train_iter)
        except StopIteration:
            break
        
        optimizer.zero_grad()
        model_output = model(b1['fa'], b2['fa'])
        model_output.all_loss.backward()
        optimizer.step()
        
        loss_curves_train.add_to_buffer(model_output)
        del(model_output)

    if schedule_lr:
        if scheduler.get_last_lr()[0] > min_lr:
            scheduler.step()

    loss_curves_train.aggregate_buffers_for_epoch(e)
    print(f"\tTraining loss: {loss_curves_train.all_losses[-1]:.4f} (sup={loss_curves_train.supervised_losses[-1]:.4f}, sim={loss_curves_train.sim_losses[-1]:.4f}, reg={loss_curves_train.regularization_losses[-1]:.4f})")
    
    # Validate
    if validate_when(e):
        model.eval()
        dl_valid_iter = iter(dl_valid)
        while True:
            try:
                b1 = next(dl_valid_iter)
                b2 = next(dl_valid_iter)
            except StopIteration:
                break

            with torch.no_grad():
                model_output = model(b1['fa'], b2['fa'])
                loss_curves_valid.add_to_buffer(model_output)
        loss_curves_valid.aggregate_buffers_for_epoch(e) 
        print("\tValidation loss:", loss_curves_valid.all_losses[-1])
    
    

In [None]:
# SAVE

loss_curves_train.plot(savepath = footsteps.output_dir + 'loss_plot_train.png')
loss_curves_valid.plot(savepath = footsteps.output_dir + 'loss_plot_valid.png')

with open(footsteps.output_dir + 'loss_curves.p', 'wb') as f:
    pickle.dump([loss_curves_train, loss_curves_valid],f)

torch.save(model.state_dict(), footsteps.output_dir + 'model_state_dict.pth')

In [None]:
# LOAD

experiment_name_to_load = "affine with ICon"
load_dir = os.path.join('results', experiment_name_to_load)

model.load_state_dict(torch.load(os.path.join(load_dir,'model_state_dict.pth')))

with open(os.path.join(load_dir,'loss_curves.p'), 'rb') as f:
    loss_curves_train, loss_curves_valid = pickle.load(f)

In [None]:
import util

# Choose whether to view performance on training or on validation data
dl = monai.data.DataLoader(ds_train, shuffle=True, batch_size=1, drop_last=True)
it = iter(dl)
d1 = next(it)
d2 = next(it)

img_A = d1['fa']
img_B = d2['fa']

model.eval()
with torch.no_grad():
    model_output = model(img_A,img_B, compute_warped_B=True)

img_B_warped = model_output.warped_moving

preview_slices = (80,80,80)

print("moving:")
util.preview_image(img_B[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("warped moving:")
util.preview_image(img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("target:")
util.preview_image(img_A[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("checkerboard of warped moving and target:")
util.preview_checkerboard(img_A[0,0].cpu(), img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("affine transform matrix:")
print(model_output.affine[0].as_tensor().cpu())
print("the correct affine transform matrix:")
print(model_output.true_theta[0].cpu())
true_theta = model_output.true_theta.to(img_B.device)
last_row = torch.tensor([0,0,0,1],device=true_theta.device, dtype=true_theta.dtype).view((1,1,4))
last_row = torch.repeat_interleave(last_row,true_theta.shape[0],dim=0)
true_theta_uncentered = torch.cat([true_theta,last_row],dim=1)
true_theta_uncentered = uncenter_transform(true_theta_uncentered)
img_B_warped_correct = model.affine_transform(img_B,true_theta_uncentered)
print("moving image warped by the correct affine transform")
util.preview_image(img_B_warped_correct[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("sim loss:", model_output.sim_loss.item())
print("regularization loss:", model_output.regularization_loss.item())
print("overall loss:", model_output.all_loss.item())