In [None]:
import os
import torch
import numpy as np
import sys
import time
sys.path.append('./voxelmorph/pytorch/')
import torch.nn as nn
import torch.nn.functional as F
import losses
import nibabel as nib
def gpu_usage():
    print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9))

In [None]:
H=W=D=64
#correlation layer: dense discretised displacements to compute SSD cost volume with box-filter
def correlate(mind_fix,mind_mov,disp_hw,grid_sp,shape):
    H = int(shape[0]); W = int(shape[1]); D = int(shape[2]);

    torch.cuda.synchronize()
    t0 = time.time()
    with torch.no_grad():
        mind_unfold = F.unfold(F.pad(mind_mov,(disp_hw,disp_hw,disp_hw,disp_hw,disp_hw,disp_hw)).squeeze(0),disp_hw*2+1)
        mind_unfold = mind_unfold.view(12,-1,(disp_hw*2+1)**2,W//grid_sp,D//grid_sp)
        

    ssd = torch.zeros((disp_hw*2+1)**3,H//grid_sp,W//grid_sp,D//grid_sp,dtype=mind_fix.dtype, device=mind_fix.device)#.cuda().half()
    ssd_argmin = torch.zeros(H//grid_sp,W//grid_sp,D//grid_sp).long()
    with torch.no_grad():
        for i in range(disp_hw*2+1):
            mind_sum = (mind_fix.permute(1,2,0,3,4)-mind_unfold[:,i:i+H//grid_sp]).pow(2).sum(0,keepdim=True)
            #5,stride=1,padding=2
            #3,stride=1,padding=1
            ssd[i::(disp_hw*2+1)] = F.avg_pool3d(F.avg_pool3d(mind_sum.transpose(2,1),3,stride=1,padding=1),3,stride=1,padding=1).squeeze(1)
        ssd = ssd.view(disp_hw*2+1,disp_hw*2+1,disp_hw*2+1,H//grid_sp,W//grid_sp,D//grid_sp).transpose(1,0).reshape((disp_hw*2+1)**3,H//grid_sp,W//grid_sp,D//grid_sp)
        ssd_argmin = torch.argmin(ssd,0)#
        #ssd = F.softmax(-ssd*1000,0)
    torch.cuda.synchronize()

    t1 = time.time()
    #print(t1-t0,'sec (ssd)')
    #gpu_usage()
    return ssd,ssd_argmin

#solve two coupled convex optimisation problems for efficient global regularisation
def coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,shape):
    H = int(shape[0]); W = int(shape[1]); D = int(shape[2]);

    disp_soft = F.avg_pool3d(disp_mesh_t.view(3,-1)[:,ssd_argmin.view(-1)].reshape(1,3,H//grid_sp,W//grid_sp,D//grid_sp),3,padding=1,stride=1)

    coeffs = torch.tensor([0.003,0.01,0.03,0.1,0.3,1])
    for j in range(6):
        ssd_coupled_argmin = torch.zeros_like(ssd_argmin)
        with torch.no_grad():
            for i in range(H//grid_sp):

                coupled = ssd[:,i,:,:]+coeffs[j]*(disp_mesh_t-disp_soft[:,:,i].view(3,1,-1)).pow(2).sum(0).view(-1,W//grid_sp,D//grid_sp)
                ssd_coupled_argmin[i] = torch.argmin(coupled,0)
            #print(coupled.shape)

        disp_soft = F.avg_pool3d(disp_mesh_t.view(3,-1)[:,ssd_coupled_argmin.view(-1)].reshape(1,3,H//grid_sp,W//grid_sp,D//grid_sp),3,padding=1,stride=1)

    return disp_soft

#enforce inverse consistency of forward and backward transform
def inverse_consistency(disp_field1s,disp_field2s,iter=20):
    #factor = 1
    B,C,H,W,D = disp_field1s.size()
    #make inverse consistent
    with torch.no_grad():
        disp_field1i = disp_field1s.clone()
        disp_field2i = disp_field2s.clone()

        identity = F.affine_grid(torch.eye(3,4).unsqueeze(0),(1,1,H,W,D)).permute(0,4,1,2,3).to(disp_field1s.device).to(disp_field1s.dtype)
        for i in range(iter):
            disp_field1s = disp_field1i.clone()
            disp_field2s = disp_field2i.clone()

            disp_field1i = 0.5*(disp_field1s-F.grid_sample(disp_field2s,(identity+disp_field1s).permute(0,2,3,4,1)))
            disp_field2i = 0.5*(disp_field2s-F.grid_sample(disp_field1s,(identity+disp_field2s).permute(0,2,3,4,1)))

    return disp_field1i,disp_field2i

def combineDeformation3d(disp_1st,disp_2nd,identity):
    disp_composition = disp_2nd + F.grid_sample(disp_1st,disp_2nd.permute(0,2,3,4,1)+identity)
    return disp_composition

device = 'cuda'

def dice_coeff(outputs, labels, max_label):
    dice = torch.FloatTensor(max_label-1).fill_(0)
    for label_num in range(1, max_label):
        iflat = (outputs==label_num).view(-1).float()
        tflat = (labels==label_num).view(-1).float()
        intersection = torch.mean(iflat * tflat)
        dice[label_num-1] = (2. * intersection) / (1e-8 + torch.mean(iflat) + torch.mean(tflat))
    return dice

def convexAdam(img_fixed,img_moving):
    grid_sp = 4
    disp_hw = 8
    #compute MIND descriptors and downsample (using average pooling)
    with torch.no_grad():
        mindssc_fix = losses.MINDSSC(img_fixed.unsqueeze(0).unsqueeze(0).cuda(),1,2).half()#*fixed_mask.cuda().half()#.cpu()
        mindssc_mov = losses.MINDSSC(img_moving.unsqueeze(0).unsqueeze(0).cuda(),1,2).half()#*moving_mask.cuda().half()#.cpu()
        mind_fix = F.avg_pool3d(mindssc_fix,grid_sp,stride=grid_sp)
        mind_mov = F.avg_pool3d(mindssc_mov,grid_sp,stride=grid_sp)

    ssd,ssd_argmin = correlate(mind_fix,mind_mov,disp_hw,grid_sp,(H,W,D))
    disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1)
    disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D))
    scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2
    ssd_,ssd_argmin_ = correlate(mind_mov,mind_fix,disp_hw,grid_sp,(H,W,D))
    disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D))
    disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15)
    disp_hr = F.interpolate(disp_ice.flip(1)*scale*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False)
    grid_sp = 3
    with torch.no_grad():
        patch_mind_fix = F.avg_pool3d(mindssc_fix,grid_sp,stride=grid_sp)
        patch_mind_mov = F.avg_pool3d(mindssc_mov,grid_sp,stride=grid_sp)
    #create optimisable displacement grid
    disp_lr = F.interpolate(disp_hr,size=(H//grid_sp,W//grid_sp,D//grid_sp),mode='trilinear',align_corners=False)
    net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp,W//grid_sp,D//grid_sp),bias=False))
    net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp
    net.cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr=1)
    grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp,W//grid_sp,D//grid_sp),align_corners=False)
    #run Adam optimisation with diffusion regularisation and B-spline smoothing
    lambda_weight = .6# with tps: .5, without:0.7
    for iter in range(40):#80
        optimizer.zero_grad()
        disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1)
        reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\
        lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\
        lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean()
        scale = torch.tensor([(H//grid_sp-1)/2,(W//grid_sp-1)/2,(D//grid_sp-1)/2]).cuda().unsqueeze(0)
        grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float()
        patch_mov_sampled = F.grid_sample(patch_mind_mov.float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=False,mode='bilinear')#,padding_mode='border')
        sampled_cost = (patch_mov_sampled-patch_mind_fix).pow(2).mean(1)*12
        loss = sampled_cost.mean()
        (loss+reg_loss).backward()
        optimizer.step()

    fitted_grid = disp_sample.permute(0,4,1,2,3).detach()
    disp_hr = F.interpolate(fitted_grid*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False)
    disp_smooth = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,3,padding=1,stride=1),3,padding=1,stride=1),3,padding=1,stride=1)
    disp = disp_smooth.cuda().float().permute(0,2,3,4,1)/torch.tensor([H-1,W-1,D-1]).cuda().view(1,1,1,1,3)*2
    disp = disp.flip(4)
    return disp

In [None]:
#folder = '/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/L4_fine_localized_crop/source_training_labeled/'
folder = '/share/data_supergrover1/hansen/temp/crossMoDa/preprocessed_new/resampled/localised_crop/source_training/'
files = sorted(os.listdir(folder))
print(len(files),files[0:4])

In [None]:
import subprocess
source_tumour_left = []
source_tumour_right = []
for f in files:
    #print(f)
    if('Label_l' in f):
        output = subprocess.run(["/share/data_supergrover1/heinrich/c3d",folder+f, "-replace", "2", "0","-voxel-sum" ], capture_output=True)
        if('Voxel Sum' in str(output.stdout)):
            count = (int(str(output.stdout).split('\\n')[0].split(':')[1]))
            if(count>50):
                source_tumour_left.append([f,count])
    if('Label_r' in f):
        output = subprocess.run(["/share/data_supergrover1/heinrich/c3d",folder+f, "-replace", "2", "0","-voxel-sum" ], capture_output=True)
        if('Voxel Sum' in str(output.stdout)):
            count = (int(str(output.stdout).split('\\n')[0].split(':')[1]))
            if(count>50):
                source_tumour_right.append([f,count])
    

In [None]:
print(output.stdout)

In [None]:
print(len(source_tumour_right),len(source_tumour_left))

In [None]:
#folder1 = '/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/L4_fine_localized_crop/__omitted_labels_target_training__/'
#folder1b = '/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/L4_fine_localized_crop/target_training_unlabeled/'
folder1 = '/share/data_supergrover1/hansen/temp/crossMoDa/preprocessed_new/resampled/localised_crop/target_training//'


files1 = sorted(os.listdir(folder1))
print(len(files1),files1[0:3])

In [None]:
import subprocess
target_tumour_left = []
target_tumour_right = []
for f in files1:
    if('Label_l' in f):
        output = subprocess.run(["/share/data_supergrover1/heinrich/c3d",folder1+f, "-replace", "2", "0","-voxel-sum" ], capture_output=True)
        count = (int(str(output.stdout).split('\\n')[0].split(':')[1]))
        if(count>50):
            target_tumour_left.append([f,count])
    if('Label_r' in f):
        output = subprocess.run(["/share/data_supergrover1/heinrich/c3d",folder1+f, "-replace", "2", "0","-voxel-sum" ], capture_output=True)
        count = (int(str(output.stdout).split('\\n')[0].split(':')[1]))
        if(count>50):
            target_tumour_right.append([f,count])
    

In [None]:
print(len(target_tumour_right),len(target_tumour_left))
#!mkdir crossmoda_deeds

In [None]:
dicesi = torch.zeros(30,30,2)
for i in range(30):
    
    #output = torch.zeros(30,128,128,128).short().pin_memory()
    dices = torch.zeros(30,2)
    t0 = time.time()
    for j in range(30):
        base = source_tumour_right[i][0][:-13]


        warped_seg = torch.from_numpy(nib.load('/share/data_supergrover1/heinrich/crossmoda_convex/'+base+str(target_tumour_right[j][0][:17].split('_')[1])+'.nii.gz').get_fdata())


        #disp = convexAdam(imgs_t2[i,crop[0]:-crop[1],48:-16,40:-24].cuda(),imgs_t1[j,crop[0]:-crop[1],48:-16,40:-24].cuda())
        #warped_seg = F.grid_sample(segs_t1[j:j+1,crop[0]:-crop[1],48:-16,40:-24].float().unsqueeze(1).cuda(),\
        #                           F.affine_grid(torch.eye(3,4).cuda().unsqueeze(0),(1,1,H,W,D))+disp.cuda(),mode='nearest')
        #output[j,crop[0]:-crop[1],48:-16,40:-24] = warped_seg.cpu().squeeze().cpu()
        d0 = dice_coeff(segs_t2[i].contiguous(),segs_t1[j].contiguous().cpu(),3)
        d1 = dice_coeff(segs_t2[i].contiguous(),warped_seg.squeeze().long().contiguous().cpu(),3)
        dices[j,0] = d0[0]
        dices[j,1] = d1[0]
        
    #print(i,j,d0[1],d1[1])
    dicesi[i] = dices
    t1 = time.time()
    print(torch.quantile(dices[:,1],q=torch.linspace(0,1,5)))
    print(i,t1-t0,'sec (dice)',dices.mean(0))
    


In [None]:
import matplotlib.pyplot as plt
print(dicesi.shape)
plt.plot(torch.linspace(0,1,30**2),torch.sort(dicesi[:,:,1].reshape(-1))[0],label='convex')
plt.plot(torch.linspace(0,1,30**2),torch.sort(deeds_dice.reshape(-1))[0],label='deeds')
plt.plot(torch.linspace(0,1,30**2),torch.sort(dicesi[:,:,0].reshape(-1))[0],label='before')
plt.legend()
plt.show()

In [None]:
deeds_dice_left = torch.zeros(30,30)
for i in range(30):
    for j in range(30):
        out = 'crossmoda_deeds/F'+target_tumour_left[i][0][:-13]+'_M'+str(source_tumour_left[j][0][:17].split('_')[1])
        
        output = subprocess.run(["./deedsBCV/diceMulti",folder1+target_tumour_left[i][0],out+'_deformed_seg.nii.gz'], capture_output=True)
        value = float(str(output.stdout).split(':')[1].split(',')[0])
        print('dice',value)
        deeds_dice_left[i,j] = value
        

In [None]:
deeds_dice_left = torch.zeros(30,30)
for i in range(30):
    for j in range(30):
        out = 'crossmoda_deeds/F'+target_tumour_left[i][0][:-13]+'_M'+str(source_tumour_left[j][0][:17].split('_')[1])
        cmd = './deedsBCV/deedsBCV -F '+folder1+target_tumour_left[i][0][:-14]+'hrT2_l.nii.gz -M '+folder+source_tumour_left[j][0][:-14]+'ceT1_l.nii.gz -S '+folder+source_tumour_left[j][0]+' -O '+out
        !{cmd}
        cmd2 = 'rm '+out+'_deformed.nii.gz'
        cmd2 = 'rm '+out+'_displacements.dat'

        output = subprocess.run(["./deedsBCV/diceMulti",folder1+target_tumour_left[i][0],out+'_deformed_seg.nii.gz'], capture_output=True)
        value = float(str(output.stdout).split(':')[1].split(',')[0])
        print('dice',value)
        deeds_dice_left[i,j] = value
    torch.save(deeds_dice_left,'crossmoda_deeds/dice_left.pth')



In [None]:
len(target_tumour_left)

In [None]:
deeds_dice = torch.zeros(30,30)
for i in range(30):
    for j in range(30):
        out = 'crossmoda_deeds/F'+target_tumour_right[i][0][:-13]+'_M'+str(source_tumour_right[j][0][:17].split('_')[1])
        cmd = './deedsBCV/deedsBCV -F '+folder1+target_tumour_right[i][0][:-14]+'hrT2_r.nii.gz -M '+folder+source_tumour_right[j][0][:-14]+'ceT1_r.nii.gz -S '+folder+source_tumour_right[j][0]+' -O '+out
        !{cmd}
        cmd2 = 'rm '+out+'_deformed.nii.gz'
        cmd2 = 'rm '+out+'_displacements.dat'

        output = subprocess.run(["./deedsBCV/diceMulti",folder1+target_tumour_right[i][0],out+'_deformed_seg.nii.gz'], capture_output=True)
        value = float(str(output.stdout).split(':')[1].split(',')[0])
        print('dice',value)
        deeds_dice[i,j] = value
    torch.save(deeds_dice,'crossmoda_deeds/dice_right.pth')



In [None]:
print(deeds_dice)

In [None]:
!ls crossmoda_deeds/*

In [None]:
segs_t1 = torch.zeros(30,128,128,128)
imgs_t1 = torch.zeros(30,128,128,128).pin_memory()
segs_t2 = torch.zeros(30,128,128,128)
imgs_t2 = torch.zeros(30,128,128,128).pin_memory()

for i in range(30):
    segs_t1[i] = torch.from_numpy(nib.load(folder+source_tumour_right[i][0]).get_fdata()).float()
    segs_t2[i] = torch.from_numpy(nib.load(folder1+target_tumour_right[i][0]).get_fdata()).float()
    imgs_t1[i] = torch.from_numpy(nib.load(folder+source_tumour_right[i][0][:-14]+'ceT1_r.nii.gz').get_fdata()).float()
    imgs_t2[i] = torch.from_numpy(nib.load(folder1+target_tumour_right[i][0][:-14]+'hrT2_r.nii.gz').get_fdata()).float()
    # imgs_t1[i] = torch.from_numpy(nib.load(folder+source_tumour_left[i][0][:-14]+'ceT1_r.nii.gz').get_fdata()).float()
    # imgs_t2[i] = torch.from_numpy(nib.load(folder1+target_tumour_left[i][0][:-14]+'hrT2_r.nii.gz').get_fdata()).float()
    
        

In [None]:
print(torch.sum(segs_t1),torch.sum(segs_t1[:,8:-56,48:-16,40:-24]))



In [None]:
print(torch.sum(segs_t1),torch.sum(segs_t1[:,56:-8,48:-16,40:-24]))
print(imgs_t1[0,56:-8,48:-16,40:-24].shape)

# crop = [8,56]#left
crop = [56,8]#right
print(torch.sum(segs_t1),torch.sum(segs_t1[:,crop[0]:-crop[1],48:-16,40:-24]))


dicesi = torch.zeros(30,30,2)
for i in range(30):
    output = torch.zeros(30,128,128,128).short().pin_memory()
    dices = torch.zeros(30,2)
    t0 = time.time()
    for j in range(30):
        

        disp = convexAdam(imgs_t2[i,crop[0]:-crop[1],48:-16,40:-24].cuda(),imgs_t1[j,crop[0]:-crop[1],48:-16,40:-24].cuda())
        warped_seg = F.grid_sample(segs_t1[j:j+1,crop[0]:-crop[1],48:-16,40:-24].float().unsqueeze(1).cuda(),\
                                   F.affine_grid(torch.eye(3,4).cuda().unsqueeze(0),(1,1,H,W,D))+disp.cuda(),mode='nearest')
        output[j,crop[0]:-crop[1],48:-16,40:-24] = warped_seg.cpu().squeeze().cpu()
        d0 = dice_coeff(segs_t2[i,crop[0]:-crop[1],48:-16,40:-24].contiguous(),segs_t1[j,crop[0]:-crop[1],48:-16,40:-24].contiguous().cpu(),3)
        d1 = dice_coeff(segs_t2[i,crop[0]:-crop[1],48:-16,40:-24].contiguous(),warped_seg.squeeze().long().contiguous().cpu(),3)
        dices[j,0] = d0[0]
        dices[j,1] = d1[0]
        
    #print(i,j,d0[1],d1[1])
    dicesi[i] = dices
    t1 = time.time()
    print(torch.quantile(dices[:,1],q=torch.linspace(0,1,5)))
    print(i,t1-t0,'sec (reg)')
    t0 = time.time()
    for j in range(30):
        moving = source_tumour_right[j][0][:-14]
        moving = re.findall(r'\d{1,3}', moving)[0]
        fixed = str(target_tumour_right[i][0][:17].split('_')[1])
        base = f"crossmoda_F{fixed}r_M{moving}r"
        fp = '/share/data_supergrover1/weihsbach/shared_data/tmp/tmp_convex/'+base+'.nii.gz'
        nib.save(nib.Nifti1Image(output[j].numpy(),np.eye(4)),fp)
    t1 = time.time()
    print(t1-t0,'sec (save)')




In [None]:
torch.save({'dice':dicesi,'source_tumour_right':source_tumour_right,'target_tumour_right':target_tumour_right},'crossmoda_convex/dice_files_right.pth')

In [None]:
print('before',dicesi[:,:,0].max(0)[0].mean())
print('before',dicesi[:,:,0].max(1)[0].mean())
print('after',dicesi[:,:,1].max(0)[0].mean())
print('after',dicesi[:,:,1].max(1)[0].mean())

In [None]:
print('before',torch.quantile(dicesi[:,:,0],dim=1,q=torch.linspace(0,1,5)).mean(1))
print('after',torch.quantile(dicesi[:,:,1],dim=1,q=torch.linspace(0,1,5)).mean(1))

#print(dicesi[:,:,1].max(1)[0])



In [None]:
disp = convexAdam(img3d[fix,56:-8,48:-16,40:-24],img3d_t2[i,56:-8,48:-16,40:-24])
warped_one_hot = F.grid_sample(F.one_hot(seg3d_t2[i,56:-8,48:-16,40:-24].long(),3).permute(3,0,1,2).float().cuda().view(1,3,H,W,D),F.affine_grid(torch.eye(3,4).cuda().unsqueeze(0),(1,1,H,W,D))+disp.cuda(),mode='bilinear')
combined += warped_one_hot.squeeze().cpu()
warped_seg = warped_one_hot.argmax(1).squeeze()
mov_all[i,56:-8,48:-16,40:-24] = warped_seg
before[i] = (dice_coeff(seg3d[fix,56:-8,48:-16,40:-24],seg3d_t2[i,56:-8,48:-16,40:-24],3))
after[i] = (dice_coeff(seg3d[fix,56:-8,48:-16,40:-24],warped_seg.squeeze().cpu(),3))

In [None]:
#print(len(target_tumour_right),len(target_tumour_left))
print(source_tumour_right[i][0][:-13]+'.nii.gz')