In [1]:
"""
Purpose: New skeletonization method that does
not use a recursive method but simply uses the followoing
algorithm: 
- Do mesh splitting to find which mesh pieces aren't connect to
  the soma
0) Do Soma Extraction and split the meshes from there
For Each significant mesh that was split off in beginning
1) Poisson Surface Reconstruction
2) CGAL skeletonization of all signfiicant pieces
3) Using CGAL skeleton, find the leftover mesh not skeletonized
4) Do surface reconstruction on the parts that are left over
- with some downsampling
5) Stitch the skeleton 


---- Afterwards stitching:
1) Compute the soma mesh center point
2) For meshes that were originally connected to soma
a. Find the closest skeletal point to soma center
b. Add an edge from closest point to soma center
3) Then do stitching algorithm on all of remaining disconnected
    skeletons

"""

"\nPurpose: New skeletonization method that does\nnot use a recursive method but simply uses the followoing\nalgorithm: \n- Do mesh splitting to find which mesh pieces aren't connect to\n  the soma\n0) Do Soma Extraction and split the meshes from there\nFor Each significant mesh that was split off in beginning\n1) Poisson Surface Reconstruction\n2) CGAL skeletonization of all signfiicant pieces\n3) Using CGAL skeleton, find the leftover mesh not skeletonized\n4) Do surface reconstruction on the parts that are left over\n- with some downsampling\n5) Stitch the skeleton \n\n\n---- Afterwards stitching:\n1) Compute the soma mesh center point\n2) For meshes that were originally connected to soma\na. Find the closest skeletal point to soma center\nb. Add an edge from closest point to soma center\n3) Then do stitching algorithm on all of remaining disconnected\n    skeletons\n\n"

In [2]:
from pykdtree.kdtree import KDTree
import time
import trimesh
import numpy as np
from pathlib import Path

import calcification_Module as cm #module that allows for calcification
import time
import os
import pathlib

from tqdm.notebook import tqdm

In [3]:
import meshlab
from importlib import reload
meshlab = reload(meshlab)
from meshlab import Decimator , Poisson
import skeleton_utils as sk

# Importing Example Mesh and Example Soma

In [4]:
import soma_extraction_utils as soma_utils
from pathlib import Path
import trimesh

In [5]:
def load_somas(segment_id,main_mesh_total):
    try:
        current_soma = trimesh.load_mesh("./Dustin_soma.off")
        return [current_soma]
    except:
        print("No Soma currently available so must compute own")
        (total_soma_list, 
             run_time, 
             total_soma_list_sdf) = soma_utils.extract_soma_center(
                                segment_id,
                                main_mesh_total.vertices,
                                main_mesh_total.faces,
                                outer_decimation_ratio= 0.25,
                                large_mesh_threshold = 60000,
                                large_mesh_threshold_inner = 40000,
                                soma_width_threshold = 0.32,
                                soma_size_threshold = 20000,
                               inner_decimation_ratio = 0.25,
                               volume_mulitplier=7,
                               side_length_ratio_threshold=3,
                                soma_size_threshold_max=192000,
                                delete_files=True
            )
        return total_soma_list
    else:
        return []

In [6]:
segment_id = 12345

# Load soma mesh
# #loads in the Dustin mesh
main_mesh_path = Path("./Dustin.off")
main_mesh_total = trimesh.load_mesh(str(main_mesh_path.absolute()))

soma_mesh_list = load_somas(segment_id,main_mesh_total)
print(f"Soma List = {soma_mesh_list}")

soma_mesh_list_centers = [np.array(np.mean(k.vertices,axis=0)).astype("float")
                           for k in soma_mesh_list]
print(f"soma_mesh_list_centers = {soma_mesh_list_centers}")

if len(soma_mesh_list) <= 0:
    print(f"**** No Somas Found for Mesh {segment_id}")
    


Soma List = [<trimesh.Trimesh(vertices.shape=(1864, 3), faces.shape=(3640, 3))>]
soma_mesh_list_centers = [array([1326076.34924893,  732678.79350858,  884152.51410944])]


In [7]:
from importlib import reload
sk = reload(sk)

# Do mesh splitting to find which mesh pieces aren't connect to the soma

In [8]:
split_meshes = sk.split_significant_pieces(
                            main_mesh_total,
                            significance_threshold=15,
                            print_flag=False)
len(split_meshes)

110

# Find which bounding box contains the soma

In [9]:
"""
Pseudocode: 
For all meshes in list
1) compute soma center
2) Find all the bounding boxes that contain the soma center
3) Find the mesh with the closest distance from 
   one vertex to the soma center and tht is winner
"""
containing_mesh_indices=dict([(i,[]) for i,sm_c in enumerate(soma_mesh_list_centers)])
for k,sm_center in enumerate(soma_mesh_list_centers):

    viable_meshes = [j for j,m in enumerate(split_meshes) 
             if trimesh.bounds.contains(m.bounds,sm_center.reshape(-1,3))
                    ]
    if len(viable_meshes) == 0:
        raise Exception(f"The Soma {k} with {sm_center} was not contained in any of the boundying boxes")
    elif len(viable_meshes) == 1:
        containing_mesh_indices[k] = viable_meshes[0]
    else:
        #find which mesh is closer to the soma midpoint
        min_distances_to_soma = []
        for v_i in viable_meshes:
            # build the KD Tree
            viable_neuron_kdtree = KDTree(soma_mesh_list[v_i].vertices)
            distances,closest_node = viable_neuron_kdtree.query(sm_centers.reshape(-1,3))
            min_distances_to_soma.append(np.min(distances))
        print(f"min_distances_to_soma = {min_distances_to_soma}")
        containing_mesh_indices[k] = np.argmin(min_distances_to_soma)

In [10]:
non_soma_touching_meshes = [m for i,m in enumerate(split_meshes)
                     if i not in list(containing_mesh_indices.values())]
len(non_soma_touching_meshes)
soma_touching_meshes = dict([(i,split_meshes[m_i]) 
                             for i,m_i in containing_mesh_indices.items()])
soma_touching_meshes

{0: <trimesh.Trimesh(vertices.shape=(325120, 3), faces.shape=(651866, 3))>}

# Part 1: For each soma containing mesh: Do Skeletonization and stitching

# - util functions for skeletonization

In [11]:
def subtract_soma(current_soma,main_mesh):
    
    start_time = time.time()
    face_midpoints_soma = np.mean(current_soma.vertices[current_soma.faces],axis=1)


    curr_mesh_bbox_restriction,faces_bbox_inclusion = (
                    sk.bbox_mesh_restrcition(main_mesh,
                                             current_soma.bounds,
                                            mult_ratio=1.3)
    )

    face_midpoints_neuron = np.mean(curr_mesh_bbox_restriction.vertices[curr_mesh_bbox_restriction.faces],axis=1)

    soma_kdtree = KDTree(face_midpoints_soma)

    distances,closest_node = soma_kdtree.query(face_midpoints_neuron)

    distance_threshold = 550
    distance_passed_faces  = distances<distance_threshold

    faces_to_keep = np.array(list(set(np.arange(0,len(main_mesh.faces))).difference(set(faces_bbox_inclusion[distance_passed_faces]))))
    without_soma_mesh = main_mesh.submesh([faces_to_keep],append=True)

    #get the significant mesh pieces
    mesh_pieces = sk.split_significant_pieces(without_soma_mesh,significance_threshold=200)
    print(f"Total Time for soma mesh cancellation = {np.round(time.time() - start_time,3)}")
    return mesh_pieces

In [12]:
"""
Pseudocode: 
0) Do Soma Extraction and split the meshes from there
For Each significant mesh that was split off in beginning
1) Poisson Surface Reconstruction
2) CGAL skeletonization of all signfiicant pieces 
    (if above certain size ! threshold) 
            --> if not skip straight to surface skeletonization
3) Using CGAL skeleton, find the leftover mesh not skeletonized
4) Do surface reconstruction on the parts that are left over
- with some downsampling
5) Stitch the skeleton 

"""

'\nPseudocode: \n0) Do Soma Extraction and split the meshes from there\nFor Each significant mesh that was split off in beginning\n1) Poisson Surface Reconstruction\n2) CGAL skeletonization of all signfiicant pieces \n    (if above certain size ! threshold) \n            --> if not skip straight to surface skeletonization\n3) Using CGAL skeleton, find the leftover mesh not skeletonized\n4) Do surface reconstruction on the parts that are left over\n- with some downsampling\n5) Stitch the skeleton \n\n'

# - setting up the paths for data writing

In [13]:
from shutil import rmtree
from pathlib import Path

mesh_base_path=Path("./Dustin_vp4/")
current_name="Dustin"

if mesh_base_path.exists():
    rmtree(str(mesh_base_path.absolute()))
mesh_base_path.mkdir(parents=True,exist_ok=True)
print(list(mesh_base_path.iterdir()))

[]


In [14]:
def skeletonize_connected_branch(current_mesh,
                        output_folder="./temp",
                        delete_temp_files=True,
                        name="None",
                        surface_reconstruction_size=50,
                        n_surface_downsampling = 1,
                        n_surface_samples=1000,
                        skeleton_print=False,
                        mesh_subtraction_distance_threshold=3000,
                        mesh_subtraction_buffer=50,
                        max_stitch_distance = 18000,
                        ):
    """
    Purpose: To take a mesh and construct a full skeleton of it
    (Assuming the Soma is already extracted)
    
    1) Poisson Surface Reconstruction
    2) CGAL skeletonization of all signfiicant pieces 
        (if above certain size ! threshold) 
                --> if not skip straight to surface skeletonization
    3) Using CGAL skeleton, find the leftover mesh not skeletonized
    4) Do surface reconstruction on the parts that are left over
    - with some downsampling
    5) Stitch the skeleton 
    """
    
    #check that the mesh is all one piece
    current_mesh_splits = sk.split_significant_pieces(current_mesh,
                               significance_threshold=1)
    if len(current_mesh_splits) > 1:
        raise Exception(f"The mesh passed has {len(current_mesh_splits)} pieces")

    # check the size of the branch and if small enough then just do
    # Surface Skeletonization
    if len(current_mesh.faces) < surface_reconstruction_size:
        #do a surface skeletonization
        surf_sk = sk.generate_surface_skeleton(current_mesh.vertices,
                                    current_mesh.faces,
                                    surface_samples=n_surface_samples,
                                             n_surface_downsampling=n_surface_downsampling )
        return surf_sk
    else:
    
        #if can't simply do a surface skeletonization then 
        #use cgal method that requires temp folder

        if type(output_folder) != type(Path()):
            output_folder = Path(str(output_folder))
            output_folder.mkdir(parents=True,exist_ok=True)
            
        # CGAL Step 1: Do Poisson Surface Reconstruction
        Poisson_obj = Poisson(output_folder,overwrite=True)
        

        skeleton_start = time.time()
        print("     Starting Screened Poisson")
        new_mesh,output_subprocess_obj = Poisson_obj(   
                                    vertices=current_mesh.vertices,
                                     faces=current_mesh.faces,
                                    mesh_filename=name + ".off",
                                     return_mesh=True,
                                     delete_temp_files=False,
                                    )
        print(f"-----Time for Screened Poisson= {time.time()-skeleton_start}")
            
        #2) Filter away for largest_poisson_piece:
        mesh_pieces = sk.split_significant_pieces(new_mesh,
                                            significance_threshold=surface_reconstruction_size)
        
        if skeleton_print:
            print(f"Signifiant mesh pieces of {surface_reconstruction_size} size "
                 f"after poisson = {len(mesh_pieces)}")
        skeleton_ready_for_stitching = np.array([])
        skeleton_files = [] # to be erased later on if need be
        if len(mesh_pieces) <= 0:
            if skeleton_print:
                print("No signficant skeleton pieces so just doing surface skeletonization")
            # do surface skeletonization on all of the pieces
            surface_mesh_pieces = sk.split_significant_pieces(new_mesh,
                                            significance_threshold=2)
            
            #get the skeletons for all those pieces
            current_mesh_skeleton_list = [
                sk.generate_surface_skeleton(p.vertices,
                                    p.faces,
                                    surface_samples=n_surface_samples,
                                    n_surface_downsampling=n_surface_downsampling )
                for p in surface_mesh_pieces
            ]
            
            skeleton_ready_for_stitching = np.vstack(current_mesh_skeleton_list)
            
            #will stitch them together later
        else: #if there are parts that can do the cgal skeletonization
            skeleton_start = time.time()
            print("     Starting Calcification")
            for zz,piece in enumerate(mesh_pieces):
                current_mesh_path = output_folder / f"{name}_{zz}"
                
                written_path = sk.write_neuron_off(piece,current_mesh_path)
                
                skeleton_start = time.time()
                print(f"Path sneding to calcification = {written_path[:-4]}")
                time.sleep(10)
                cm.calcification(written_path[:-4])
                print(f"Time for skeletonizatin = {time.time() - skeleton_start}")
                time.sleep(10)
                skeleton_files.append(str(current_mesh_path) + "_skeleton.cgal")
                
            if skeleton_print:
                print(f"-----Time for Running Calcification = {time.time()-skeleton_start}")
            
            #collect the skeletons and subtract from the mesh
            significant_poisson_skeleton = sk.read_skeleton_edges_coordinates(skeleton_files)
            boolean_significance_threshold=5
            
            
            mesh_pieces_leftover =  sk.mesh_subtraction_by_skeleton(current_mesh,
                                                        significant_poisson_skeleton,
                                                        buffer=mesh_subtraction_buffer,
                                                        bbox_ratio=1.2,
                                                        distance_threshold=significant_poisson_skeleton,
                                                        significance_threshold=boolean_significance_threshold,
                                                        print_flag=False
                                                       )
            
            # *****adding another significance threshold*****
            leftover_meshes_sig = [k for k in mesh_pieces_leftover if len(k.faces) > 50]
            leftover_meshes = sk.combine_meshes(leftover_meshes_sig)
            
            leftover_meshes_sig_surf_sk = []
            for m in tqdm(leftover_meshes_sig):
                surf_sk = sk.generate_surface_skeleton(m.vertices,
                                               m.faces,
                                               surface_samples=n_surface_samples,
                                    n_surface_downsampling=n_surface_downsampling )
                if len(surf_sk) > 0:
                    leftover_meshes_sig_surf_sk.append(surf_sk)
            leftover_surfaces_total = np.vstack(leftover_meshes_sig_surf_sk)
            
            skeleton_ready_for_stitching = np.vstack([significant_poisson_skeleton,leftover_surfaces_total])
        
        #now want to stitch together whether generated from 
        if skeleton_print:
            print(f"After cgal process the un-stitched skeleton has shape {skeleton_ready_for_stitching.shape}")
        
        stitched_skeletons_full = sk.stitch_skeleton(
                                                  skeleton_ready_for_stitching,
                                                  max_stitch_distance=max_stitch_distance,
                                                  stitch_print = False,
                                                  main_mesh = []
                                                )
        stitched_skeletons_full_cleaned = sk.clean_skeleton(stitched_skeletons_full)
        
        # erase the skeleton files if need to be
        if delete_temp_files:
            for sk_fi in skeleton_files:
                Path(sk_fi).unlink()
        
        # if created temp folder then erase if empty
        if str(output_folder.absolute()) == str(Path("./temp").absolute()):
            print("The process was using a temp folder")
            if len(list(output_folder.iterdir())) == 0:
                print("Temp folder was empty so deleting it")
                if output_folder.exists():
                    rmtree(str(output_folder.absolute()))
        
        return stitched_skeletons_full_cleaned    

In [15]:
import calcification_Module as cm

In [16]:
#cm.calcification("/notebooks3/Users/celii/Documents/Complete_Pinky100_Pipeline/notebooks/Platinum/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_0")

In [17]:
sk = reload(sk)

soma_touching_meshes_skeletons = []
for s_i,main_mesh in soma_touching_meshes.items():
    #Do the mesh subtraction to get the disconnected pieces
    current_soma = soma_mesh_list[s_i]
    
    mesh_pieces = subtract_soma(current_soma,main_mesh)

    #get each branch skeleton
    total_soma_skeletons = []
    for dendrite_index,picked_dendrite in enumerate(mesh_pieces):
        dendrite_name=current_name + f"_soma_{s_i}_branch_{dendrite_index}"
        stitched_dendrite_skeleton = skeletonize_connected_branch(picked_dendrite,
                                                       output_folder=mesh_base_path,
                                                       name=dendrite_name,
                                                        skeleton_print = True)
        
        if len(stitched_dendrite_skeleton)<=0:
            print(f"*** Dendrite {dendrite_index} did not have skeleton computed***")
        else: 
            total_soma_skeletons.append(stitched_dendrite_skeleton)
    
#     #stitch the branches skeleton to the soma centroid
#     soma_stitched_dendrtie_skeletons = soma_skeleton_stitching(
#                     total_soma_skeletons,
#                     current_soma
#     )
    
#     soma_touching_meshes_skeletons.append(soma_stitched_dendrtie_skeletons)

Total Time for soma mesh cancellation = 1.204
     Starting Screened Poisson
IN INPUT FILE VALIDATION LOOP
LEAVING LOOP, MESH VALIDATED
Using port = 813
xvfb-run -n 813 -s "-screen 0 800x600x24" meshlabserver $@  -i /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0.off -o /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_poisson.off -s /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/poisson_183352.mls
-----Time for Screened Poisson= 33.07263779640198


face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_normals all zero, ignoring!
face_norma

Signifiant mesh pieces of 50 size after poisson = 502
     Starting Calcification
Path sneding to calcification = /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_0
Time for skeletonizatin = 17.854615211486816
Path sneding to calcification = /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_1
Time for skeletonizatin = 10.098324298858643
Path sneding to calcification = /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_2
Time for skeletonizatin = 10.097452640533447
Path sneding to calcification = /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_3
Time for skeletonizatin = 10.185581684112549
Path sneding to calcification = /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_4
Time for skeletonizatin = 10.097547054290771
Path sneding to calcification = /notebooks/Platinum_Skeletonization_vp3/Dustin_vp4/Dustin_soma_0_branch_0_5
Time for skeletonizatin = 10.133883237838745


KeyboardInterrupt: 

# Do Skeletonization of all non-soma touching branches

# Stitch the All Soma touching and non soma touching branches