In [None]:
import os
import matplotlib.pyplot as plt
import nibabel as nib
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import warnings
import cc3d
import numpy as np
from scipy import ndimage

warnings.filterwarnings('ignore')

In [None]:
DATA_DIR = './data_3d/' # define data directory
CASES = ['-', 'TCGA-B8-5158' , '-', 'TCGA-B8-5545', '-', 'TCGA-B8-5551', '-', 'TCGA-BP-5006', '-', 'TCGA-DD-A1EI', '-', 'TCGA-DD-A4NJ', '-', 'TCGA-G7-7502', '-', 'TCGA-G7-A8LC']
FOLD = 0
if FOLD == 0:
    TEST_CASES = [0, 2, 4, 6, 8, 10, 12, 14]
elif FOLD == 1:
    TEST_CASES = [1, 3, 5, 7, 9, 11, 13, 15]
TRAIN_CASES = [i for i in range(len(CASES)) if not i in TEST_CASES]
D, H, W = 192, 160, 192
dtype = torch.float32
device = 'cuda'

In [None]:
def parameter_count(model):
    print('# parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad))

    
def dice_coeff(outputs, labels, max_label):
    dice = torch.zeros(max_label-1)
    for label in range(1, max_label):
        iflat = (outputs==label).reshape(-1).float()
        tflat = (labels==label).reshape(-1).float()
        intersection = torch.mean(iflat * tflat)
        dice[label-1] = (2. * intersection) / (1e-8 + torch.mean(iflat) + torch.mean(tflat))
    return dice


def find_rigid_3d(x, y):
    x_mean = x[:, :3].mean(0)
    y_mean = y[:, :3].mean(0)
    u, s, v = torch.svd(torch.matmul((x[:, :3]-x_mean).t(), (y[:, :3]-y_mean)))
    m = torch.eye(v.shape[0], v.shape[0]).to(x.device)
    m[-1,-1] = torch.det(torch.matmul(v, u.t()))
    rotation = torch.matmul(torch.matmul(v, m), u.t())
    translation = y_mean - torch.matmul(rotation, x_mean)
    T = torch.eye(4).to(x.device)
    T[:3,:3] = rotation
    T[:3, 3] = translation
    return T


def generate_random_rigid_3d(strength=.3):
    x = torch.randn(12,3).to(device)
    y = x + strength*torch.randn(12,3).to(device)
    return find_rigid_3d(x, y)


def compute_datacost_grid(mask_fix, mind_fix, mind_mov, grid_step, disp_radius, disp_step, beta=15):
    _, _, D, H, W = mask_fix.shape

    grid_pts = F.affine_grid(.925 * torch.eye(3,4).unsqueeze(0).to(device), (1, 1, H//grid_step, W//grid_step, D//grid_step), align_corners=True).view(1,1,1,-1,3)
    mask_bg = F.grid_sample(mask_fix, grid_pts, align_corners=True)
    grid_pts = grid_pts[:, :, :, mask_bg.view(-1)>0.5, :]

    cost = ssd(grid_pts.view(1, -1, 3), mind_fix, mind_mov, (D,H,W), disp_radius, disp_step, disp_radius+1)
    disp = torch.stack(torch.meshgrid(torch.arange(- disp_step * disp_radius, disp_step * disp_radius  + 1, disp_step),
                                      torch.arange(- disp_step * disp_radius, disp_step * disp_radius  + 1, disp_step),
                                      torch.arange(- disp_step * disp_radius, disp_step * disp_radius  + 1, disp_step))).permute(1, 2, 3, 0).contiguous().view(1, -1, 3).float()
    disp = (disp.flip(-1) * 2 / (torch.tensor([W, H, D]) - 1)).to(device)
    
    ssd_val, ssd_idx = torch.min(cost.squeeze(), 1)
    idx_best = torch.sort(ssd_val, dim=0, descending=False)[1][:grid_pts.shape[3]//2]
    disp_best = torch.sum(torch.softmax(-beta*cost.squeeze(0).unsqueeze(2),1) * disp, 1)
    disp_best = disp_best[idx_best,:]
    fixed_pts = torch.cat((grid_pts[0,0,0,idx_best,:], torch.ones(idx_best.size(0),1).to(device)),1)
    moving_pts = torch.cat((grid_pts[0,0,0,idx_best,:] + disp_best, torch.ones(idx_best.size(0),1).to(device)),1)
    return fixed_pts,moving_pts


def least_trimmed_rigid(fixed_pts, moving_pts, iter=5):
    idx = torch.arange(fixed_pts.shape[0]).to(fixed_pts.device)
    for i in range(iter):
        x = find_rigid_3d(fixed_pts[idx,:], moving_pts[idx,:]).t()
        residual = torch.sqrt(torch.sum(torch.pow(moving_pts - torch.mm(fixed_pts, x), 2), 1))
        _, idx = torch.topk(residual, fixed_pts.shape[0]//2, largest=False)
    return x.t()


def ssd(kpts_fixed, feat_fixed, feat_moving, orig_shape, disp_radius=16, disp_step=2, patch_radius=2, alpha=1.5, unroll_factor=50):
    _, N, _ = kpts_fixed.shape
    device = kpts_fixed.device
    D, H, W = orig_shape
    C = feat_fixed.shape[1]
    dtype = feat_fixed.dtype
    
    patch_step = disp_step # same stride necessary for fast implementation
    patch = torch.stack(torch.meshgrid(torch.arange(0, 2 * patch_radius + 1, patch_step),
                                       torch.arange(0, 2 * patch_radius + 1, patch_step),
                                       torch.arange(0, 2 * patch_radius + 1, patch_step))).permute(1, 2, 3, 0).contiguous().view(1, 1, -1, 1, 3).float() - patch_radius
    patch = (patch.flip(-1) * 2 / (torch.tensor([W, H, D]) - 1)).to(dtype).to(device)
    
    patch_width = round(patch.shape[2] ** (1.0 / 3))
    
    if patch_width % 2 == 0:
        pad = [(patch_width - 1) // 2, (patch_width - 1) // 2 + 1]
    else:
        pad = [(patch_width - 1) // 2, (patch_width - 1) // 2]
    
    disp = torch.stack(torch.meshgrid(torch.arange(- disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)), (disp_step * (disp_radius + ((pad[0] + pad[1]) / 2))) + 1, disp_step),
                                      torch.arange(- disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)), (disp_step * (disp_radius + ((pad[0] + pad[1]) / 2))) + 1, disp_step),
                                      torch.arange(- disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)), (disp_step * (disp_radius + ((pad[0] + pad[1]) / 2))) + 1, disp_step))).permute(1, 2, 3, 0).contiguous().view(1, 1, -1, 1, 3).float()
    disp = (disp.flip(-1) * 2 / (torch.tensor([W, H, D]) - 1)).to(dtype).to(device)
    
    disp_width = disp_radius * 2 + 1
    ssd = torch.zeros(1, N, disp_width ** 3).to(device)
    split = np.array_split(np.arange(N), unroll_factor)
    for i in range(unroll_factor):
        feat_fixed_patch = F.grid_sample(feat_fixed, kpts_fixed[:, split[i], :].view(1, -1, 1, 1, 3).to(dtype) + patch, padding_mode='border', align_corners=True)
        feat_moving_disp = F.grid_sample(feat_moving, kpts_fixed[:, split[i], :].view(1, -1, 1, 1, 3).to(dtype) + disp, padding_mode='border', align_corners=True)        
        corr = F.conv3d(feat_moving_disp.view(1, -1, disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1]), feat_fixed_patch.view(-1, 1, patch_width, patch_width, patch_width), groups=C * split[i].shape[0]).view(C, split[i].shape[0], -1)
        patch_sum = (feat_fixed_patch ** 2).squeeze(0).squeeze(3).sum(dim=2, keepdims=True)
        disp_sum = (patch_width ** 3) * F.avg_pool3d((feat_moving_disp ** 2).view(C, -1, disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1]), patch_width, stride=1).view(C, split[i].shape[0], -1)
        ssd[0, split[i], :] = ((- 2 * corr + patch_sum + disp_sum)).sum(0)
    
    ssd *= (alpha / (patch_width ** 3))
    
    return ssd

In [None]:
def load_case(case, plot=False):
    img_fix = torch.from_numpy(nib.load(os.path.join(DATA_DIR, '{}_MR2mm_crop.nii.gz'.format(CASES[case]))).get_fdata()).to(dtype)
    img_fix -= img_fix.mean()
    img_fix /= img_fix.std()
    img_mov = (torch.from_numpy(nib.load(os.path.join(DATA_DIR, '{}_CT2mm_crop.nii.gz'.format(CASES[case]))).get_fdata()).to(dtype).clip_(-1000, 1500)+1000)/2500
    seg_fix = torch.from_numpy(nib.load(os.path.join(DATA_DIR, '{}_MR2mm_segcrop.nii.gz'.format(CASES[case]))).get_fdata()).long()
    seg_mov = torch.from_numpy(nib.load(os.path.join(DATA_DIR, '{}_CT2mm_segcrop.nii.gz'.format(CASES[case]))).get_fdata()).long()
    
    mask_fix = img_fix < -0.25
    mask_fix = cc3d.connected_components(mask_fix.numpy()) == 0
    mask_fix = ndimage.binary_erosion(ndimage.binary_dilation(mask_fix, iterations=5), iterations=5)
    mask_fix = torch.from_numpy(mask_fix)
    mask_mov = img_mov < 0.05
    mask_mov = cc3d.connected_components(mask_mov.numpy()) == 0
    mask_mov = ndimage.binary_erosion(ndimage.binary_dilation(mask_mov, iterations=5), iterations=5)
    mask_mov = torch.from_numpy(mask_mov)
    
    if plot:
        cmap = plt.get_cmap('Set1')
        plt.figure(figsize=(16,8))
        plt.subplot(121)
        plt.imshow(img_fix[:, :, W//2], cmap='gray')
        seg_fix_plot = cmap(seg_fix[:, :, W//2]/5.)
        seg_fix_plot[:, :, 3] = seg_fix[:, :, W//2] != 0
        plt.imshow(seg_fix_plot, alpha=0.5)
        plt.axis('off')
        plt.subplot(122)
        plt.imshow(img_mov[:, :, W//2], cmap='gray')
        seg_mov_plot = cmap(seg_mov[:, :, W//2]/5.)
        seg_mov_plot[:, :, 3] = seg_mov[:, :, W//2] != 0
        plt.imshow(seg_mov_plot, alpha=0.5)
        plt.axis('off')
        plt.show()
        
    return img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov

dice_all = 0
for case in TRAIN_CASES:
    print('Case: ', CASES[case])
    img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov = load_case(case, plot=False)
    dice = dice_coeff(seg_fix, seg_mov, 5)
    dice_all += dice
    print('Initial Dice: {:.2f}, {:.2f}, {:.2f}, {:.2f} (mean: {:.2f})'.format(*(dice.tolist()), dice.mean().item()))
    print('--')
    print()
dice_all /= len(TRAIN_CASES)
print('Initial Dice (all): {:.2f}, {:.2f}, {:.2f}, {:.2f} (mean: {:.2f})'.format(*(dice_all.tolist()), dice_all.mean().item()))

In [None]:
imgs_fix1_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).float().pin_memory()
imgs_mov2_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).float().pin_memory()
imgs_mov3_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).float().pin_memory()
segs_fix1_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).int().pin_memory()
segs_mov2_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).int().pin_memory()
masks_fix1_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).bool().pin_memory()
R21s = torch.zeros(len(TRAIN_CASES), 8, 4, 4).float().pin_memory()
R23s = torch.zeros(len(TRAIN_CASES), 8, 4, 4).float().pin_memory()

torch.manual_seed(60)

for i, case in enumerate(TRAIN_CASES):
    print('process case', i)
    img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov = load_case(case)
    img_fix = img_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    img_mov = img_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    seg_fix = seg_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    seg_mov = seg_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    mask_fix = mask_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    
    imgs_mov2_train[i:i+1] = img_mov
    segs_mov2_train[i:i+1] = seg_mov
    
    for j in range(8):
        R = generate_random_rigid_3d()
        grid = F.affine_grid(R[:3,:4].unsqueeze(0), (1,1,D,H,W))
        img_fix_ = F.grid_sample(img_fix, grid, padding_mode='border')
        seg_fix_ = F.grid_sample(F.one_hot(seg_fix[0, 0]).permute(3, 0, 1, 2).unsqueeze(0).float(), grid).argmax(1, keepdim=True).int()
        mask_fix_ = F.grid_sample(mask_fix.float(), grid)>0.5
        
        imgs_fix1_train[i:i+1, j:j+1] = img_fix_
        segs_fix1_train[i:i+1, j:j+1] = seg_fix_
        masks_fix1_train[i:i+1, j:j+1] = mask_fix_
        R21s[i:i+1, j:j+1] = R
        
        R = generate_random_rigid_3d()
        grid = F.affine_grid(R[:3,:4].unsqueeze(0), (1,1,D,H,W))
        img_mov_ = F.grid_sample(img_mov, grid, padding_mode='border')
        seg_mov_ = F.grid_sample(F.one_hot(seg_mov[0, 0]).permute(3, 0, 1, 2).unsqueeze(0).float(), grid).argmax(1, keepdim=True).int()
        
        imgs_mov3_train[i:i+1, j:j+1] = img_mov_
        R23s[i:i+1, j:j+1] = R

In [None]:
grid_step = 12
disp_radius = 4
disp_step = 5
beta=150

In [None]:
class ModalityNet(nn.Module):
    def __init__(self, base):
        super(ModalityNet, self).__init__()
        
        base = 8
        
        self.conv1 = nn.Sequential(
            nn.Conv3d(1, base, 3, padding=1, bias=False),
            nn.InstanceNorm3d(base),
            nn.LeakyReLU(),
            nn.Conv3d(base, base, 3, padding=1, bias=False),
            nn.InstanceNorm3d(base),
            nn.LeakyReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv3d(base, base*2, 3, stride=2, padding=1, bias=False),
            nn.InstanceNorm3d(base*2),
            nn.LeakyReLU(),
            nn.Conv3d(base*2, base*2, 3, padding=1, bias=False),
            nn.InstanceNorm3d(base*2),
            nn.LeakyReLU()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x
    
class SharedNet(nn.Module):
    def __init__(self, base, out_channels):
        super(SharedNet, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv3d(base*2, base*2, 3, padding=1, bias=False),
            nn.InstanceNorm3d(base*2),
            nn.LeakyReLU(),
            nn.Conv3d(base*2, base*2, 3, padding=1, bias=False),
            nn.InstanceNorm3d(base*2),
            nn.LeakyReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv3d(base*2, base*4, 3, stride=2, padding=1, bias=False),
            nn.InstanceNorm3d(base*4),
            nn.LeakyReLU(),
            nn.Conv3d(base*4, base*4, 3, padding=1, bias=False),
            nn.InstanceNorm3d(base*4),
            nn.LeakyReLU()
        )
        
        self.conv3 = nn.Conv3d(base*4, out_channels, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x
    
class FeatureNet(nn.Module):
    def __init__(self):
        super(FeatureNet, self).__init__()
        
        base = 8
        out_channels = 16
        
        self.modality1_net = ModalityNet(base)
        self.modality2_net = ModalityNet(base)
        self.shared_net = SharedNet(base, out_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        x = self.modality1_net(x)
        y = self.modality2_net(y)
        x = self.shared_net(x)
        y = self.shared_net(y)
        return self.sigmoid(x), self.sigmoid(y)
    


In [None]:
def predict(net, img_fix, img_mov, mask_fix):
    feat_fix, feat_mov = net(img_fix, img_mov)
    fixed_pts, moving_pts = compute_datacost_grid(mask_fix.float(), feat_fix, feat_mov, grid_step, disp_radius, disp_step, beta)
    R = least_trimmed_rigid(fixed_pts, moving_pts)
    return R

num_epochs = 100
init_lr = 0.001

net = FeatureNet().to(device)
parameter_count(net)

optimizer = optim.Adam(net.parameters(), lr=init_lr)

lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=init_lr/(10**2))

criterion = nn.MSELoss()

losses = torch.zeros(num_epochs)

for epoch in range(num_epochs):
    net.train()

    torch.cuda.synchronize()
    t0 = time.time()
    running_loss = 0
    rand_idx = torch.randperm(len(TRAIN_CASES))
    for idx in rand_idx:
        optimizer.zero_grad()
        
        rand_idx1 = torch.randint(8, (1,))[0]
        rand_idx2 = torch.randint(8, (1,))[0]
        img_fix1 = imgs_fix1_train[idx:idx+1, rand_idx1:rand_idx1+1].to(device,non_blocking=True)
        img_mov2 = imgs_mov2_train[idx:idx+1].to(device,non_blocking=True)
        img_mov3 = imgs_mov3_train[idx:idx+1, rand_idx2:rand_idx2+1].to(device,non_blocking=True)
        seg_fix1 = segs_fix1_train[idx:idx+1, rand_idx1:rand_idx1+1].long().to(device,non_blocking=True)
        seg_mov2 = segs_mov2_train[idx:idx+1].to(device,non_blocking=True).long()
        mask_fix1 = masks_fix1_train[idx:idx+1, rand_idx1:rand_idx1+1].to(device,non_blocking=True)
        R23 = R23s[idx, rand_idx2].to(device,non_blocking=True)

        R21 = predict(net, img_fix1, img_mov2, mask_fix1)
        R31 = predict(net, img_fix1, img_mov3, mask_fix1)
        
        R23_ = torch.mm(R21,R31.inverse())
        
        grid23 = F.affine_grid(R23[:3].unsqueeze(0), (1,1,D,H,W))
        grid23_ = F.affine_grid(R23_[:3].unsqueeze(0), (1,1,D,H,W))

        loss = criterion(grid23, grid23_)
      
        if epoch%10==9:
            grid21 = F.affine_grid(R21[:3,:4].unsqueeze(0), (1,1,D,H,W))
            seg_mov2_warped = F.grid_sample(F.one_hot(seg_mov2, 5).view(1, D, H, W, -1).permute(0, 4, 1, 2, 3).float(), grid21, mode='bilinear')
            print('epoch (train): {:02d} -- mean dice case {:01d}: {:.2f}'.format(epoch, idx, dice_coeff(seg_fix1, seg_mov2_warped.argmax(1, keepdim=True), 5).mean().item()))
            
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    running_loss /= len(TRAIN_CASES)
    losses[epoch] = running_loss
    torch.cuda.synchronize()
    t1 = time.time()
    
    lr_scheduler.step()

    print('epoch (train): {:02d} -- loss: {:.4f} -- time(s): {:.1f}'.format(epoch, running_loss, t1-t0))

In [None]:
plt.plot(losses)
torch.save(net.cpu().state_dict(), 'net3d_cycle_fold{}.pth'.format(FOLD))