In [None]:
"""
Try the skeletonization that uses both the meshparty and the meshafterparty

Pseudocode: 
1) Do MP skeletonization and mesh corresondence that divides it into branches
2) Find all of the pices that need MAP skeletonization
and then combine them into connected component meshes
3) Do MAP skeletonization and mesh correspondence for those larger pieces
4) For each MP connected skeleton:
a. Find the closest MAP skeleton branch endpoint and add stitch to the appropriate branch
  of the MP skeleton
5) Check all of skeleton is connected:
a. If No --> then stitch until fully connected and add stitching point 
to the smaller widths of the branches
"""

In [1]:
from os import sys
sys.path.append("/meshAfterParty/")

In [2]:
import trimesh_utils as tu
import meshparty_skeletonize as m_sk
import neuron_utils as nru
from meshparty import trimesh_io
import neuron_visualizations as nviz
import time
import numpy as np
from importlib import reload



In [3]:
neur_file = "/notebooks/test_neurons/Segmentation_2/meshparty/864691135548568516_single_soma_inhib_axon_cloud"
current_neuron = nru.decompress_neuron(filepath=neur_file,
                      original_mesh=neur_file)

Decompressing Neuron in minimal output mode...please wait


In [7]:
curr_limb = current_neuron[1]

In [None]:
curr_limb.current_touching_soma_vertices

In [8]:
nviz.plot_objects(meshes=curr_limb.mesh,
                 scatters=[curr_limb.current_touching_soma_vertices])

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

In [4]:
nviz.visualize_neuron(current_neuron,
                     limb_branch_dict=dict(L1="all"))


 Working on visualization type: mesh


VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

In [None]:
# Getting all of the 

# getting a limb to practice on

In [None]:
limb_obj = current_neuron[0]
limb_obj.mesh.show()

# 1) Do MP skeletonization and mesh corresondence that divides it into branches


In [None]:
import networkx as nx
m_sk = reload(m_sk)
tu = reload(tu)

In [None]:
#will eventually get the current root from soma_to_piece_touching_vertices[i]
root_curr = np.array(limb_obj.current_touching_soma_vertices[0])

m_sk = reload(m_sk)
sk_meshparty_obj,limb_mesh_mparty = m_sk.skeletonize_mesh_largest_component(limb_obj.mesh,
                                                        root=root_curr)
m_sk = reload(m_sk)

(segment_branches, #skeleton branches
divided_submeshes, divided_submeshes_idx, #mesh correspondence (mesh and indices)
segment_widths_median) = m_sk.skeleton_obj_to_branches(sk_meshparty_obj,
                                                      mesh = limb_mesh_mparty)

In [None]:
#using the width threshold
width_threshold = 450
size_threshold = 1000

pieces_above_threshold = np.where(segment_widths_median>width_threshold)[0]

width_large = segment_widths_median[pieces_above_threshold]
sk_large = [segment_branches[k] for k in pieces_above_threshold]
mesh_large_idx = [divided_submeshes_idx[k] for k in pieces_above_threshold]

mesh_large_connectivity = tu.mesh_list_connectivity(meshes = mesh_large_idx,
                        main_mesh = limb_mesh_mparty,
                        print_flag = False)
"""
Pseudocode: 
1) build a networkx graph with all nodes for mesh_large_idx indexes
2) Add the edges
3) Find the connected components
4) Find sizes of connected components
5) For all those connected components that are of a large enough size, 
add the mesh branches and skeletons to the final list


"""
G = nx.Graph()
G.add_nodes_from(np.arange(len(mesh_large_idx)))
G.add_edges_from(mesh_large_connectivity)
conn_comp = list(nx.connected_components(G))

filtered_pieces = []

sk_large_size_filt = []
mesh_large_idx_size_filt = []
width_large_size_filt = []

for cc in conn_comp:
    total_cc_size = np.sum([len(mesh_large_idx[k]) for k in cc])
    if total_cc_size>size_threshold:
        #print(f"cc ({cc}) passed the size threshold because size was {total_cc_size}")
        filtered_pieces.append(pieces_above_threshold[list(cc)])

if len(filtered_pieces) > 0:
    #all the pieces that will require MAP mesh correspondence and skeletonization
    #(already organized into their components)
    mesh_pieces_for_MAP = [limb_mesh_mparty.submesh([np.concatenate(divided_submeshes_idx[k])],append=True,repair=False) for k in filtered_pieces]

    pieces_idx_MP = np.setdiff1d(np.arange(len(divided_submeshes_idx)),np.concatenate(filtered_pieces))
    mesh_idx_MP = [divided_submeshes_idx[k] for k in pieces_idx_MP]

    mesh_large_connectivity_MP = tu.mesh_list_connectivity(meshes = mesh_idx_MP,
                            main_mesh = limb_mesh_mparty,
                            print_flag = False)
    G = nx.Graph()
    G.add_nodes_from(np.arange(len(mesh_idx_MP)))
    G.add_edges_from(mesh_large_connectivity_MP)
    sublimbs_MP = list(nx.connected_components(G))
    sublimbs_MP_orig_idx = [pieces_idx_MP[list(k)] for k in sublimbs_MP]
    
    
    #concatenate into sublimbs the skeletons and meshes
    sublimb_mesh_idx_branches_MP = [divided_submeshes_idx[k] for k in sublimbs_MP_orig_idx]
    sublimb_meshes_MP = [limb_mesh_mparty.submesh([np.concatenate(k)],append=True,repair=False)
                                                 for k in sublimb_mesh_idx_branches_MP]
    sublimb_skeleton_branches = [segment_branches[k] for k in sublimbs_MP_orig_idx]
    widths_MP = [segment_widths_median[k] for k in sublimbs_MP_orig_idx]

else: #if no pieces were determine to need MAP processing
    print("No MAP processing needed: just returning the Meshparty skeletonization and mesh correspondence")
    raise Exception("Returning MP correspondence")


#         for indiv_cc in cc:
#             sk_large_size_filt.append(sk_large[indiv_cc])
#             mesh_large_idx_size_filt.append(mesh_large_idx[indiv_cc])
#             width_large_size_filt.append(width_large[indiv_cc])

# nviz.plot_objects(main_mesh=tu.combine_meshes([limb_mesh_mparty,current_neuron["S0"].mesh]),
#                   main_mesh_color="green",
#     skeletons=sk_large_size_filt,
#      meshes=[limb_mesh_mparty.submesh([k],append=True) for k in mesh_large_idx_size_filt],
#       meshes_colors="red")

In [None]:
import skeleton_utils as sk
start_time = time.time()
skeletons_MAP = [sk.skeletonize_connected_branch(branch) for branch in mesh_pieces_for_MAP]
print(f"Total MAP skeleton time = {time.time() - start_time}")

In [None]:
skeletons_MAP

# Find which pieces are actually touching the soma so know when to add soma extending piece

# Doing the Mesh Correspondence for the skeletons

In [None]:
from tqdm_utils import tqdm
import compartment_utils as cu

In [None]:
#dictionary mapping soma to its touching border vertices
curr_soma_to_piece_touching_vertices = dict()
curr_soma_to_piece_touching_vertices[0] = limb_obj.current_touching_soma_vertices

In [None]:
"""
Idea: Find which sublimb has the soma_touching_soma 
so can add on branch

Pseudocode: 
1) get the vertices touching the soma
2) Find the sublimbs that contain these vertices

"""

In [None]:
cu = reload(cu)
branch_skeletons_MAP = []
branch_meshes_idx_MAP = []
branch_meshes_MAP = []
branch_widths_MAP = []
distance_by_mesh_center = True


start_time = time.time()

for curr_limb_sk,curr_limb_mesh in zip(skeletons_MAP,mesh_pieces_for_MAP):
    
    filter_end_node_length=4001
    distance_cleaned_skeleton = sk.clean_skeleton(curr_limb_sk,
                            distance_func=sk.skeletal_distance,
                      min_distance_to_junction=filter_end_node_length, #this used to be a tuple i think when moved the parameter up to function defintion
                      return_skeleton=True,
                        soma_border_vertices = None,
                        skeleton_mesh=curr_limb_mesh,
                        endpoints_must_keep = None,
                      print_flag=False)
    new_cleaned_skeleton = sk.clean_skeleton_with_decompose(distance_cleaned_skeleton)
    
    
    curr_limb_branches_sk_uneven = sk.decompose_skeleton_to_branches(new_cleaned_skeleton)
    sub_limb_mesh_idx = []
    sub_limb_width = []
    sub_limb_mesh = []
    for j,curr_branch_sk in tqdm(enumerate(curr_limb_branches_sk_uneven)):
        returned_data = cu.mesh_correspondence_adaptive_distance(curr_branch_sk,
                                                          curr_limb_mesh,
                                                         skeleton_segment_width = 1000,
                                                         distance_by_mesh_center=distance_by_mesh_center)
        curr_branch_face_correspondence, width_from_skeleton = returned_data
        sub_limb_corr.append(curr_branch_face_correspondence)
        sub_limb_width.append(width_from_skeleton)
        
        if len(curr_branch_face_correspondence) > 0:
            sub_limb_mesh.append(curr_limb_mesh.submesh([list(curr_branch_face_correspondence)],append=True,repair=False))
        else:
            sub_limb_mesh.append(trimesh.Trimesh(vertices=np.array([]),faces=np.array([])))
    
    branch_meshes_MAP.append(sub_limb_mesh)
    branch_skeletons_MAP.append(curr_limb_branches_sk_uneven)
    branch_meshes_idx_MAP.append(sub_limb_corr)
    branch_widths_MAP.append(sub_limb_width)
        
print(f"Total time for mesh correspondence = {time.time() - start_time}")

In [None]:
for b_mesh,b_skel in zip(branch_meshes_MAP,branch_skeletons_MAP):
    nviz.plot_objects(meshes=b_mesh,
                      meshes_colors="random",
                     skeletons=b_skel,
                     skeletons_colors="random",
                     scatters=[curr_soma_to_piece_touching_vertices[0]],
                    scatter_size=0.3)