## Testing parameters for Kimimaro skeletonization. 
#### Objectives:
 ##### 1. Download a skeleton from CV
 ##### 2. Choose a random point on the skeleton.
 ##### 3. Download a 256 pixel cube surrounding the point.
 ##### 4. Run kimimaro with some parameter set.
 ##### 5. Select the ground truth nodes corresponding to the 256 pixel bounding box.
 ##### 6. Calculate the similarity between the new skeleton and the ground truth skeleton
 ##### 7. Repeat across parameter space, and across the skeleton

In [None]:
from cloudvolume import Bbox, Vec, CloudVolume, volumecutout, view
import kimimaro
from meshparty import trimesh_io, trimesh_vtk, skeleton_io
import navis
import pymaid
import numpy as np
import sys
from pathlib import Path
from matplotlib import pyplot as plt
import random
from scipy import spatial
import math

if '/Users/brandon/Documents/Repositories/Python/FANC_auto_recon/transforms/' not in sys.path:
    sys.path.append('/Users/brandon/Documents/Repositories/Python/FANC_auto_recon/transforms/')


from fanc_seg_utils import neuroglancer_utilities
from fanc_seg_utils import catmaid_utilities
from fanc_seg_utils import skeletonization



In [None]:
# Establish connections
instance = catmaid_utilities.catmaid_login('fanc',60,'/Users/brandon/Documents/MN_Analysis/catmaid_keys.txt')
cv = CloudVolume('https://storage.googleapis.com/zetta_lee_fly_vnc_001_segmentation/vnc1_full_v3align_2/realigned_v1/seg/full_run_v1')

In [None]:
def choose_vertex(skeleton):
    ''' Pick a random vertex from a skeleton
    Args:
        skeleton: cv skeleton object
    Returns:
        coordinates of a skeleton node'''
    return(sk.vertices[random.randint(0,len(sk.vertices))])

def download_vol(pt, size=256, seg_mip = np.array([17.2,17.2,45])):
    ''' Download segmentation cube
    Args:
        pt: np.array, mip0 xyz coord at center
        size: int, size in pixels of cube to download
        seg_mip: np.array, resolution of segmentation
    Returns:
        segmentation cube'''
    img_vol = cv.download_point(pt // np.array([4, 4, 1]) ,size=512, mip = seg_mip)
    return(img_vol)

def return_bounds(img_vol,transform = True):
    ''' Return bbox for image volume
    Args:
        img_vol: cv cutout
        transform: bool, default is true. We want to transform the points of the bbox to generate a filter for a catmaid skeleton which is in fanc3 space.
    Returns:
        bbox as np.array [[xmin,xmax],[ymin,ymax],[zmin,zmax]] in voxel coords'''\
    
    bbox = img_vol.bounds.to_list() * np.array([4,4,1,4,4,1])
    if transform is True:
        bbox = neuroglancer_utilities.fanc4_to_3(bbox.reshape(2,3).astype('float64'))
 
    return([bbox['x']*np.array([4.3,4.3]),bbox['y']*np.array([4.3,4.3]),bbox['z']*np.array([45,45])])

def bbox_filter(points, bbox):
    ''' Return skeleton coords within a bbox
    Args:
        points: np.array, nx3 array of mip0 vertex coords to filter
        bbox: list, bbox as returned from return_bounds to filter coordinates
    Returns:
        vertex coordinates within bounding box'''
    
    bound_x = np.logical_and(points[:, 0] > bbox[0][0], points[:, 0] < bbox[0][1])
    bound_y = np.logical_and(points[:, 1] > bbox[1][0], points[:, 1] < bbox[1][1])
    bound_z = np.logical_and(points[:, 2] > bbox[2][0], points[:, 2] < bbox[2][1])

    bbox_mask = np.logical_and(np.logical_and(bound_x, bound_y), bound_z)

    return points[bbox_mask]

def transform_pts(skeleton, img_res = np.array([4.3,4.3,45])):
    ''' Transform to CATMAID space
    args: 
        skeleton: cv skeleton ovject with voxel coord vertices
        img_res: np.array, NG resolution
    returns:
        cv skeleton object with transformed vertices in voxel coordinates'''
    tf = neuroglancer_utilities.fanc4_to_3(skeleton.vertices / img_res)
    skeleton.vertices = np.array(list(zip(tf['x'],tf['y'],tf['z'])))
    skeleton.vertices = skeleton.vertices * img_res
    return(skeleton)
    
    
def skeletonize(image_vol,seg_id,img_res,**kwargs):
    ''' Run skeletonization on the seg_id of interest. Use kwargs to vary params'''
    
    teasar_params={
    'scale': 1,
    'const': 200, # physical units
    'pdrf_exponent': 4,
    'pdrf_scale': 100000,
    'soma_detection_threshold': 1100, # physical units
    'soma_acceptance_threshold': 3500, # physical units
    'soma_invalidation_scale': 1.0,
    'soma_invalidation_const': 300, # physical units
    'max_paths': None, # default None
  }
    for i in teasar_params.keys():
        if i in kwargs.keys():
            teasar_params[i] = kwargs[i]
    
    if 'dust_threshold' in kwargs.keys():
        dust_threshold = kwargs['dust_threshold']
    else:
        dust_threshold = 500
        
        
    skels = kimimaro.skeletonize(
      image_vol,teasar_params=teasar_params,
      object_ids=[seg_id], # process only the specified labels
      # extra_targets_before=[ (27,33,100), (44,45,46) ], # target points in voxels
      # extra_targets_after=[ (27,33,100), (44,45,46) ], # target points in voxels
      dust_threshold = dust_threshold, # skip connected components with fewer than this many voxels
      anisotropy=(17,17,45), # default True
      fix_branching=True, # default True
      fix_borders=True, # default True
      fill_holes=True, # default False
      fix_avocados=True, # default False
      progress=True, # default False, show progress bar
      parallel=0, # <= 0 all cpu, 1 single process, 2+ multiprocess
      parallel_chunk_size=100, # how many skeletons to process before updating progress bar
    )
    
    for i in kwargs.keys():
        param = i
        val = kwargs[i]
        print(val)
    

    for i in skels.keys():
        sk = skels[i]
        sk.vertices =  (sk.vertices) + (np.array(image_vol.bounds.to_list()[0:3]) * np.array([4,4,1])) * img_res
        transform_pts(sk)
      

    
    return(skels[seg_id])    


def measure_similarity(a, b, sigma=2000, omega=4000):

    # Get distance matrix
    dist_mat = spatial.distance.cdist(a, b)

    # Get index of closest nodes in ground_truth
    closest_ix = np.argmin(dist_mat, axis=1)

    # Get closest distances
    closest_dist = dist_mat.min(axis=1)

    # Get intra-neuron matrices 
    distA = spatial.distance.pdist(a)
    distA = spatial.distance.squareform(distA)
    distB = spatial.distance.pdist(b)
    distB = spatial.distance.squareform(distB)

    # Calculate number of nodes closer than OMEGA. 
    closeA = (distA <= omega).sum(axis=1)
    closeB = (distB <= omega).sum(axis=1)

    # Calculate the scores over all nodes
    all_values = []
    for a in range(distA.shape[0]):
        this_synapse_value = math.exp(-1 * math.fabs(closeA[a] - closeB[closest_ix[a]]) / (
            closeA[a] + closeB[closest_ix[a]])) * math.exp(-1 * (closest_dist[a]**2) / (2 * sigma**2))
        all_values.append(this_synapse_value)

    score = sum(all_values) / len(all_values)

    return score

def compare_against_ground_truth(new,ground_truth):
    forward = measure_similarity(new,ground_truth)
    reverse = measure_similarity(ground_truth,new)
    return((forward+reverse)/2)

def find_min_dist(new_skeleton,ground_truth):
    distances = spatial.distance.cdist(new_skeleton,ground_truth)
    min_dist = np.mean(distances.min(axis=0)**2)
    
    return(min_dist)

#### Define Parameter Space

In [None]:
param_space = {'scale':range(1,4),
              'const': range(250,750,50),
              'pdrf_exponent': range(1,7),
              'pdrf_scale': range(10000,100000,5000)}

params = []
for s in param_space['scale']:
    for c in param_space['const']:
        for ex in param_space['pdrf_exponent']:
            for sc in param_space['pdrf_scale']:
                params.append({'scale':s,'const':c,'pdrf_exponent':ex,'pdrf_scale':sc})

In [None]:
# Download skeleton from cloud_volume. Get bounding box.
img_res = np.array([4.3,4.3,45])
seg_id = 74172093612626737
sk = cv.skeleton.get(seg_id)
sk.vertices = sk.vertices / img_res

## Download test neuron from CATMAID, get nodes within box to compare
n = pymaid.get_neurons('annotations:FANC4_ID: 74172093612626737',remote_instance=instance)


In [None]:
all_sims = []
all_ranks = []
all_fragments = []
for j in range(5):
    img_vol = download_vol(choose_vertex(sk))
    sims = []
    neuron_fragments = []

    for i in range(1,1000,200):
        new_neuron = skeletonize(img_vol,seg_id,img_res,**params[i])
        bbox = return_bounds(img_vol)
        nodes_to_compare = bbox_filter(n.nodes[['x','y','z']].values,bbox)
        sims.append(find_min_dist(new_neuron.vertices,nodes_to_compare))
        neuron_fragments.append(new_neuron)

    temp = np.array(sims).argsort()
    ranks = np.empty_like(temp)
    ranks[temp] = np.arange(len(sims))
    all_ranks.append(ranks)
    all_sims.append(sims/np.max(sims))
    all_fragments.append(neuron_fragments)
    


In [None]:
plt.scatter(n.nodes['x'],n.nodes['z'],c='r')
plt.scatter(new_neuron.vertices[:,0],new_neuron.vertices[:,2],alpha=.05)
plt.xlim(left=bbox[0][0],right=bbox[0][1])
plt.ylim(top=bbox[2][0],bottom=bbox[2][1])
measure_similarity(new_neuron.vertices,nodes_to_compare)

## Detect Somas

In [None]:
soma_params = {'soma_detection_threshold': 2000, # physical units
    'soma_acceptance_threshold': 5500, # physical units
    'soma_invalidation_scale': 1.0,
    'soma_invalidation_const': 800} # physical units

In [None]:
img_vol = download_vol(np.array([62944, 114880, 2300]))

In [None]:
new_neuron = skeletonize(img_vol,seg_id,img_res,**soma_params)

In [None]:
len(range(10000,100000,5000))

In [None]:
len(params)