In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
import time
import json
import nibabel as nib
import sys
import glob
import csv
import struct
import json
from scipy.ndimage import distance_transform_edt


from networks import RegModel,ShufflePermutation

from utils import *

if is_notebook():
    from tqdm.notebook import tqdm, trange
else:
    from tqdm import tqdm, trange


#from meidic_vtach_utils.run_on_recommended_cuda import get_cuda_environ_vars as get_vars
#os.environ.update(get_vars('*'))

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]="10"

skip_stage = []

In [None]:
def load_dataset_seg(task_name, split ='train'):
    if task_name == 'AbdomenCTCT':
        with open('data_compressed/AbdomenCTCT/AbdomenCTCT_dataset.json') as f:
            dataset=json.load(f)

        image_shape=dataset['tensorImageShape']['0']
        num_labels=len(dataset['labels']['0'])
        H,W,D=image_shape

        if split == 'train':
            mode = 'Tr'
            cases = [str(x).zfill(4) for x in sorted(list(range(31)[2::3])+list(range(31)[3::3]))]
            num_train=len(cases)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = F.avg_pool3d(F.one_hot(torch.from_numpy(nib.load(f'data_compressed/AbdomenCTCT/labelsTr/AbdomenCTCT_{case}_0000.nii.gz').get_fdata()).long().unsqueeze(0),num_labels).permute(0,4,1,2,3),2)
        
        elif split == 'val':
            mode = 'Tr'
            cases = [str(x).zfill(4) for x in list(range(31)[1::3])]
            num_train=len(cases)
            labels = torch.zeros(num_train,1,H,W,D).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = torch.from_numpy(nib.load(f'data_compressed/AbdomenCTCT/labelsTr/AbdomenCTCT_{case}_0000.nii.gz').get_fdata()).long().unsqueeze(0)

        elif split == 'val_pred':
            mode = 'Tr'
            cases = [str(x).zfill(4) for x in list(range(31)[1::3])]
            num_train=len(cases)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = F.avg_pool3d(F.one_hot(torch.from_numpy(nib.load(f'data_compressed/AbdomenCTCT/predictedlabelsTr/AbdomenCTCT_{case}_0000.nii.gz').get_fdata()).long().unsqueeze(0),num_labels).permute(0,4,1,2,3),2)

    elif task_name == 'AMOS':
        with open('data_compressed/AMOS/AMOS_dataset.json') as f:
            dataset=json.load(f)

        image_shape=dataset['tensorImageShape']['0']
        num_labels=len(dataset['labels']['0'])
        H,W,D=image_shape
        lst = ['0507', '0508', '0510', '0514', '0517', '0518', '0522', '0530', '0532', '0538', '0540', '0541', '0551', '0555', '0557', '0571', '0578', '0580', '0582', '0584', '0585', '0586', '0587', '0588', '0589', '0590', '0592', '0594', '0595', '0596', '0597', '0599']
        if split == 'train':
            mode = 'Tr'
            cases = lst [:-6]
            num_train=len(cases)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = F.avg_pool3d(F.one_hot(torch.from_numpy(nib.load(f'data_compressed/AMOS/labelsTr/AMOS_{case}_0000.nii.gz').get_fdata()).long().unsqueeze(0),num_labels).permute(0,4,1,2,3),2)
        
        elif split == 'val':
            mode = 'Tr'
            cases = lst [-6:]
            num_train=len(cases)
            labels = torch.zeros(num_train,1,H,W,D).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = torch.from_numpy(nib.load(f'data_compressed/AMOS/labelsTr/AMOS_{case}_0000.nii.gz').get_fdata()).long().unsqueeze(0)

        elif split == 'val_pred':
            mode = 'Tr'
            cases = lst [-6:]
            num_train=len(cases)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            mappingtensor = torch.LongTensor([ 0, 12, 11,  8,  0,  0,  0, 13,  5,  4,  9,  3,  2,  6, 10,  1,  7])
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = F.avg_pool3d(F.one_hot(mappingtensor[torch.from_numpy(nib.load(f'data_compressed/AMOS_pred/predictedlabelsTr/AMOS_{case}_0000.nii.gz').get_fdata()).long()].unsqueeze(0),num_labels).permute(0,4,1,2,3),2)

    
    elif task_name == 'TS_Skeletal':

        list_data=sorted(glob.glob(f'data_compressed/TS_Skeletal/labels/*nii.gz'))
        if split == 'train':
            cases = list_data#list_data[:27]
            num_train,num_labels,H,W,D = (len(cases),29,256,160,256)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = F.avg_pool3d(F.one_hot(torch.from_numpy(nib.load(case).get_fdata()).long().unsqueeze(0),num_labels).permute(0,4,1,2,3),2)

        elif split == 'val':
            cases = list_data[27:]
            num_train,num_labels,H,W,D = (len(cases),29,256,160,256)
            labels = torch.zeros(num_train,1,H,W,D).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = torch.from_numpy(nib.load(case).get_fdata()).long().unsqueeze(0)

        elif split == 'val_pred':
            cases = list_data[27:]
            num_train,num_labels,H,W,D = (len(cases),29,256,160,256)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                labels[i] = F.avg_pool3d(F.one_hot(torch.from_numpy(nib.load(case).get_fdata()).long().unsqueeze(0),num_labels).permute(0,4,1,2,3),2)

    elif task_name == 'SilverCorpus':
        mappingtensor =torch.LongTensor([0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4,
        4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0])

        mapping_dict={
        "inferior_vena_cava" : 2,
        "aorta" : 1,
        "pulmonary_artery" : 3,
        "heart_myocardium": 4,
        "heart_atrium_left": 4,
        "heart_ventricle_left": 4,
        "heart_atrium_right": 4,
        "heart_ventricle_right":4,
        "lung_upper_lobe_left":5,
        "lung_lower_lobe_left":6,
        "lung_upper_lobe_right":7,
        "lung_middle_lobe_right":8,
        "lung_lower_lobe_right":9
        }

        if split == 'train':
            cases = [12, 36, 49, 68, 97]+[13, 24, 26, 27, 28, 29, 31, 32, 33, 34, 35, 37, 38, 40, 41, 42, 43, 44, 45, 46, 47, 48, 51, 53, 54, 55, 57, 58, 59, 60, 62, 64, 65, 69, 70, 71, 72, 75, 76, 77, 78, 84, 91, 93, 98, 141]
            num_train,num_labels,H,W,D = (len(cases),mappingtensor.max().item()+1,256,192,288)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                case_ = f'data_compressed/SilverCorpus/silver{str(case).zfill(3)}.nii.gz'
                labels[i] = F.avg_pool3d(F.one_hot(mappingtensor[torch.from_numpy(nib.load(case_).get_fdata()).long().unsqueeze(0)],num_labels).permute(0,4,1,2,3),2)

        if split == 'val':
            cases = [12, 36, 49, 68, 97]
            num_train,num_labels,H,W,D = (len(cases),mappingtensor.max().item()+1,256,192,288)
            labels = torch.zeros(num_train,num_labels,H,W,D).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                case_ = f'data_compressed/SilverCorpus/silver{str(case).zfill(3)}.nii.gz'
                labels[i] = mappingtensor[torch.from_numpy(nib.load(case_).get_fdata()).long().unsqueeze(0)]
            
        if split == 'val_pred':
            cases = [12, 36, 49, 68, 97]
            num_train,num_labels,H,W,D = (len(cases),mappingtensor.max().item()+1,256,192,288)
            labels = torch.zeros(num_train,num_labels,H//2,W//2,D//2).pin_memory()
            for i,case in tqdm(enumerate(cases),total=num_train):
                case_ = f'data_compressed/SilverCorpus/silver{str(case).zfill(3)}.nii.gz'
                labels[i] = F.avg_pool3d(F.one_hot(mappingtensor[torch.from_numpy(nib.load(case_).get_fdata()).long().unsqueeze(0)],num_labels).permute(0,4,1,2,3),2)


    print('loaded', task_name, split)   
    return labels, (num_train,num_labels,H,W,D)

def get_val_pairs(B):
    ii_all = torch.empty(0,2).long()
    for i in range(B):
        for j in range(B):
            if(i<j):
                ii_all = torch.cat((ii_all,torch.tensor([i,j]).long().view(1,2)),0)
    return ii_all


In [None]:
def AdamReg(mind_fix,mind_mov,dense_flow):

    
    if(dense_flow.shape[-1]==3):
        dense_flow = dense_flow.permute(0,4,1,2,3)
    
    H,W,D = dense_flow[0,0].shape
    
    disp_hr = dense_flow.cuda().flip(1)*torch.tensor([H-1,W-1,D-1]).cuda().view(1,3,1,1,1)/2
    with torch.enable_grad(): 
        grid_sp = 2

       
        disp_lr = F.interpolate(disp_hr,size=(H//grid_sp,W//grid_sp,D//grid_sp),mode='trilinear',align_corners=False)
        net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp,W//grid_sp,D//grid_sp),bias=False))
        net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp
        net.cuda()
        optimizer = torch.optim.Adam(net.parameters(), lr=1)
        grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp,W//grid_sp,D//grid_sp),align_corners=False)
        #run Adam optimisation with diffusion regularisation and B-spline smoothing
        lambda_weight = .65
        for iter in range(50):
            optimizer.zero_grad()
            disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),\
                                       3,stride=1,padding=1).permute(0,2,3,4,1)
            reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\
            lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\
            lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean()
            scale = torch.tensor([(H//grid_sp-1)/2,(W//grid_sp-1)/2,(D//grid_sp-1)/2]).cuda().unsqueeze(0)
            grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float()
            patch_mov_sampled = F.grid_sample(mind_mov.cuda().float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda()\
                                              ,align_corners=False,mode='bilinear')
            sampled_cost = (patch_mov_sampled-mind_fix.cuda()).pow(2).mean(1)*12
            loss = sampled_cost.mean()
            (loss+reg_loss).backward()
            optimizer.step()
        fitted_grid = disp_sample.permute(0,4,1,2,3).detach()
        disp_hr = F.interpolate(fitted_grid*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False)
        
    disp_smooth = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,3,padding=1,stride=1),3,padding=1,stride=1),3,padding=1,stride=1)


    disp_hr = torch.flip(disp_smooth/torch.tensor([H-1,W-1,D-1]).view(1,3,1,1,1).cuda()*2,[1])
    return disp_hr

In [None]:
tasks=['SilverCorpus','TS_Skeletal']

results={}
for task in tasks:
    results[task]={}

train_label = []; train_shapes = []
iterate_list = [train_label, train_shapes]

for t in tasks:
    data=load_dataset_seg(t,split='train')
    for x, lst in zip(data, iterate_list):
        lst.append(x)


val_label = []; val_shapes = []
iterate_list = [val_label, val_shapes]

for t in tasks:
    data=load_dataset_seg(t,split='val')
    for x, lst in zip(data, iterate_list):
        lst.append(x)


pred_label = []; pred_shapes = []
iterate_list = [pred_label, pred_shapes]
for t in tasks:
    data=load_dataset_seg(t,split='val_pred')
    for x, lst in zip(data, iterate_list):
        lst.append(x)




In [None]:
edt_train_label=[]
for dataset,task in enumerate(tasks):
    B,num_classes,H,W,D = train_label[dataset].shape

    tmp=torch.zeros(B,1,H,W,D)
    for i in tqdm(range(B)):
        for ii in range(num_classes):
            edt = torch.from_numpy(distance_transform_edt((train_label[dataset][i,ii]).cpu().squeeze())).float()
            tmp[i,0]+=edt#(7-nn.ELU()(7-edt))/7
    edt_train_label.append(tmp)


In [None]:
if 1 not in skip_stage:
    model = RegModel(1)

    model.cuda()
    repeats = 3
    iterations =2000
    run_dataset=torch.randint(0,len(tasks),[repeats,iterations])

    for repeat in range(repeats):

        optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
        scaler = torch.cuda.amp.GradScaler()

        ramp_up = torch.sigmoid(torch.linspace(-5,25,iterations))

        run_loss = torch.zeros(iterations,2)
        run_val = torch.zeros(iterations//10)
        t0 = time.time()
        
        with trange(iterations) as pbar:
            for i in pbar:     
                optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    
                    dataset=run_dataset[repeat,i]
                    B,C,H,W,D = train_shapes[dataset]
                    ii = torch.randperm(B)[:2]

                    grid = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//2,W//2,D//2)).cuda()
                    
                    affine = F.affine_grid((.07*ramp_up[i]*torch.randn(1,3,4)+torch.eye(3,4).unsqueeze(0)).cuda(),(1,1,H//2,W//2,D//2),align_corners=False)
                    fix_aff = F.grid_sample(edt_train_label[dataset][ii[:1]].cuda(),affine,align_corners=False)
                    fix_aff_img =  F.grid_sample(train_label[dataset][ii[:1]].cuda(),affine,align_corners=False)

                    disp = model(fix_aff,edt_train_label[dataset][ii[1:2]].cuda(),level=int(i>iterations//2)+1)
                    warped_img = F.grid_sample(train_label[dataset][ii[1:2]].cuda(),grid+disp.permute(0,2,3,4,1),padding_mode='border',align_corners=False)
                    loss = (1-soft_dice(fix_aff_img,warped_img)).mean()
                scaler.scale(loss).backward()
                #scaler.unscale_(optimizer)
                #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                scaler.step(optimizer)
                scaler.update()
                run_loss[i,0] = (1-soft_dice(fix_aff_img,train_label[dataset][ii[1:2]].cuda())).mean()
                run_loss[i,1] = loss.item()

                str1 = f"d: {dataset.item()}, iter: {i}, loss: {'%0.3f'%run_loss[i-25:i-1,1].mean()} | {'%0.3f'%run_loss[i-25:i-1,0].mean()}, runtime: {'%0.1f'%(time.time()-t0)} sec, gpumem/max: {'%0.2f'%(torch.cuda.max_memory_allocated()*1e-9)} GB"
                pbar.set_description(str1)
                        #print('dice',dice_val)

        print(f"Repeat {repeat}, Last 100 Losses {'%0.3f'%run_loss[i-101:i-1,1].mean()}")
        plt.plot(F.avg_pool1d(F.avg_pool1d(run_loss.view(1,1,-1),15,stride=3),15,stride=1).squeeze())

    torch.save(model,f'unpaired_models/v2_{tasks[0]}_{tasks[1]}_edt_model.pth')

In [None]:
if 2 not in skip_stage:
    model=torch.load(f'unpaired_models/v2_{tasks[0]}_{tasks[1]}_edt_model.pth').cuda()

    adapt=ShufflePermutation().cuda()
    model.train(); adapt.train()


    repeats = 3
    iterations = 2000
    run_dataset=torch.randint(0,len(tasks),[repeats,iterations])


    for repeat in range(repeats):

        optimizer = torch.optim.Adam(list(model.parameters())+list(adapt.parameters()),lr=0.001)
        scaler = torch.cuda.amp.GradScaler()

        ramp_up = torch.sigmoid(torch.linspace(-5,25,iterations))

        run_loss = torch.zeros(iterations,2)
        run_val = torch.zeros(iterations//10)
        t0 = time.time()
        
        with trange(iterations) as pbar:
            for i in pbar:     
                optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    
                    dataset=run_dataset[repeat,i]

                    B,C,H,W,D = train_shapes[dataset]
                    ii = torch.randperm(B)[:2]
                    grid = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//2,W//2,D//2)).cuda()
                    
                    affine = F.affine_grid((.07*ramp_up[i]*torch.randn(1,3,4)+torch.eye(3,4).unsqueeze(0)).cuda(),(1,1,H//2,W//2,D//2),align_corners=False)
                    fix_aff = F.grid_sample((train_label[dataset][ii[:1]].cuda()),affine,align_corners=False)
                    fix_aff_img =  F.grid_sample(train_label[dataset][ii[:1]].cuda(),affine,align_corners=False)

                    disp = model(adapt(fix_aff),adapt(train_label[dataset][ii[1:2]].cuda()),level=int(i>iterations//2)+1)
                    #warped = F.grid_sample(train_label[dataset][ii[1:2]].cuda(),grid+disp.permute(0,2,3,4,1),padding_mode='border',align_corners=False)
                    warped_img = F.grid_sample(train_label[dataset][ii[1:2]].cuda(),grid+disp.permute(0,2,3,4,1),padding_mode='border',align_corners=False)
                    loss = (1-soft_dice(fix_aff_img,warped_img)).mean()
                scaler.scale(loss).backward()
                #scaler.unscale_(optimizer)
                #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                scaler.step(optimizer)
                scaler.update()
                run_loss[i,0] = (1-soft_dice(fix_aff_img,train_label[dataset][ii[1:2]].cuda())).mean()
                run_loss[i,1] = loss.item()

                str1 = f"d: {dataset.item()}, iter: {i}, loss: {'%0.3f'%run_loss[i-25:i-1,1].mean()} | {'%0.3f'%run_loss[i-25:i-1,0].mean()}, runtime: {'%0.1f'%(time.time()-t0)} sec, gpumem/max: {'%0.2f'%(torch.cuda.max_memory_allocated()*1e-9)} GB"
                pbar.set_description(str1)

        print(f"Repeat {repeat}, Last 100 Losses {'%0.3f'%run_loss[i-101:i-1,1].mean()}")
        plt.plot(F.avg_pool1d(F.avg_pool1d(run_loss.view(1,1,-1),15,stride=3),15,stride=1).squeeze())

    torch.save([model,adapt],f'unpaired_models/v2_{tasks[0]}_{tasks[1]}_adapt_model.pth')
