In [1]:
from glob import glob
import os
import subprocess
import SimpleITK as sitk
import torch
import nibabel as nib
import numpy as np
from monai.networks.blocks import Warp

In [2]:
def convertWarpFromPhysicalToVoxel(warp):
    affine = warp.affine
    header = warp.header
    inv_affine = np.linalg.inv(affine)
    inv_affine[:3,3] = 0
    inv_affine[:3,:3] = inv_affine[:3,:3] * np.diag((-1,-1,1)) 
        
    trans_array = torch.from_numpy(warp.get_fdata()).squeeze()
    warp_voxel = nib.affines.apply_affine(inv_affine, trans_array)
    warp_voxel = torch.from_numpy(warp_voxel).unsqueeze(dim=3).permute(3,4,0,1,2)     
    warp_voxel = nib.Nifti1Image(warp_voxel.numpy(), affine, header)
    
    return warp_voxel

In [3]:
def convertWarpFromVoxelToPhysical(warp):
    affine = warp.affine
    header = warp.header
    affine[:3,3] = 0
    affine[:3,:3] = affine[:3,:3] * np.diag((-1,-1,1)) 
        
    trans_array = torch.from_numpy(warp.get_fdata()).squeeze().permute(1,2,3,0)
    warp_voxel = nib.affines.apply_affine(affine, trans_array)
    warp_voxel = torch.from_numpy(warp_voxel).unsqueeze(dim=3)
    warp_voxel = nib.Nifti1Image(warp_voxel.numpy(), affine, header)
    
    return warp_voxel

In [4]:
def applyMonaiWarp(mri, transform, ref):     
    
    warp = Warp("bilinear", "border")
    
    warp_ants = nib.load(transform)    
    warp_voxel = convertWarpFromPhysicalToVoxel(warp_ants)   
    warp_voxel = torch.from_numpy(warp_voxel.get_fdata())
    
    
    print(warp_ants.shape)
    print(warp_ants.get_fdata()[53,44,101,0,:])
    print(warp_voxel.shape)
    print(warp_voxel[0,:,53,44,101])
    
    image = torch.from_numpy(nib.load(mri).get_fdata()).unsqueeze(dim=0).unsqueeze(dim=0)   
    warped_mri = warp(image, warp_voxel)
    
    affine = nib.load(ref).affine
    header = nib.load(ref).header
    warped_mri_nib = nib.Nifti1Image(warped_mri.squeeze().numpy(), affine, header)
    
    return warped_mri_nib

In [27]:
import scipy
def interpolate_warp(warp_data, point):
    x_range = np.arange(0, warp_data.shape[0], 1)
    y_range = np.arange(0, warp_data.shape[1], 1)
    z_range = np.arange(0, warp_data.shape[2], 1) 
    
    x = point[0]
    y = point[1]
    z = point[2]
    
    interp_x = scipy.interpolate.RegularGridInterpolator((x_range, y_range, z_range), warp_data[:,:,:,0], fill_value=0)
    dx = interp_x(point)
    
    interp_y = scipy.interpolate.RegularGridInterpolator((x_range, y_range, z_range), warp_data[:,:,:,1], fill_value=0)
    dy = interp_y(point)
    
    interp_z = scipy.interpolate.RegularGridInterpolator((x_range, y_range, z_range), warp_data[:,:,:,2], fill_value=0)
    dz = interp_z(point)
    
    
    return dx, dy, dz

In [29]:
ref = "/home/valentini/dev/Mousenet/dataset2/Atlas/Deformable_Feminad_template0_sameaffine.nii.gz"
for dataset in ["Feminad", "Painfact", "IRIS"]:
    print(dataset)
    data_dir = os.path.join("/home/valentini/dev/Mousenet/dataset2/" + str(dataset) + "/")
    out_dir = os.path.join("/home/valentini/dev/Mousenet/output/Feminad/DL_Localnet_Reg/")
    mris = sorted(glob(os.path.join(data_dir, 'MRI_N4_Resample_Norm', "*.nii.gz")))
    warps_ants = sorted(glob(os.path.join(data_dir, 'MRI_N4_Resample_Norm_Deformable', "*Warp.nii.gz")))
    warps_deep = sorted(glob(os.path.join(out_dir, 'DeformableWarp', "*.nii.gz")))
    for i, mri in enumerate(mris):
        if "2_6516" in mri:
            print(str(i) + "/" + str(len(mris)), end="\r") 
            print(warps_ants[2*i])
            print(warps_ants[2*i+1])
            print(warps_deep[i])
            print(mris[i])
            outname = data_dir + 'test_Affine/Affine_' + mri.split('/')[-1]
            print(outname)
            
            warp_inv_ants = torch.from_numpy(nib.load(warps_ants[2*i]).get_fdata()).numpy()
            print(warp_inv_ants.shape)
            warp_ants = torch.from_numpy(nib.load(warps_ants[2*i+1]).get_fdata()).numpy()
            warp_deep = torch.from_numpy(nib.load(warps_deep[i]).get_fdata()).numpy()
            print(point)
            print(interpolate_warp(warp_inv_ants[:,:,:,0,:], point))
            print(interpolate_warp(warp_ants[:,:,:,0,:], point))
            
            point = (51.94,41.54,88.95)
            print(interpolate_warp(warp_deep[:,:,:,0,:], point))
            
            #affine_mri_nib = applyMonaiWarp(mris[i], warps[i], ref)
            #nib.save(affine_mri_nib, outname)
            break
    print('-'*40)
    break

Feminad
/home/valentini/dev/Mousenet/dataset2/Feminad/MRI_N4_Resample_Norm_Deformable/Deformable_Norm_Resampled_N4_2_6516131InverseWarp.nii.gz
/home/valentini/dev/Mousenet/dataset2/Feminad/MRI_N4_Resample_Norm_Deformable/Deformable_Norm_Resampled_N4_2_6516131Warp.nii.gz
/home/valentini/dev/Mousenet/output/Feminad/DL_Localnet_Reg/DeformableWarp/Affine_template0Norm_Resampled_N4_2_651613WarpedToTemplate.nii.gz
/home/valentini/dev/Mousenet/dataset2/Feminad/MRI_N4_Resample_Norm/Norm_Resampled_N4_2_6516.nii.gz
/home/valentini/dev/Mousenet/dataset2/Feminad/test_Affine/Affine_Norm_Resampled_N4_2_6516.nii.gz
(128, 128, 128, 1, 3)
(51.5, 43.0, 89.5)
(array(-0.01639375), array(0.10580799), array(0.06736265))
(array(0.02296625), array(-0.10385632), array(-0.07041248))
(array(-0.0454352), array(3.13162197), array(-0.24716698))
----------------------------------------
