In [None]:
import math
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
from scipy.sparse import csr_matrix, csgraph
from skimage.morphology import convex_hull_image
import struct
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
import warnings

from evaluation import *
from mindssc import *
from thin_plate_spline import *
from utils import *

warnings.filterwarnings('ignore')
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

In [None]:
data_dir = '/path/to/data/l2r_2020/task_03/'
device = 'cuda'

pairs = torch.tensor(np.genfromtxt(os.path.join(data_dir, 'pairs_val.csv'), delimiter=',')[1:], dtype=torch.long)
cases = pairs.unique()
ind = torch.zeros(cases.max()+1, dtype=torch.long)
ind[cases] = torch.arange(cases.shape[0])

D, H, W = 192, 160, 256
mind_delta = 3
mind_sigma = 3
mind_patch_step = 3
mind_patch_radius = 3
mind_stride = 2
N = 2048
d = 2
k = 10
k1 = 12
alpha = 2.5
alpha1 = 100
sps = torch.linspace(.3,.03,20)
lbp_iter = 3

In [None]:
def load_case(case, plot=False):
    img_path = os.path.join(data_dir, 'Training/img', 'img{:04d}.nii.gz'.format(case))
    img = torch.from_numpy(nib.load(img_path).get_fdata().astype(np.float32))
    
    seg_path = os.path.join(data_dir, 'Training/label', 'label{:04d}.nii.gz'.format(case))
    seg = torch.from_numpy(nib.load(seg_path).get_fdata().astype(np.int64))
    
    mask = seg > 0
    mask = F.avg_pool3d(mask.view(1, 1, D, H, W).float().cuda(), 17, stride=1, padding=8)[0, 0].cpu() > 0.001
    for i in range(D):
        mask[i, :, :] = torch.from_numpy(convex_hull_image(mask[i,:,:]))
    for i in range(H):
        mask[:, i, :] = torch.from_numpy(convex_hull_image(mask[:,i,:]))
    for i in range(W):
        mask[:, :, i] = torch.from_numpy(convex_hull_image(mask[:,:,i]))

    if plot:
        cmap = plt.get_cmap('Set1')
        plt.figure(figsize=(24,8))
        plt.subplot(131)
        plt.imshow(img[D//2, :, :], cmap='gray')
        seg_plot = cmap(seg[D//2, :, :]/13.)
        seg_plot[:, :, 3] = seg[D//2, :, :] != 0
        plt.imshow(seg_plot, alpha=0.5)
        plt.imshow(mask[D//2, :, :], alpha=0.1)
        plt.axis('off')
        
        plt.subplot(132)
        plt.imshow(img[:, H//2, :], cmap='gray')
        seg_plot = cmap(seg[:, H//2, :]/13.)
        seg_plot[:, :, 3] = seg[:, H//2, :] != 0
        plt.imshow(seg_plot, alpha=0.5)
        plt.imshow(mask[:, H//2, :], alpha=0.1)
        plt.axis('off')
        
        plt.subplot(133)
        plt.imshow(img[:, :, W//2], cmap='gray')
        seg_plot = cmap(seg[:, :, W//2]/13.)
        seg_plot[:, :, 3] = seg[:, :, W//2] != 0
        plt.imshow(seg_plot, alpha=0.5)
        plt.imshow(mask[:, :, W//2], alpha=0.1)
        plt.axis('off')
        plt.show()
        
    return img, seg, mask
  
grid = F.affine_grid(torch.eye(3, 4, device='cuda').unsqueeze(0), (1, 1, D//mind_stride, H//mind_stride, W//mind_stride))
mind_patch = torch.stack(torch.meshgrid(torch.arange(0, 2 * mind_patch_radius + 1, mind_patch_step, device='cuda'),
                                        torch.arange(0, 2 * mind_patch_radius + 1, mind_patch_step, device='cuda'),
                                        torch.arange(0, 2 * mind_patch_radius + 1, mind_patch_step, device='cuda')), dim=3).view(1, -1, 3) - mind_patch_radius
mind_patch = flow_pt(mind_patch, (D//mind_stride, H//mind_stride, W//mind_stride), align_corners=True).view(1, 1, -1, 1, 3)
    
imgs = torch.zeros(len(cases), 1, D, H, W).pin_memory()
segs = torch.zeros(len(cases), 1, D, H, W, dtype=torch.long).pin_memory()
masks = torch.zeros(len(cases), 1, D, H, W, dtype=torch.bool).pin_memory()
feats = torch.zeros(len(cases), 12*mind_patch.shape[2], D//mind_stride, H//mind_stride, W//mind_stride).pin_memory()
for i, case in enumerate(cases):
    print('Case {}'.format(case))
    img, seg, mask = load_case(case, plot=True)

    mind = F.grid_sample(mindssc(img.view(1, 1, D, H, W).cuda(), mind_delta, mind_sigma), grid.view(1, 1, 1, -1, 3) + mind_patch).view(12*mind_patch.shape[2], D//mind_stride, H//mind_stride, W//mind_stride).cpu()
    
    imgs[i, 0, :, :, :] = img
    segs[i, 0, :, :, :] = seg
    masks[i, 0, :, :, :] = mask
    feats[i, :, :, : ,:] = mind


In [None]:
def iter_lbp(kpts_fixed, feat_kpts_fixed, feat_moving, sps):
    edges, edges_reverse_idx = lbp_graph(kpts_fixed, k)
    N_edges = edges.shape[0]

    flow = torch.zeros_like(kpts_fixed) 
    for sp in sps:
        candidates = flow.view(1, -1, 1, 3) + (torch.rand((1, N, k1, 3), device=device) - 0.5) * sp
        candidates_edges0 = candidates[0, edges[:, 0], :, :]
        candidates_edges1 = candidates[0, edges[:, 1], :, :]

        feat_kpts_moving = F.grid_sample(feat_moving, kpts_fixed.view(1, 1, -1, 1, 3) + candidates.view(1, 1, -1, k1, 3)).view(1, -1, N, k1).permute(0, 2, 3, 1)
        candidates_cost = alpha*(feat_kpts_fixed.unsqueeze(2)-feat_kpts_moving).pow(2).mean(3)

        messages = torch.zeros((N_edges, k1), device=device)
        temp_messages = torch.zeros((N, k1), device=device)
        for _ in range(lbp_iter):
            multi_data_cost = torch.gather(temp_messages + candidates_cost.view(-1, k1), 0, edges[:,0].view(-1, 1).expand(-1, k1))

            reverse_message = torch.gather(messages, 0, edges_reverse_idx.view(-1, 1).expand(-1, k1))
            multi_data_cost -= reverse_message

            messages = sparse_minconv(multi_data_cost, candidates_edges0, candidates_edges1)

            torch.fill_(temp_messages, 0)
            temp_messages.scatter_add_(0, edges[:,1].view(-1, 1).expand(-1, k1), messages)

        marginals = temp_messages + candidates_cost.view(-1, k1)

        flow = torch.sum(F.softmax(-alpha1 * marginals, 1).unsqueeze(2) * candidates.view(N, k1, 3), 1).view(1, N, 3)

    return flow
        
dice_initial = torch.zeros((len(pairs), 13), device=device)
dice_affine = torch.zeros((len(pairs), 13), device=device)
dice_affine1 = torch.zeros((len(pairs), 13), device=device)
dice = torch.zeros((len(pairs), 13), device=device)
sd_log_j = torch.zeros((len(pairs),), device=device)
runtimes = torch.zeros((len(pairs),), device=device)
runtimes1 = torch.zeros((len(pairs),), device=device)

torch.manual_seed(30)
for i, pair in enumerate(pairs):
    idx_fixed = ind[pair[0]]
    idx_moving = ind[pair[1]]

    img_fixed = imgs[idx_fixed:idx_fixed+1].to(device, non_blocking=True)
    seg_fixed = segs[idx_fixed:idx_fixed+1].to(device, non_blocking=True)
    mask_fixed = masks[idx_fixed:idx_fixed+1].to(device, non_blocking=True)
    feat_fixed = feats[idx_fixed:idx_fixed+1].to(device, non_blocking=True)

    img_moving = imgs[idx_moving:idx_moving+1].to(device, non_blocking=True)
    seg_moving = segs[idx_moving:idx_moving+1].to(device, non_blocking=True)
    mask_moving = masks[idx_moving:idx_moving+1].to(device, non_blocking=True)
    feat_moving = feats[idx_moving:idx_moving+1].to(device, non_blocking=True)
    
    kpts_fixed = random_kpts(mask_fixed, d, num_points=N)
    feat_kpts_fixed = F.grid_sample(feat_fixed, kpts_fixed.view(1, 1, 1, -1, 3)).view(1, -1, N).permute(0, 2, 1)

    torch.cuda.synchronize()
    t0 = time.time()

    flow = iter_lbp(kpts_fixed, feat_kpts_fixed, feat_moving, sps)

    torch.cuda.synchronize()
    t1 = time.time()

    dense_flow = thin_plate_dense(kpts_fixed.cuda(), flow.cuda(), (D, H, W), 3, 0.001)
    seg_moving_warped = F.grid_sample(seg_moving.float().cuda(), F.affine_grid(torch.eye(3,4,device='cuda').unsqueeze(0), (1,1,D,H,W)) + dense_flow, mode='nearest').to(device)
    img_moving_warped = F.grid_sample(img_moving+3024.0005, F.affine_grid(torch.eye(3,4,device='cuda').unsqueeze(0), (1,1,D,H,W)) + dense_flow, mode='bilinear').to(device)-3024.0005
    jac_det = jacobian_determinant(flow_world(dense_flow.view(1, -1, 3), (D//2, H//2, W//2)).view_as(dense_flow)[:, ::2, ::2, ::2, :]).to(device)
    
    dice_initial[i] = dice_coeff(seg_moving, seg_fixed, 13)
    dice[i] = dice_coeff(seg_moving_warped, seg_fixed, 13)
    sd_log_j[i] = torch.log((jac_det + 3).clamp_(0.000000001, 1000000000)).std()
    runtimes[i] = t1-t0

    print('Fixed: {}, Moving: {}'.format(pair[0], pair[1]))
    print('Initial Dice: {:.2f}'.format(dice_initial[i].mean()))
    print('Dice: {:.2f}'.format(dice[i].mean()))
    print('SDlogJ: {:.2f}'.format(sd_log_j[i]))
    print('Runtime: {:.2f} s'.format(runtimes[i]))
    print()

print('---')
print('Mean Initial Dice: {:.3f}'.format(dice_initial.mean()))
print('Mean Dice: {:.3f}'.format(dice.mean()))
print('Mean SDlogJ: {:.3f}'.format(sd_log_j.mean()))
print('Mean Runtime: {:.3f} s'.format(runtimes.mean()))
