In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as checkpoint
import numpy as np
import time
import warnings

warnings.filterwarnings('ignore')

In [None]:
FOLD = 0

In [None]:
#load data

#ct_data = torch.load('data2d_ct.pth')
#data_mr = torch.load('data2d_mr.pth')

#ct_imgs_train = torch.clamp(torch.from_numpy(ct_data['imgs']).float()[:,:,32:192+64,16:16+256]+300,0,600).contiguous().cuda()/600
#ct_segs_train = torch.from_numpy(ct_data['segs']).long()[:,32:192+64,16:16+256].contiguous().cuda()

#mr_imgs = torch.clamp(torch.from_numpy(data_mr['imgs'][:8,:,32:192+64,16:16+256]).float(),0,400).contiguous().cuda()/400
#mr_segs = torch.from_numpy(data_mr['segs']).long()[:8,32:192+64,16:16+256].contiguous().cuda()

ct_imgs = torch.zeros(9,1,224,256).cuda()
ct_segs = torch.zeros(9,224,256).cuda()

mr_imgs = torch.zeros(9,1,224,256).cuda()
mr_segs = torch.zeros(9,224,256).cuda()

In [None]:
if FOLD == 0:
    ids_train = [0,1,2,3,4,5]
    ids_test = [6,7,8]
    
if FOLD == 1:
    ids_train = [0,1,2,6,7,8]
    ids_test = [3,4,5]
    
if FOLD == 2:
    ids_train = [3,4,5,6,7,8]
    ids_test = [0,1,2]

ct_imgs_train = ct_imgs[ids_train]
ct_segs_train = ct_segs[ids_train]
mr_imgs_train = mr_imgs[ids_train]
mr_segs_train = mr_segs[ids_train]

ct_imgs_test = ct_imgs[ids_test]
ct_segs_test = ct_segs[ids_test]
mr_imgs_test = mr_imgs[ids_test]
mr_segs_test = mr_segs[ids_test]


print(ct_imgs_train.shape, ct_imgs_test.shape)
print(ct_segs_train.shape, ct_segs_test.shape)

print(mr_imgs_train.shape, mr_imgs_test.shape)
print(mr_segs_train.shape, mr_segs_test.shape)


In [None]:
def parameter_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def dice_coeff(outputs, labels, max_label):
    dice = torch.FloatTensor(max_label-1).fill_(0)
    for label_num in range(1, max_label):
        iflat = (outputs==label_num).view(-1).float()
        tflat = (labels==label_num).view(-1).float()
        intersection = torch.mean(iflat * tflat)
        dice[label_num-1] = (2. * intersection) / (1e-8 + torch.mean(iflat) + torch.mean(tflat))
    return dice


def jacobian_det(est_x2):
    B,C,H,W = est_x2.size()
    est_pix = torch.zeros_like(est_x2)
    est_pix[:,0,:,:] = est_x2[:,0,:,:]*(H-1)/2.0
    est_pix[:,1,:,:] = est_x2[:,1,:,:]*(W-1)/2.0
    gradx = nn.Conv2d(2,2,(3,1),padding=(1,0),bias=False,groups=2)
    gradx.weight.data[:,0,:,0] = torch.tensor([-0.5,0,0.5]).view(1,3).repeat(2,1)
    gradx.to(est_x2.device)
    grady = nn.Conv2d(2,2,(1,3),padding=(0,1),bias=False,groups=2)
    grady.weight.data[:,0,0,:] = torch.tensor([-0.5,0,0.5]).view(1,3).repeat(2,1)
    grady.to(est_x2.device)
    with torch.no_grad():
        J1 = gradx(est_pix)
        J2 = grady(est_pix)
    J = (J1[:,0,:,:]+1)*(J2[:,1,:,:]+1)-(J1[:,1,:,:])*(J2[:,0,:,:])
    return J


def Correlation(pad_size,kernel_size,max_displacement,stride1,stride2,corr_multiply):
    disp_hw = max_displacement
    corr_unfold = torch.nn.Unfold((disp_hw+1,disp_hw+1),dilation=(stride2,stride2),padding=disp_hw//2)

    def applyCorr(feat1,feat2):
        B,C,H,W = feat1.size()
        return torch.mean(corr_unfold(feat2).view(B,C,-1,H,W)*(feat1).unsqueeze(2),1)

    return applyCorr

In [None]:
class MultimodalNet(nn.Module):
    def __init__(self):
        super(MultimodalNet, self).__init__()
        self.identity = F.affine_grid(torch.eye(2,3).cuda().unsqueeze(0),(1,1,224,256),align_corners=False)

        self.Y1 = nn.Sequential(nn.Conv2d(1,32,5,stride=2,padding=2),nn.InstanceNorm2d(32),nn.ReLU(),\
                                nn.Conv2d(32,32,3,stride=2,padding=1),nn.InstanceNorm2d(32),nn.ReLU(),\
                                nn.Conv2d(32,48,3,stride=1,padding=1),nn.InstanceNorm2d(48),nn.ReLU())
        self.Y2 = nn.Sequential(nn.Conv2d(1,32,5,stride=2,padding=2),nn.InstanceNorm2d(32),nn.ReLU(),\
                                nn.Conv2d(32,32,3,stride=2,padding=1),nn.InstanceNorm2d(32),nn.ReLU(),\
                                nn.Conv2d(32,48,3,stride=1,padding=1),nn.InstanceNorm2d(48),nn.ReLU())
        
        self.corr = Correlation(pad_size=14,kernel_size=1,max_displacement=14,stride1=1,stride2=1,corr_multiply=1)
        
        self.reg = nn.Sequential(nn.Conv2d(225+32,128,3,stride=2,padding=1),nn.BatchNorm2d(128),nn.ReLU(),\
                                 nn.Conv2d(128,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(),\
                                 nn.Conv2d(64,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64,2,1),\
                                 nn.AvgPool2d(3,stride=1,padding=1),nn.AvgPool2d(3,stride=1,padding=1),\
                                 nn.Upsample(scale_factor=4,mode='bicubic',align_corners=False),\
                                 nn.AvgPool2d(3,stride=1,padding=1),nn.AvgPool2d(3,stride=1,padding=1),
                                 nn.Upsample(scale_factor=2,mode='bicubic',align_corners=False))

    def forward(self, x1,x2, swap):
        if(swap):
            y1 = self.Y2(x1)
            y2 = self.Y1(x2)
        else:
            y1 = self.Y1(x1)
            y2 = self.Y2(x2)
        corr_tensor = torch.cat((checkpoint(self.corr,y1[:,:32],y2[:,:32]),y1[:,32:],y2[:,32:]),1)
        field = checkpoint(self.reg,corr_tensor)
        return field

In [None]:
net = MultimodalNet()
net.cuda()
print('#parameters: ', parameter_count(net))

smooth = nn.Sequential(nn.AvgPool2d(7,stride=1,padding=3),nn.AvgPool2d(7,stride=1,padding=3),nn.AvgPool2d(5,stride=1,padding=2),nn.Upsample(scale_factor=4,mode='bicubic',align_corners=False),\
                       nn.AvgPool2d(9,stride=1,padding=4),nn.AvgPool2d(9,stride=1,padding=4),nn.Upsample(scale_factor=2,mode='bicubic',align_corners=False))


H,W = ct_imgs_train.shape[-2:]
B = 4
n_iter = 7500
alpha_ = torch.linspace(.9,.6,n_iter)
beta_ = 1
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

identity = F.affine_grid(torch.eye(2,3).cuda().unsqueeze(0),(1,1,H,W),align_corners=False)
run_dice = torch.zeros(n_iter)
run_dice0 = torch.zeros(n_iter)
run_jacdet = torch.zeros(n_iter)
cycle_loss = torch.zeros(n_iter)
t0 = time.time()


for i in range(n_iter):
    optimizer.zero_grad()
    idx1 = torch.randperm(len(ids_train))[:B]
    idx2 = torch.randperm(len(ids_train))[:B]
    t1 = ct_imgs_train[idx1]
    t2 = mr_imgs_train[idx2]
    seg_t1 = ct_segs_train[idx1]
    seg_t2 = mr_segs_train[idx2]
    

    # residual prediction
    with torch.no_grad():
        field12_ = net(t1,t2,False).detach()
        field21_ = net(t2,t1,True).detach()
        
    #synthetic field
    synth13 = beta_*field12_+smooth(alpha_[i]*torch.randn(B,2,H//8,W//8).cuda())
    t3 = F.grid_sample(t1,synth13.permute(0,2,3,1)+identity,align_corners=False,padding_mode='border')
    synth24 = beta_*field21_+smooth(alpha_[i]*torch.randn(B,2,H//8,W//8).cuda())
    t4 = F.grid_sample(t2,synth24.permute(0,2,3,1)+identity,align_corners=False,padding_mode='border')

    #two cycles 12 + 23 = 13 and 21 + 14 = 24 (four estimated multimodal registrations, two known monomodal ones)
    field12 = net(t1,t2,False)
    field23 = net(t2,t3,True)
    field21 = net(t2,t1,True)
    field14 = net(t1,t4,False)
    combi1223 = F.grid_sample(field12,field23.permute(0,2,3,1)+identity,align_corners=False)+field23
    combi2114 = F.grid_sample(field21,field14.permute(0,2,3,1)+identity,align_corners=False)+field14
    cycle1_loss = nn.MSELoss()(synth13,combi1223)
    cycle2_loss = nn.MSELoss()(synth24,combi2114)
    loss = cycle1_loss + cycle2_loss
    
    with torch.no_grad():
        warped = F.grid_sample(seg_t2.float().unsqueeze(1),field21.permute(0,2,3,1)+identity,align_corners=False,mode='nearest',padding_mode='border').squeeze().long()
        run_dice[i] = dice_coeff(seg_t1,warped,8)[torch.Tensor([0,1,3,4,5,6]).long()].mean().cpu()
        run_dice0[i] = dice_coeff(seg_t1,seg_t2,8)[torch.Tensor([0,1,3,4,5,6]).long()].mean().cpu()
        J = jacobian_det(field12)
        run_jacdet[i] = J.std().cpu()
    
    cycle_loss[i] = loss.item()
    loss.backward()
    optimizer.step() 
    
    if(i%100==19):
        print(i,'dice',run_dice[i-18:i].mean().item(),'before',run_dice0[i-18:i].mean().item(),'jacdet',run_jacdet[i-18:i].mean().item(),'cycle_loss',cycle_loss[i-18:i].mean().item(),'t',time.time()-t0)

In [None]:
n_test = len(ids_test)

d0_mean = torch.zeros(n_test**2)
d1_mean = torch.zeros(n_test**2)
d0 = torch.zeros(n_test**2,6)
d1 = torch.zeros(n_test**2,6)

print('test FOLD', FOLD)
idx = 0
for i in range(n_test):
    for j in range(n_test):
        t1 = ct_imgs_test[i:i+1].cuda()
        t2 = mr_imgs_test[j:j+1]
        seg_t1 = ct_segs_test[i:i+1]
        seg_t2 = mr_segs_test[j:j+1]

        with torch.no_grad():
            field21 = net(t2,t1,True)
        
        warped = F.grid_sample(seg_t2.float().unsqueeze(1),field21.permute(0,2,3,1)+identity,align_corners=False,mode='nearest',padding_mode='border').squeeze().long()
        d1[idx,:] = dice_coeff(seg_t1,warped,8)[torch.Tensor([0,1,3,4,5,6]).long()].cpu()
        d0[idx,:] = dice_coeff(seg_t1,seg_t2,8)[torch.Tensor([0,1,3,4,5,6]).long()].cpu()
        d1_mean[idx] = d1[idx,:].mean()
        d0_mean[idx] = d0[idx,:].mean()
        idx+=1
        

print('d0', d0_mean.mean())
print('d1', d1_mean.mean())
print('d0 mean', d0.mean(dim=0))
print('d1 mean', d1.mean(dim=0))

In [None]:
torch.save(net.cpu().state_dict(),'net2d_cycle_fold{}.pth'.format(FOLD))