In [1]:
%matplotlib qt

import numpy as np
import mcubes
import math
from scipy.spatial import distance
from skimage import measure
from scipy.stats import entropy
import numexpr as ne
from numpy.linalg import norm as nrm 

from scipy.ndimage import _ni_support
from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
    generate_binary_structure
from scipy.ndimage.measurements import label, find_objects
from scipy.stats import pearsonr
from math import exp
from numpy import dot 


from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from skimage import measure
from skimage.draw import ellipsoid

import sys
import shutil
import os
import numpy as np
import math
from scipy import ndimage
import matplotlib.pyplot as plt
import torch as th
from skimage import morphology, filters
from scipy.ndimage import gaussian_filter
from dipy.io.image import load_nifti, save_nifti

from numpy import inf

from scipy.linalg.blas import sgemm

import time

In [2]:
def mesh(vol,step_size=1):
    verts, faces, normals, vals = measure.marching_cubes(vol,level=0,step_size=step_size,allow_degenerate=False)

    cts_normals = np.zeros((verts[faces].shape[0],3),dtype=np.float32)
    cts = np.zeros((verts[faces].shape[0],3),dtype=np.float32)

    for i in range(0,max(verts[faces].shape)):
        norm_vec = np.cross((verts[faces][i,1,:]-verts[faces][i,0,:]),(verts[faces][i,2,:]-verts[faces][i,0,:]))
        norm_vec /= np.linalg.norm(norm_vec)
        cts_normals[i,:] = norm_vec
        cts[i,:] = (verts[faces][i,0,:] + verts[faces][i,1,:] + verts[faces][i,2,:])/3

    return cts, cts_normals, verts, faces




def hilbert_distance(cntr1,norm1,cntr2,norm2,sigma):
    return inner_hilbert_w(cntr1,norm1,cntr1,norm1,sigma) + inner_hilbert_w(cntr2,norm2,cntr2,norm2,sigma) - 2*inner_hilbert_w(cntr1,norm1,cntr2,norm2,sigma)


def hilbert_distance_explicit(cntr1,norm1,cntr2,norm2,sigma):
    return inner_hilbert_w_FAST(cntr1,norm1,cntr1,norm1,sigma) + inner_hilbert_w_FAST(cntr2,norm2,cntr2,norm2,sigma) - 2*inner_hilbert_w_FAST(cntr1,norm1,cntr2,norm2,sigma)



def inner_hilbert_w(cntr1,norm1,cntr2,norm2,sigma):
    if np.asarray(cntr1).ndim == 3:
        cntr1 = np.asarray(cntr1)[0,...]
    else:
        cntr1 = np.asarray(cntr1)
    if np.asarray(norm1).ndim == 3:
        norm1 = np.asarray(norm1)[0,...]
    else:
        norm1 = np.asarray(norm1)
    if np.asarray(cntr2).ndim == 3:
        cntr2 = np.asarray(cntr2)[0,...]
    else:
        cntr2 = np.asarray(cntr2)
    if np.asarray(norm2).ndim == 3:
        norm2 = np.asarray(norm2)[0,...]
    else:
        norm2 = np.asarray(norm2)

    dw = 0
    
    s_squared = 2*(sigma**2)
    
    for p in range(0,max(cntr1.shape)):
        for q in range(0,max(cntr2.shape)):
            dw += (exp(-nrm(cntr1[p,:]-cntr2[q,:])/s_squared))*((dot(norm1[p,:],norm2[q,:]))**2)
    return dw


def inner_hilbert_w_FAST(cntr1,norm1,cntr2,norm2,sigma):
    if np.asarray(cntr1).ndim == 3:
        cntr1 = np.asarray(cntr1)[0,...]
    else:
        cntr1 = np.asarray(cntr1)
    if np.asarray(norm1).ndim == 3:
        norm1 = np.asarray(norm1)[0,...]
    else:
        norm1 = np.asarray(norm1)
    if np.asarray(cntr2).ndim == 3:
        cntr2 = np.asarray(cntr2)[0,...]
    else:
        cntr2 = np.asarray(cntr2)
    if np.asarray(norm2).ndim == 3:
        norm2 = np.asarray(norm2)[0,...]
    else:
        norm2 = np.asarray(norm2)

    dw = 0 
    s_squared = 2*(sigma**2)
    
    
    for p in range(0,max(cntr1.shape)):
        for q in range(0,max(cntr2.shape)):
            dw += (exp(-(nrm(cntr1[p,:]-cntr2[q,:])**2)/s_squared))*((dot(norm1[p,:],norm2[q,:]))**2)
    return dw


def RBF_kernel(pt1,pt2,sigma):
    sqdist = np.linalg.norm(pt1-pt2)
    return math.exp(-sqdist/(2*(sigma**2)))



def RBF_kernel_fast(Cts,sigma):
    X_norm = np.sum(Cts**2, axis = -1)
    return ne.evaluate('exp(s * (A + B - 2 * C))', {\
        'A' : X_norm[:,None],\
        'B' : X_norm[None,:],\
        'C' : np.dot(Cts, Cts.T),\
        's' : -(1/(2*(sigma**2)))\
    })



def varifold_distance_FAST(norms1,norms2,cts1,cts2,sigma=3.0):
    
    if np.asarray(cts1).ndim == 3:
        cts1 = np.asarray(cts1)[0,...]
    else:
        cts1 = np.asarray(cts1)
    if np.asarray(norms1).ndim == 3:
        norms1 = np.asarray(norms1)[0,...]
    else:
        norms1 = np.asarray(norms1)
    if np.asarray(cts2).ndim == 3:
        cts2 = np.asarray(cts2)[0,...]
    else:
        cts2 = np.asarray(cts2)
    if np.asarray(norms2).ndim == 3:
        norms2 = np.asarray(norms2)[0,...]
    else:
        norms2 = np.asarray(norms2)
    
    cts = np.vstack((cts1,cts2))
    norms = np.vstack((norms1,norms2))
    hilbert_mat = RBF_kernel_fast(cts,sigma)*(np.matmul(norms,np.transpose(norms))**2)

    hilbert_distance = np.sum(hilbert_mat[0:cts1.shape[0],0:cts1.shape[0]]) 
    hilbert_distance += np.sum(hilbert_mat[cts1.shape[0]:,cts1.shape[0]:])
    hilbert_distance -= 2*np.sum(hilbert_mat[cts1.shape[0]:,0:cts1.shape[0]])
    
    return hilbert_distance




In [None]:
vol_scp_path = "/Users/markolchanyi/Desktop/Edlow_Brown/Projects/Atlases/CRSEG_atlas/superstructures_v2/SCP_ROSTRAL.mgz"
vol_scp,dummy_affine = load_nifti(vol_scp_path, return_img=False)

vol_ctg_path = "/Users/markolchanyi/Desktop/Edlow_Brown/Projects/Atlases/CRSEG_atlas/superstructures_v2/CTG.mgz"
vol_ctg,dummy_affine = load_nifti(vol_ctg_path, return_img=False)



In [None]:
cts_scp, cts_norms_scp, verts_scp, faces_scp = mesh(vol_scp,step_size=1)
cts_ctg, cts_norms_ctg, verts_ctg, faces_ctg = mesh(vol_ctg,step_size=1)

In [None]:
#cts_scp, cts_norms_scp, verts_scp, faces_scp = mesh(vol_ctg,step_size=3)

# Display resulting triangular mesh using Matplotlib. This can also be done
# with mayavi (see skimage.measure.marching_cubes_lewiner docstring).
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# Fancy indexing: `verts[faces]` to generate a collection of triangles
mesh = Poly3DCollection(vertices[triangles])
mesh.set_edgecolor('k')
ax.add_collection3d(mesh)

ax.set_xlabel("x-axis")
ax.set_ylabel("y-axis")
ax.set_zlabel("z-axis")

ax.set_xlim(0, 400)  
ax.set_ylim(0, 400)  
ax.set_zlim(0, 400)  

plt.tight_layout()
plt.show()

In [None]:
sigma = 3.0

startTime = time.time()

hd2 = varifold_distance_FAST(cts_norms_scp,cts_norms_ctg,cts_scp,cts_ctg,sigma=3.0)

executionTime = (time.time() - startTime)
print('Execution time in seconds: ' + str(executionTime))
print("HD is: ", hd2)



In [None]:
cts_scp, cts_norms_scp, verts_scp, faces_scp = mesh(vol_scp,step_size=3)
rng = np.arange(-30,30)
print(rng)
hd_array = np.zeros_like(rng)
counter = 0
for i in rng: 
    vol_scp_shifted = np.roll(vol_scp,i,axis=1)
    cts_shift, cts_norms_shift, verts_shift, faces_shift = mesh(vol_scp_shifted,step_size=3)
    hd_array[counter] = varifold_distance_FAST(cts_norms_scp,cts_norms_shift,cts_scp,cts_shift,sigma=30.0)
    print("done with: ", counter, " hd is: ", hd_array[counter])
    counter +=1
    
    

In [None]:
plt.plot(rng, hd_array)
plt.show()

In [None]:
cts_scp, cts_norms_scp, verts_scp, faces_scp = mesh(vol_ctg,step_size=3)

# Display resulting triangular mesh using Matplotlib. This can also be done
# with mayavi (see skimage.measure.marching_cubes_lewiner docstring).
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# Fancy indexing: `verts[faces]` to generate a collection of triangles
mesh = Poly3DCollection(verts_scp[faces_scp])
mesh.set_edgecolor('k')
ax.add_collection3d(mesh)

ax.set_xlabel("x-axis")
ax.set_ylabel("y-axis")
ax.set_zlabel("z-axis")

ax.set_xlim(0, 400)  
ax.set_ylim(0, 400)  
ax.set_zlim(0, 400)  

plt.tight_layout()
plt.show()

In [3]:
test_path = "/Users/markolchanyi/Desktop/Edlow_Brown/Projects/datasets/ex_vivo_test_data/EXC007/scratch/f_mask4.nii.gz"
vol_test,dummy_affine = load_nifti(test_path, return_img=False)


atlas_path = "/Users/markolchanyi/Desktop/Edlow_Brown/Projects/datasets/ex_vivo_test_data/EXC007/scratch/m_mask4.nii.gz"
vol_atlas,dummy_affine = load_nifti(atlas_path, return_img=False)

cts_test, cts_norms_test, verts_test, faces_test = mesh(vol_test,step_size=1)
cts_atlas, cts_norms_atlas, verts_atlas, faces_atlas = mesh(vol_atlas,step_size=1)




In [12]:
displacement = np.load("/Users/markolchanyi/Desktop/Edlow_Brown/Projects/datasets/ex_vivo_test_data/EXC007/scratch/displacement.npy")


print(displacement)

[[[[[-1.7450186e+08 -1.7450186e+08 -1.7450186e+08]
    [-1.7058046e+08 -1.7450186e+08 -1.7450186e+08]
    [-1.6665907e+08 -1.7450186e+08 -1.7450186e+08]
    ...
    [ 1.6665907e+08 -1.7450186e+08 -1.7450186e+08]
    [ 1.7058046e+08 -1.7450186e+08 -1.7450186e+08]
    [ 1.7450186e+08 -1.7450186e+08 -1.7450186e+08]]

   [[-1.7450186e+08 -1.7058046e+08 -1.7450186e+08]
    [-1.7058046e+08 -1.7058046e+08 -1.7450186e+08]
    [-1.6665907e+08 -1.7058046e+08 -1.7450186e+08]
    ...
    [ 1.6665907e+08 -1.7058046e+08 -1.7450186e+08]
    [ 1.7058046e+08 -1.7058046e+08 -1.7450186e+08]
    [ 1.7450186e+08 -1.7058046e+08 -1.7450186e+08]]

   [[-1.7450186e+08 -1.6665907e+08 -1.7450186e+08]
    [-1.7058046e+08 -1.6665907e+08 -1.7450186e+08]
    [-1.6665907e+08 -1.6665907e+08 -1.7450186e+08]
    ...
    [ 1.6665907e+08 -1.6665907e+08 -1.7450186e+08]
    [ 1.7058046e+08 -1.6665907e+08 -1.7450186e+08]
    [ 1.7450186e+08 -1.6665907e+08 -1.7450186e+08]]

   ...

   [[-1.7450186e+08  1.6665907e+08 -1.745018

In [4]:
# Display resulting triangular mesh using Matplotlib. This can also be done
# with mayavi (see skimage.measure.marching_cubes_lewiner docstring).
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# Fancy indexing: `verts[faces]` to generate a collection of triangles
mesh = Poly3DCollection(verts_test[faces_test])
mesh.set_edgecolor('k')
mesh.set_facecolor('b')
ax.add_collection3d(mesh)

# Fancy indexing: `verts[faces]` to generate a collection of triangles
mesh2 = Poly3DCollection(verts_atlas[faces_atlas])
mesh2.set_edgecolor('k')
mesh2.set_edgecolor('r')
ax.add_collection3d(mesh2)

ax.set_xlabel("x-axis")
ax.set_ylabel("y-axis")
ax.set_zlabel("z-axis")

ax.set_xlim(0, 100)  
ax.set_ylim(0, 100)  
ax.set_zlim(0, 100)  

plt.tight_layout()
plt.show()