In [1]:
%matplotlib notebook
from tqdm import tqdm
#import matplotlib.pyplot as plt
# Note! ITK interacts weirdly here.  from lazy_imports import itk does not work.
# Additionally, import itk must occur before lazy_imports for itkwidgets.view (ie itkview) to work.
import itk
#from lazy_imports import itk
from lazy_imports import torch
from lazy_imports import np
from lazy_imports import plt
from lazy_imports import sitk
from lazy_imports import loadmat, savemat
from lazy_imports import sio
from lazy_imports import itkwidgets
from lazy_imports import itkview
from lazy_imports import interactive
from lazy_imports import ipywidgets
from lazy_imports import pv

plt.rcParams["figure.figsize"] = (6, 6) # (w, h)

In [2]:
from util.SplitEbinMetric3D import get_karcher_mean, Squared_distance_Ebin_field, logm_invB_A

In [3]:
from util.diffeo import get_idty_3d, get_gradient_3d, compose_function_3d, phi_pullback_3d, compose_function_in_place_3d

In [4]:
from algo.metricMatching import metric_matching

In [5]:
from disp.vis import show_2d, show_2d_tensors
from disp.vis import vis_tensors, vis_path
from disp.vis import view_3d_tensors, tensors_to_mesh, view_3d_paths, path_to_tube
from data.io import readRaw, ReadScalars, ReadTensors, WriteTensorNPArray, WriteScalarNPArray, readPath3D
from data.convert import GetNPArrayFromSITK, GetSITKImageFromNP

In [None]:
    torch.set_default_tensor_type('torch.FloatTensor')
    file_name = [1,2,4,6]
    input_dir = '/usr/sci/projects/HCP/Kris/NSFCRCNS/TestResults/working_3d_python/'
    output_dir = '/usr/sci/projects/HCP/Kris/NSFCRCNS/TestResults/working_3d_python/Cubic1246AtlasJul16/'
    height, width, depth = 100,100,41
    sample_num = len(file_name)
    tensor_lin_list, tensor_met_list, mask_list, mask_thresh_list, fa_list = [], [], [], [], []
    mask_union = torch.zeros(height, width, depth).float()
    phi_inv_acc_list, phi_acc_list, energy_list = [], [], []
    resume = False
   
    start_iter = 0
    iter_num = 800

    for s in range(len(file_name)):
        tensor_np = sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/cubic{file_name[s]}_scaled_tensors.nhdr'))
        mask_np = sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/cubic{file_name[s]}_filt_mask.nhdr'))
        tensor_lin_list.append(torch.from_numpy(tensor_np).float().permute(3,2,1,0))
    #     create union of masks
        mask_union += torch.from_numpy(mask_np).float().permute(2,1,0)
        mask_list.append(torch.from_numpy(mask_np).float().permute(2,1,0))
    #     rearrange tensor_lin to tensor_met
        tensor_met_zeros = torch.zeros(height,width,depth,3,3,dtype=torch.float32)
        tensor_met_zeros[:,:,:,0,0] = tensor_lin_list[s][0]
        tensor_met_zeros[:,:,:,0,1] = tensor_lin_list[s][1]
        tensor_met_zeros[:,:,:,0,2] = tensor_lin_list[s][2]
        tensor_met_zeros[:,:,:,1,0] = tensor_lin_list[s][1]
        tensor_met_zeros[:,:,:,1,1] = tensor_lin_list[s][3]
        tensor_met_zeros[:,:,:,1,2] = tensor_lin_list[s][4]
        tensor_met_zeros[:,:,:,2,0] = tensor_lin_list[s][2]
        tensor_met_zeros[:,:,:,2,1] = tensor_lin_list[s][4]
        tensor_met_zeros[:,:,:,2,2] = tensor_lin_list[s][5]
    #     balance the background and subject by rescaling
        # tensor_met_zeros = tensor_cleaning(tensor_met_zeros, scale_factor=torch.tensor(1,dtype=torch.float64))
        # fa_list.append(fractional_anisotropy(tensor_met_zeros))
        tensor_met_list.append(torch.inverse(tensor_met_zeros))
        # fore_back_adaptor = torch.ones((height,width,depth))
        fore_back_adaptor = torch.ones((height,width,depth))
        mask_thresh_list.append(fore_back_adaptor)
        tensor_met_list[s] = torch.einsum('ijk...,lijk->ijk...', tensor_met_list[s], mask_thresh_list[s].unsqueeze(0))
    #     initialize the accumulative diffeomorphism    
        if resume==False:
            print('start from identity')
            phi_inv_acc_list.append(get_idty_3d(height, width, depth))
            phi_acc_list.append(get_idty_3d(height, width, depth))
        else:
            print('start from checkpoint')
            phi_inv_acc_list.append(torch.from_numpy(sio.loadmat(f'{output_dir}/{file_name[s]}_{start_iter-1}_phi_inv.mat')['diffeo']).float())
            phi_acc_list.append(torch.from_numpy(sio.loadmat(f'{output_dir}/{file_name[s]}_{start_iter-1}_phi.mat')['diffeo']).float())
            tensor_met_list[s] = phi_pullback_3d(phi_inv_acc_list[s], tensor_met_list[s])
        energy_list.append([])    
        
    mask_union[mask_union>0] = 1


    print(f'Starting from iteration {start_iter} to iteration {iter_num+start_iter}')

    for i in tqdm(range(start_iter, start_iter+iter_num)):
        G = torch.stack(tuple(tensor_met_list))
        dim, sigma, epsilon, iter_num = 3., 0, 4e-3, 1 # epsilon = 3e-3 for orig tensor
        atlas = get_karcher_mean(G, 1./dim)

        phi_inv_list, phi_list = [], []
        for s in range(sample_num):
            energy_list[s].append(torch.einsum("ijk...,lijk->",[(tensor_met_list[s] - atlas)**2, mask_union.unsqueeze(0)]).item())
            old = tensor_met_list[s]
            if tensor_met_list[s].grad is not None:
                print('tensor met', s)
            if atlas.grad is not None:
                print('atlas')
            if mask_union.grad is not None:
                print('mask union')
            tensor_met_list[s], phi, phi_inv = metric_matching(tensor_met_list[s], atlas, height, width, depth, mask_union, iter_num, epsilon, sigma,dim,use_idty=True)
            phi_inv_list.append(phi_inv)
            phi_list.append(phi)
            phi_inv_acc_list[s][:] = compose_function_3d(phi_inv_acc_list[s], phi_inv_list[s])
            phi_acc_list[s][:] = compose_function_3d(phi_list[s], phi_acc_list[s])
            mask_list[s][:] = compose_function_3d(mask_list[s], phi_inv_list[s])
            #compose_function_in_place_3d(phi_inv_acc_list[s], phi_inv_list[s])
            #compose_function_in_place_3d(phi_list[s], phi_acc_list[s])
            #compose_function_in_place_3d(mask_list[s], phi_inv_list[s])
    #         if i%1==0:
    #             plot_diffeo(phi_acc_list[s][1:, 50, :, :], step_size=2, show_axis=True)
    #             plot_diffeo(phi_acc_list[s][:2, :, :, 20], step_size=2, show_axis=True)
    #             plot_diffeo(torch.stack((phi_acc_list[s][0, :, 50, :],phi_acc_list[s][2, :, 50, :]),0), step_size=2, show_axis=True)
                
        '''check point'''
        if i%50==0:
            atlas_lin = np.zeros((6,height,width,depth))
            mask_acc = np.zeros((height,width,depth))
            atlas_inv = torch.inverse(atlas)
            atlas_lin[0] = atlas_inv[:,:,:,0,0]
            atlas_lin[1] = atlas_inv[:,:,:,0,1]
            atlas_lin[2] = atlas_inv[:,:,:,0,2]
            atlas_lin[3] = atlas_inv[:,:,:,1,1]
            atlas_lin[4] = atlas_inv[:,:,:,1,2]
            atlas_lin[5] = atlas_inv[:,:,:,2,2]
            for s in range(sample_num):
                sio.savemat(f'{output_dir}/cubic{file_name[s]}_{i}_phi_inv.mat', {'diffeo': phi_inv_acc_list[s].detach().numpy()})
                sio.savemat(f'{output_dir}/cubic{file_name[s]}_{i}_phi.mat', {'diffeo': phi_acc_list[s].detach().numpy()})
                sio.savemat(f'{output_dir}/cubic{file_name[s]}_{i}_energy.mat', {'energy': energy_list[s]})
    #             plt.plot(energy_list[s])
                #mask_acc += mask_list[s].numpy()
            #mask_acc[mask_acc>0]=1
            sitk.WriteImage(sitk.GetImageFromArray(np.transpose(atlas_lin,(3,2,1,0))), f'{output_dir}/atlas_{i}_tens.nhdr')
            sitk.WriteImage(sitk.GetImageFromArray(np.transpose(mask_union,(2,1,0))), f'{output_dir}/atlas_{i}_mask.nhdr')

    atlas_lin = np.zeros((6,height,width,depth))
    #mask_acc = np.zeros((height,width,depth))

    for s in range(sample_num):
        sio.savemat(f'{output_dir}/{file_name[s]}_phi_inv.mat', {'diffeo': phi_inv_acc_list[s].detach().numpy()})
        sio.savemat(f'{output_dir}/{file_name[s]}_phi.mat', {'diffeo': phi_acc_list[s].detach().numpy()})
        sio.savemat(f'{output_dir}/{file_name[s]}_energy.mat', {'energy': energy_list[s]})
        
        plt.plot(energy_list[s])
        #mask_acc += mask_list[s].numpy()

    atlas = torch.inverse(atlas)
    atlas_lin[0] = atlas[:,:,:,0,0]
    atlas_lin[1] = atlas[:,:,:,0,1]
    atlas_lin[2] = atlas[:,:,:,0,2]
    atlas_lin[3] = atlas[:,:,:,1,1]
    atlas_lin[4] = atlas[:,:,:,1,2]
    atlas_lin[5] = atlas[:,:,:,2,2]
    #mask_acc[mask_acc>0]=1
    sitk.WriteImage(sitk.GetImageFromArray(np.transpose(atlas_lin,(3,2,1,0))), f'{output_dir}/atlas_tens.nhdr')
    sitk.WriteImage(sitk.GetImageFromArray(np.transpose(mask_union,(2,1,0))), f'{output_dir}/atlas_mask.nhdr')


start from identity
start from identity


  0%|          | 0/800 [00:00<?, ?it/s]

start from identity
start from identity
Starting from iteration 0 to iteration 800
0 13667.572265625
0 13455.9111328125
0 4734.90283203125
0 4638.921875


  0%|          | 1/800 [00:14<3:10:51, 14.33s/it]

0 13332.6875
0 13174.046875
0 4575.45751953125
0 4480.81884765625


  0%|          | 2/800 [00:40<4:40:06, 21.06s/it]

0 13047.1953125
0 12927.84765625
0 4437.4296875
0 4343.53125


  0%|          | 3/800 [01:05<5:06:28, 23.07s/it]

0 12793.1943359375
0 12706.2529296875
0 4314.5439453125
0 4220.8603515625


  0%|          | 4/800 [01:31<5:18:49, 24.03s/it]

In [13]:
G.dtype

torch.float32