In [7]:
import reg_mri
import os
from glob import glob
from utils import compute_mean_dice
import nibabel as nib
from scipy.spatial.distance import dice
import numpy as np
import itk
import SimpleITK as sitk
import scipy.ndimage
import scipy
import matplotlib.pyplot as plt
from transforms_dict import getRegistrationEvalInverseTransformForMRI, SaveTransformForMRI
from tqdm import tqdm
import monai
import subprocess
import torch

In [2]:
def getJacobian_deformable_ants(warp_img, outfolder, outdataset):
    outdir = os.path.join('output', outdataset, outfolder, 'ANTSJacobian_Deformable')
    outname = outdir + '/ANTSJacobianDeformable_' + warp_img.split('/')[-1].split('Warp')[0] + '.nii.gz'
    jacobian_command = "CreateJacobianDeterminantImage 3 " + str(warp_img) + " " + str(outname) + " 0 0"
    subprocess.call(jacobian_command.split(" "))

In [29]:
def JacobianDet(y_pred):
    imgshape = (128,128,128)
    x = np.arange(imgshape[0])
    y = np.arange(imgshape[1])
    z = np.arange(imgshape[2])
    grid = np.array(np.meshgrid(x, y, z))
    grid = np.reshape(grid, (1,3,128,128,128))
    grid = torch.from_numpy(grid).float().cpu()
    
    J = y_pred + grid
    dx = J[:, :, 1:, :-1, :-1] - J[:, :, :-1, :-1, :-1]
    dy = J[:, :, :-1, 1:, :-1] - J[:, :, :-1, :-1, :-1]
    dz = J[:, :, :-1, :-1, 1:] - J[:, :, :-1, :-1, :-1]
    
    #= dx/x * (dy/y * dz/z - dy/z * dz/y)
    #+ dx/y * (dy/z * dz/x - dy/x * dz/z)
    #+ dx/z * (dy/x * dz/y - dy/y * dz/x)

    Jdet0 = dx[:,0,:,:,:] * (dy[:,1,:,:,:] * dz[:,2,:,:,:] - dy[:,2,:,:,:] * dz[:,1,:,:,:])
    Jdet1 = dx[:,1,:,:,:] * (dy[:,2,:,:,:] * dz[:,0,:,:,:] - dy[:,0,:,:,:] * dz[:,2,:,:,:])
    Jdet2 = dx[:,2,:,:,:] * (dy[:,0,:,:,:] * dz[:,1,:,:,:] - dy[:,1,:,:,:] * dz[:,0,:,:,:])

    Jdet = Jdet0 - Jdet1 + Jdet2

    return Jdet

In [30]:
models = [
    "scenario-108-freeze/template_local_scenario4_painfact_finetune_newmodel_freeze2_1.0-0.0-8.0.pth",      
         ]

outfolders = [
            "TestJacobian",
         ]
for i, model in enumerate(models):
    outdataset = 'Feminad'
    #outdataset = 'Painfact'
    outfolder = outfolders[i]
    mris = sorted(glob(os.path.join('dataset2', outdataset, 'MRI_N4_Resample_Norm_Identity_Affine', "*.nii.gz")))
    atlas_name = "dataset2/Atlas/Identity_Feminad_Template.nii.gz"
    affine = nib.load(atlas_name).affine
    header = nib.load(atlas_name).header 

    for i, mri in enumerate(mris):
        outname = "output/" + outdataset + "/" + outfolder + "/MRI_N4_Registration_Deformable/" + mri.split('/')[-1].split('.')[0]
        pred_image, ddfs = reg_mri.main(model, mri, outname, False, True, "local", newmodel=True)
    
        moving_image = ddfs[0].cpu()
    
        deformable_ddf = ddfs[1].cpu()
        ddf_jacdet = JacobianDet(deformable_ddf).unsqueeze(0)
        ddf_jacdet = ddf_jacdet.cpu().numpy()
        ddf_jacdet = np.transpose(ddf_jacdet, (2, 3, 4, 0, 1))
        deformable_ddf = deformable_ddf.cpu().numpy()
        deformable_ddf = np.transpose(deformable_ddf, (2, 3, 4, 0, 1))
        #deformable_ddf = scipy.ndimage.gaussian_filter(deformable_ddf, sigma=1.0)            
        
        deformable_image = nib.Nifti1Image(deformable_ddf, affine, header)
        outname_deformable = "output/" + outdataset + "/" + outfolder + "/DeformableWarp/" + mri.split('/')[-1].split('.')[0] + ".nii.gz"
        nib.save(deformable_image, outname_deformable)        
        
        ddf_jacdet_image = nib.Nifti1Image(ddf_jacdet, affine, header)
        outname_jacdet = "output/" + outdataset + "/" + outfolder + "/TorchJacobian_Deformable/TorchJacobian_" + mri.split('/')[-1].split('.')[0] + ".nii.gz"
        nib.save(ddf_jacdet_image, outname_jacdet)    
        break
        
    mridir = 'output/' + outdataset + '/' + outfolder + '/'
    affine_warps = sorted(glob(os.path.join(mridir, 'AffineWarp', "*.nii.gz")))
    deformable_warps = sorted(glob(os.path.join(mridir, 'DeformableWarp', "*.nii.gz")))
    for i in range(len(deformable_warps)): 
        getJacobian_deformable_ants(deformable_warps[i], outfolder, outdataset)  

monai.networks.blocks.Warp: Using PyTorch native grid_sample.


=> Saved to output/Feminad/TestJacobian/MRI_N4_Registration_Deformable/Affine_Identity_Norm_Resampled_N4_12_6517_MRI.nii.gz
=> Saved to output/Feminad/TestJacobian/MRI_N4_Registration_Deformable/Affine_Identity_Norm_Resampled_N4_12_6517_Resample.nii.gz
=> Saved to output/Feminad/TestJacobian/MRI_N4_Registration_Deformable/Affine_Identity_Norm_Resampled_N4_12_6517_Resample_Reg.nii.gz
=> Saved to output/Feminad/TestJacobian/MRI_N4_Registration_Deformable/Affine_Identity_Norm_Resampled_N4_12_6517_Reg.nii.gz
1.80523005900001
