Experimenting with affine only registration using deep learning, applied to the OASIS-1 dataset.

In [None]:
import monai
import torch
import itk
import numpy as np
import matplotlib.pyplot as plt
import random
import glob
import os.path
import tempfile
from collections import defaultdict
import itertools

from utils import (
    preview_image, preview_3D_vector_field, preview_3D_deformation,
    jacobian_determinant, plot_against_epoch_numbers
)

monai.config.print_config()

# Set deterministic training for reproducibility
monai.utils.set_determinism(seed=2938649572)

In [None]:
root_dir = os.environ.get("MONAI_DATA_DIRECTORY")
if not root_dir:
    raise Exception("Need to set MONAI_DATA_DIRECTORY env var")
data_dir = os.path.join(root_dir, "OASIS-1")
print(f"Root directory: {root_dir}")
print(f"Data directory: {data_dir}")

In [None]:
image_path_expression = "PROCESSED/MPRAGE/T88_111/OAS1_*_MR*_mpr_n*_anon_111_t88_masked_gfc.img"

# Expect either of two reasonable ways of organizing extracted data:
# 1) <data_dir>/disc1/OAS1_0031_MR1/...
# 2) <data_dir>/OAS1_0031_MR1/...
image_paths = glob.glob(os.path.join(data_dir, '*', image_path_expression))
image_paths += glob.glob(os.path.join(data_dir, '*/*', image_path_expression))

In [None]:
image_paths_train, image_paths_valid = \
    monai.data.utils.partition_dataset(image_paths, ratios=(8, 2))

In [None]:
data_pairs_train = [{'fixed':img0, 'moving':img1} for img0 in image_paths_train for img1 in image_paths_train]
data_pairs_valid = [{'fixed':img0, 'moving':img1} for img0 in image_paths_valid for img1 in image_paths_valid]

In [None]:
S = 64
resize = (S,S,S)
device=torch.device('cuda')

In [None]:
# Control the overall scale of affine transform
a=0.3

keys=['fixed', 'moving']

rand_affine_params = {
    'prob':1.,
    'mode': 'bilinear',
    'padding_mode': 'zeros',
    'spatial_size':resize,
    'cache_grid':True,
    'rotate_range': (a*np.pi/2,)*3,
    'shear_range': (0,)*6, # no shearing
    'translate_range': (a*S/16,)*3,
    'scale_range': (a*0.2,)*3,
}

transform = monai.transforms.Compose([
    monai.transforms.LoadImageD(reader='itkreader', keys=keys),
    monai.transforms.TransposeD(keys=keys, indices=(2,1,0)),
    monai.transforms.AddChannelD(keys=keys),
    monai.transforms.ToTensorD(keys=keys),
    monai.transforms.ResizeD(spatial_size=resize, keys=keys),
    monai.transforms.ToDeviceD(keys=keys, device=device),
    monai.transforms.RandAffineD(keys='fixed', **rand_affine_params),
    monai.transforms.RandAffineD(keys='moving', **rand_affine_params),
    monai.transforms.ConcatItemsD(keys=keys, name='fm', dim=0)
])

In [None]:
# Supress the many warnings related to depracation of the Analyze file format
# (without this, we would see warnings when the LoadImage transform calls itk to load Analyze files)
itk.ProcessObject.SetGlobalWarningDisplay(False)

In [None]:
# We will redefine these as CacheDatasets later
dataset_pairs_train = monai.data.Dataset(data = data_pairs_train, transform=transform)
dataset_pairs_valid = monai.data.Dataset(data = data_pairs_valid, transform=transform)

In [None]:
# Examine

d = random.choice(dataset_pairs_train)
preview_image(d['fm'][0].cpu())

In [None]:
# originally stolen from:
# https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

class AffineNet(torch.nn.Module):
    def __init__(self):
        super(AffineNet, self).__init__()

        self.dropout_p = 0.2
        
        # Spatial transformer localization-network
        self.loc_C1 = 8
        self.loc_C2 = 16
        self.loc_C3 = 32
        self.localization = torch.nn.Sequential( # say input is shape (B,2,S,S,S)
            torch.nn.Conv3d(2, self.loc_C1, kernel_size=5), # -4
            torch.nn.MaxPool3d(2, stride=2), # /2
            torch.nn.BatchNorm3d(self.loc_C1), torch.nn.Dropout(p=self.dropout_p),
            torch.nn.PReLU(),
            torch.nn.Conv3d(self.loc_C1, self.loc_C2, kernel_size=3), # -2
            torch.nn.MaxPool3d(2, stride=2), # /2
            torch.nn.BatchNorm3d(self.loc_C2), torch.nn.Dropout(p=self.dropout_p),
            torch.nn.PReLU(), # output should be (B,2,S',S',S') where S' = ((S-4)/2 - 2)/2
            torch.nn.Conv3d(self.loc_C2, self.loc_C3, kernel_size=3), # -2
#             torch.nn.MaxPool3d(2, stride=2), # /2
            torch.nn.BatchNorm3d(self.loc_C3), torch.nn.Dropout(p=self.dropout_p),
            torch.nn.PReLU(),
        )
        
        self.S0 = ((S-4)//2 - 2)//2 - 2 # Spatial size expected after self.localization

        # Regressor for the affine transform parameters (a 4 by 3 matrix)
        self.fc_loc = torch.nn.Sequential(
            torch.nn.Linear(self.loc_C3 * self.S0**3, 4*3),
#             torch.nn.PReLU(),
#             torch.nn.Linear(32, 4 * 3)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[-1].weight.data.zero_()
        self.fc_loc[-1].bias.data.copy_(torch.tensor([1,0,0,0, 0,1,0,0, 0,0,1,0], dtype=torch.float))

    # Spatial transformer network forward function
    def forward(self, x):
        xs = self.localization(x)
        assert(len(set(xs.shape[2:]))==1 and xs.shape[2]==self.S0)
        xs = xs.view(-1, self.loc_C3 * self.S0**3 )
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 3, 4)

        grid = torch.nn.functional.affine_grid(theta, x.size(), align_corners=False)
        x = torch.nn.functional.grid_sample(x[:,[1]], grid, align_corners=False)

        return theta, x


In [None]:
def make_affine_net():
    return AffineNet()

model = make_affine_net().to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f"model has {num_params} parameters")

In [None]:
# Examine

d=random.choice(dataset_pairs_train)
fixed = d['fixed'].unsqueeze(0)
moving = d['moving'].unsqueeze(0)

print("fixed:"); preview_image(fixed[0,0].cpu())
print("moving:"); preview_image(moving[0,0].cpu())

theta, warped = model(d['fm'].unsqueeze(0))
print("warped:"); preview_image(warped[0,0].detach().cpu())
print("theta:"); print(theta.detach().cpu().numpy())

In [None]:
def get_perfect_transform_monai_coords(d):
    phi1 = [t for t in d['fixed_transforms'] if t['class']=='RandAffined'][0]['extra_info']['affine']
    phi2 = [t for t in d['moving_transforms'] if t['class']=='RandAffined'][0]['extra_info']['affine']
    return torch.linalg.solve(phi2, phi1)
 
N = torch.tensor([[0, 0, 2/(S-1), 0], [0,2/(S-1),0,0], [2/(S-1),0,0,0], [0,0,0,1]], dtype=torch.float32).to(device)
def get_perfect_transform_torch_coords(d):
    theta = get_perfect_transform_monai_coords(d).to(device)
    # We need to convert monai coords, which are based on grid values in [-(size-1)/2, (size-1)/2]
    # to torch coords, which have different index ordering and are based on grid values in [-1,1]
    return torch.matmul(N, torch.matmul(theta, torch.linalg.inv(N)))

In [None]:
lncc_loss = monai.losses.LocalNormalizedCrossCorrelationLoss(
    spatial_dims=3,
    kernel_size=7,
    kernel_type='rectangular',
    reduction="mean",
    smooth_nr = 1e-6, # Make sure to make smooth_nr quite a bit smaller than smooth_dr!
    smooth_dr = 1e-3, # Don't make this too small, for the sake of numerical stability
)
mse_loss = torch.nn.MSELoss()

In [None]:
# Test out by eyeballing and checking similarity losses
d=random.choice(dataset_pairs_train)
print("fixed:"); preview_image(d['fixed'][0].cpu())
print("moving:"); preview_image(d['moving'][0].cpu())

theta_monai = get_perfect_transform_monai_coords(d)
theta_torch = get_perfect_transform_torch_coords(d).unsqueeze(0).type(torch.float32).to(device)
theta_torch = theta_torch[:,:3]


perfectly_transformed_moving =\
    monai.transforms.Affine(affine=theta_monai,padding_mode='zeros')(d['moving'])[0]
print('moving warped by perfect transform:'); preview_image(perfectly_transformed_moving[0].cpu())
print("perfect transform in monai coords:"); print(theta_monai)
l1 = mse_loss(perfectly_transformed_moving.unsqueeze(0), fixed).item()
l2 = lncc_loss(perfectly_transformed_moving.unsqueeze(0), fixed).item()
print("losses:",l1,l2)

x=d['fm'].unsqueeze(0)
print("perfect transform again, torch coords:"); print(theta_torch)
grid = torch.nn.functional.affine_grid(theta_torch, x[:,[1]].size(), align_corners=False)
x = torch.nn.functional.grid_sample(x[:,[1]], grid, align_corners=False)
print('moving warped by perfect transform, using torch coords this time:'); preview_image(x[0,0].cpu())
l1 = mse_loss(x, fixed).item()
l2 = lncc_loss(x, fixed).item()
print("losses:",l1,l2)

In [None]:
dataset_pairs_train = monai.data.CacheDataset(data = data_pairs_train, transform=transform, cache_num=256)
# dataset_pairs_valid = monai.data.CacheDataset(data = data_pairs_valid, transform=transform, cache_num=256)

In [None]:
dataloader_pairs_train = monai.data.DataLoader(
    dataset_pairs_train,
    batch_size=4,
    num_workers=0,
    shuffle=True
)
model = make_affine_net().to(device) # Reinitialize weights (convenient to do that in this cell)
opt = torch.optim.Adam(model.parameters(), lr=1e-6)

losses_to_plot = []
print_loss_every = 1
batches_per_epoch = 20

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=1e-7)

In [None]:
losses = []
for epoch in range(200):
    model.train()
    batch_number = 0
    for batch in dataloader_pairs_train:
        if batch_number >= batches_per_epoch: break
        batch_number += 1
        opt.zero_grad()

        theta, warped = model(batch['fm'])
#         grid = torch.nn.functional.affine_grid(theta, moving.size(), align_corners=False)
#         warped = torch.nn.functional.grid_sample(moving, grid, align_corners=False)
        
#         loss = mse_loss(warped , fixed) + lambda_kpvar * (fixed_trcov.mean() + moving_trcov.mean())

        # Drop the last row from these transforms-- it is always 0,0,0,1
        # (This leaves us with 3x4 matrices, so it matches theta's dimensions)
        perfect_theta = torch.stack([get_perfect_transform_torch_coords(d) for d in monai.data.decollate_batch(batch)])[:,:3].to(device)

        loss = ((theta-perfect_theta)**2).sum()
         
        loss.backward()
        opt.step()
        losses.append(loss.item())
    if epoch % print_loss_every == 0 :
        mean_loss = np.mean(losses)
        print(f'{epoch}: {mean_loss}')
        losses_to_plot.append(mean_loss)
        losses = []

In [None]:
plt.plot(losses_to_plot)

In [None]:
run_id = '002'
save_path = f'model_{run_id}.pth'
if os.path.exists(save_path):
    raise Exception("change run_id before saving")
torch.save(model.state_dict(), save_path)

In [None]:
run_id = '001'
load_path = f'model_{run_id}.pth'
if not os.path.exists(save_path):
    raise Exception(f"model {save_path} not found")
model.load_state_dict(torch.load(load_path))

In [None]:
# comaprison with ants

import ants,time

model.eval()
d=random.choice(dataset_pairs_train)
fixed = d['fixed'].unsqueeze(0)
moving = d['moving'].unsqueeze(0)        
start_time = time.perf_counter()
with torch.no_grad():
    theta, warped = model(d['fm'].unsqueeze(0))
my_time = time.perf_counter() - start_time
print(theta.detach().cpu())
print("warped moving image:")
preview_image(warped[0,0].cpu())
print("target image:")
preview_image(fixed[0,0].cpu())
print("original moving image:")
preview_image(moving[0,0].cpu())
loss = mse_loss(warped,fixed)
print("my mse loss:",loss.item())

print("ants warped image:")
ants_fixed = ants.from_numpy(fixed.cpu().numpy()[0,0])
ants_moving = ants.from_numpy(moving.cpu().numpy()[0,0])
start_time = time.perf_counter()
ants_reg = ants.registration(ants_fixed, ants_moving, type_of_transform='Affine')
ants_time = time.perf_counter() - start_time
preview_image(ants_reg['warpedmovout'].numpy())
loss = mse_loss(torch.tensor(ants_reg['warpedmovout'].numpy()).unsqueeze(0).unsqueeze(0), fixed.cpu())
print("ants mse loss:",loss.item())

print(f"My time: {my_time}, ants time: {ants_time}")

print("For reference, the known perfect transform:")
perfect_theta = get_perfect_transform_torch_coords(d).unsqueeze(0).type(torch.float32).to(device)
perfect_theta = perfect_theta[:,:3]
print(perfect_theta)
x=d['fm'].unsqueeze(0)
grid = torch.nn.functional.affine_grid(perfect_theta, x[:,[1]].size(), align_corners=False)
x = torch.nn.functional.grid_sample(x[:,[1]], grid, align_corners=False)
print('moving warped by perfect transform:'); preview_image(x[0,0].cpu())
loss = mse_loss(x, fixed)
print('perfect transform mse loss:', loss.item())