In [None]:
"""
Purpose: Clean implmentation that 
takes a fully processed neuron and is able to 
split the limbs that were accidentally included together

Pseudocode

"""

In [1]:
from os import sys
sys.path.append("/meshAfterParty/")
#sys.path.append("../../meshAfterParty/meshAfterParty")

In [2]:
from importlib import reload
import os
from pathlib import Path
os.getcwd()

import neuron_utils as nru
nru = reload(nru)
import neuron
neuron=reload(neuron)
import neuron_visualizations as nviz
import time
import system_utils as su



# Importing Test Neuron

In [None]:
current_file = "/notebooks/test_neurons/meshafterparty_processed/12345_double_soma_meshafterparty"
neuron_obj = nru.decompress_neuron(filepath=current_file,
                                  original_mesh=current_file,
                                  minimal_output=True)


In [3]:
import skeleton_utils as sk
from pykdtree.kdtree import KDTree
import trimesh_utils as tu
import numpy as np
import networkx_utils as xu
import compartment_utils as cu
import networkx as nx
import numpy_utils as nu
import copy
import general_utils as gu
import system_utils as su

In [None]:
current_neuron = neuron_obj
nru = reload(nru)
neuron = reload(neuron)
current_neuron = neuron.Neuron(current_neuron)

In [None]:
current_neuron[0].labels,current_neuron[1].labels

# Pre-Step 0: Prepare the containers to store the new data

In [4]:
def check_if_branch_needs_splitting(curr_limb,soma_idx,curr_soma_mesh,
                                    significant_skeleton_threshold=30000,
                                   print_flag=False):
    
    
    #1) get the starting node:
    curr_starting_branch_idx = curr_limb.get_starting_branch_by_soma(soma_idx)
    #a. Get the staring branch skeleton and find the closest skeleton point to the soma border
    curr_branch = curr_limb[curr_starting_branch_idx]
    curr_branch_sk = curr_branch.skeleton
    #b. find the closest skeleton point to the soma border
    curr_soma_border = curr_limb.get_concept_network_data_by_soma(soma_idx)["touching_soma_vertices"]
    unique_skeleton_nodes = np.unique(curr_branch_sk.reshape(-1,3),axis=0)
    
    curr_soma_border_kdtree = KDTree(curr_soma_border)
    distances,closest_node = curr_soma_border_kdtree.query(unique_skeleton_nodes)
    cut_coordinate = unique_skeleton_nodes[np.argmin(distances),:]
    
    #c. cut the limb skeleton at that point
    curr_limb_sk_graph = sk.convert_skeleton_to_graph(curr_limb.skeleton)

    node_to_cut = xu.get_nodes_with_attributes_dict(curr_limb_sk_graph,dict(coordinates=cut_coordinate))

    if len(node_to_cut) != 1:
        raise Exception("Node to cut was not of length 1")

    node_to_cut = node_to_cut[0]
    
    #c. Seperate the graph into 2 components, If there are 2 connected components after cut, are the connected components both significant
    curr_limb_sk_graph.remove_node(node_to_cut)

    seperated_components = list(nx.connected_components(curr_limb_sk_graph))

    if len(seperated_components) <= 1:
        return None
        #raise Exception(f"Continue to next limb because number of seperated_components = {len(seperated_components)}")

    #c1. Seperate the graph into subgraph based on components and output the skeletons from each
    seperated_skeletons = [sk.convert_graph_to_skeleton(curr_limb_sk_graph.subgraph(k)) for k in seperated_components]
    skeleton_lengths = np.array([sk.calculate_skeleton_distance(k) for k in seperated_skeletons])


    n_significant_skeletons = np.sum(skeleton_lengths>significant_skeleton_threshold)
    if print_flag:
        print(f"n_significant_skeletons={n_significant_skeletons}")
        print(f"skeleton_lengths = {skeleton_lengths}")

    if n_significant_skeletons < 2:
        return None
    else:
        return cut_coordinate
        #raise Exception(f"Continue to next limb because n_significant_skeletons = {n_significant_skeletons} with lengths {skeleton_lengths}")

In [18]:
def split_limb_on_soma(curr_limb,soma_idx,curr_soma_mesh,
                       current_neuron_mesh,
                       soma_meshes,
                       cut_coordinate=None,
                      print_flag=False):
    #1) get the starting node:
    curr_starting_branch_idx = curr_limb.get_starting_branch_by_soma(soma_idx)
    #a. Get the staring branch skeleton and find the closest skeleton point to the soma border
    curr_branch = curr_limb[curr_starting_branch_idx]
    curr_branch_sk = curr_branch.skeleton
    
    if cut_coordinate is None:
        if print_flag:
            print("Having to recalculate the cut coordinate")
        
        #b. find the closest skeleton point to the soma border
        curr_soma_border = curr_limb.get_concept_network_data_by_soma(soma_idx)["touching_soma_vertices"]
        unique_skeleton_nodes = np.unique(curr_branch_sk.reshape(-1,3),axis=0)

        curr_soma_border_kdtree = KDTree(curr_soma_border)
        distances,closest_node = curr_soma_border_kdtree.query(unique_skeleton_nodes)
        cut_coordinate = unique_skeleton_nodes[np.argmin(distances),:]
    
    
    
    # --------------- Part A: Finding the new split skeletons ------------------------- #
    # Finding the node to cut on the BRANCH skeleton
    curr_branch_sk_graph = sk.convert_skeleton_to_graph(curr_branch_sk)

    if print_flag:
        print(f"cut_coordinate={cut_coordinate}")
    node_to_cut = xu.get_nodes_with_attributes_dict(curr_branch_sk_graph,dict(coordinates=cut_coordinate))
    if len(node_to_cut) != 1:
        raise Exception("Node to cut was not of length 1")

    node_to_cut = node_to_cut[0]
    
    G = curr_branch_sk_graph
    endpoint_nodes = xu.get_nodes_of_degree_k(G,1)
    endpoint_nodes_coord = xu.get_node_attributes(curr_branch_sk_graph,node_list=endpoint_nodes)
    paths_to_endpt = [nx.dijkstra_path(G,node_to_cut,k) for k in endpoint_nodes]
    path_lengths = np.array([len(k) for k in paths_to_endpt])
    closest_endpoint = np.argmin(path_lengths)
    farthest_endpoint = 1 - closest_endpoint
    closest_endpoint_len = path_lengths[closest_endpoint]

    if closest_endpoint_len <= 1:
        #need to readjust the node_to_cut and paths
        print("Having to readjust endpoint")
        node_to_cut = paths_to_endpt[farthest_endpoint][1]
        paths_to_endpt = [nx.dijkstra_path(G,node_to_cut,k) for k in endpoint_nodes]

    #generate the subnode in each graph
    paths_to_endpt[farthest_endpoint].remove(node_to_cut)

    subgraph_list = [G.subgraph(k) for k in paths_to_endpt]

    starting_endpoints = xu.get_node_attributes(curr_branch_sk_graph,node_list=[k[0] for k in paths_to_endpt])


    #export the skeletons of the subgraphs
    exported_skeletons = [sk.convert_graph_to_skeleton(s) for s in subgraph_list]
    endpoint_nodes_coord #will have the endpoints belonging to each split
    
    
    # --------------- Part B: Getting Initial Mesh Correspondence ------------------------- #
    """
    3) Do Mesh correspondnece to get new branch meshes for the split skeleton
    - where have to do face resolving as well
    4) Check that both of the meshes are touching the soma
    5) If one of them is not touching the soma
    --> do water growing algorithm until it is

    """

    div_st_branch_face_corr = []
    div_st_branch_width = []
    for sub_sk in exported_skeletons:
        curr_branch_face_correspondence, width_from_skeleton = cu.mesh_correspondence_adaptive_distance(sub_sk,
                                                  curr_branch.mesh,
                                                 skeleton_segment_width = 1000,
                                                 distance_by_mesh_center=True)
        div_st_branch_face_corr.append(curr_branch_face_correspondence)
        div_st_branch_width.append(width_from_skeleton)
    
    
    divided_submeshes,divided_submeshes_idx = cu.groups_of_labels_to_resolved_labels(current_mesh = curr_branch.mesh,
                                          face_correspondence_lists=div_st_branch_face_corr)
    
    
#     # ------------ Intermediate part where intentionally messing up --------------- #    
#     label_to_expand = 1

#     #0) Turn the mesh into a graph
#     total_mesh_graph = nx.from_edgelist(curr_branch.mesh.face_adjacency)

#     #1) Get the nodes that represent the border
#     border_vertices =  curr_limb.get_concept_network_data_by_soma(soma_idx)["touching_soma_vertices"]
#     border_faces = tu.vertices_coordinates_to_faces(curr_branch.mesh,border_vertices)

#     label_face_idx = divided_submeshes_idx[label_to_expand]

#     final_faces = label_face_idx.copy()

#     for i in range(0,40):
#         final_faces = np.unique(np.concatenate([xu.get_neighbors(total_mesh_graph,k) for k in final_faces]))

#     other_mesh_faces = np.setdiff1d(np.arange(0,len(curr_branch.mesh.faces)),final_faces)
    
#     divided_submeshes_idx = [other_mesh_faces,final_faces]
#     divided_submeshes = [curr_branch.mesh.submesh([k],append=True) for k in divided_submeshes_idx]
    
#     sk.graph_skeleton_and_mesh(other_meshes=list(divided_submeshes),
#                            other_meshes_colors=["black","red"],
#                           other_skeletons=exported_skeletons,
#                           other_skeletons_colors=["black","red"],)
    
    
    
    # ---------------- Part C: Checking that both pieces are touching the soma ------------- #
    touching_pieces,touching_pieces_verts = tu.mesh_pieces_connectivity(main_mesh=tu.combine_meshes([curr_branch.mesh,curr_soma_mesh]),
                           central_piece=curr_soma_mesh,
                           periphery_pieces=divided_submeshes,
                           merge_vertices=True,
                            return_vertices=True,
                           print_flag=False)
    if print_flag:
        print(f"touching_pieces = {touching_pieces}")
        
        
        
    # --------------- Part D: Doing Waterfilling Unitl Both Pieces are Touching Soma ------------- #
    if len(touching_pieces) == 0:
        raise Exception("There were none of the new meshes that were touching the soma")
    if len(touching_pieces) < 2:
        #find which piece was not touching
        label_to_expand = 1 - touching_pieces[0]
        print(f"new_mesh {label_to_expand} was not touching the mesh so need to expand until touches soma")

        #0) Turn the mesh into a graph
        total_mesh_graph = nx.from_edgelist(curr_branch.mesh.face_adjacency)

        #1) Get the nodes that represent the border
        border_vertices =  curr_limb.get_concept_network_data_by_soma(soma_idx)["touching_soma_vertices"]
        border_faces = set(tu.vertices_coordinates_to_faces(curr_branch.mesh,border_vertices))

        label_face_idx = divided_submeshes_idx[label_to_expand]

        final_faces = label_face_idx.copy()

        n_touching_soma = 0
        counter = 0
        while n_touching_soma < 10:
            final_faces = np.unique(np.concatenate([xu.get_neighbors(total_mesh_graph,k) for k in final_faces]))
            n_touching_soma = len(border_faces.intersection(set(final_faces)))
            counter+= 1


        other_mesh_faces = np.setdiff1d(np.arange(0,len(curr_branch.mesh.faces)),final_faces)

        


        print(f"Took {counter} iterations to expand the label back")

        divided_submeshes_idx[label_to_expand] = final_faces
        divided_submeshes_idx[touching_pieces[0]] = other_mesh_faces

        #Need to fix the labels one more time to make sure the expansion did not cut off one of the labels
        print(f"divided_submeshes_idx = {divided_submeshes_idx}")
        divided_submeshes,divided_submeshes_idx = cu.groups_of_labels_to_resolved_labels(curr_branch.mesh,divided_submeshes_idx)

        print(f"divided_submeshes_idx = {divided_submeshes_idx}")

        divided_submeshes = [curr_branch.mesh.submesh([k],append=True) for k in divided_submeshes_idx]

        #recalculate the border vertices and the list should be 2
        touching_pieces,touching_pieces_verts = tu.mesh_pieces_connectivity(main_mesh=tu.combine_meshes([curr_branch.mesh,curr_soma_mesh]),
                               central_piece=curr_soma_mesh,
                               periphery_pieces=divided_submeshes,
                               merge_vertices=True,
                                return_vertices=True,
                               print_flag=False)
        if len(touching_pieces) != 2:
            raise Exception(f"Number of touching pieces not equal to 2 even after correction: {touching_pieces}")

    soma_border_verts = touching_pieces_verts

#     sk.graph_skeleton_and_mesh(other_meshes=list(divided_submeshes),
#                                other_meshes_colors=["black","red"],
#                               other_skeletons=exported_skeletons,
#                               other_skeletons_colors=["black","red"],
#                               other_scatter=[endpoint_nodes_coord[0].reshape(-1,3),endpoint_nodes_coord[1].reshape(-1,3)],
#                                other_scatter_colors=["black","red"],
#                               scatter_size=1)


    
    # ----------------- Part E: Check that the mesh can't be split ----------------- #
    
    # check that the mesh can't be split
    for j,sub in enumerate(divided_submeshes):
        c_mesh,c_indic = tu.split(sub)
        if len(c_mesh) > 1:
            raise Exception(f"New Mesh {j} had {len(c_mesh)} pieces after split")

            
    # ----------------- Part F: Reorganize the Concept Network ----------------- #
    neighbors_to_starting_node = xu.get_neighbors(curr_limb.concept_network,curr_starting_branch_idx)
    """
    sk.graph_skeleton_and_mesh(other_meshes=[curr_limb[k].mesh for k in neighbors_to_starting_node + [curr_starting_branch_idx]],
                              other_meshes_colors="random")
    """
    
    match=dict([(k,[]) for k in neighbors_to_starting_node])
    for ex_neighbor in neighbors_to_starting_node:
        ex_neighbor_branch = curr_limb[ex_neighbor]
        for j,endpt in enumerate(endpoint_nodes_coord):
            if len(nu.matching_rows(ex_neighbor_branch.endpoints,endpt))>0:
                match[ex_neighbor].append(j)

    print(f"match = {match}")
    #make sure that there was only one match
    for k,v in match.items():
        if len(v) != 1:
            raise Exception(f"Neighbor {k} did not have one matching but instead had {v}")
    
   
    concept_network_copy = copy.deepcopy(curr_limb.concept_network)
    concept_network_copy.remove_node(curr_starting_branch_idx)
    concept_conn_comp = list(nx.connected_components(concept_network_copy))

    print(f"match = {concept_conn_comp}")
    #divide up the connected components into the groups they belong to
    new_branch_groups = [[],[]]
    for c in concept_conn_comp:
        #find the matching neighbor in that
        matching_neighbor = c.intersection(set(neighbors_to_starting_node))
        if len(matching_neighbor) != 1:
            raise Exception(f"matching_neighbor was not size 1 : {matching_neighbor}")
        matching_neighbor = list(matching_neighbor)[0]
        new_branch_groups[match[matching_neighbor][0]].extend(list(c))

    print(f"new_branch_groups = {new_branch_groups}")
    if print_flag:
        print(f"new_branch_groups = {new_branch_groups}")
#     #check that the lists are not empty
#     for i,g in enumerate(new_branch_groups):
#         if len(g) == 0:
#             raise Exception(f"New branch group {i} was empty after dividing the rest of the nodes")


        
        
    
    # Visualize that correctly split
#     divided_neighbor_meshes = [tu.combine_meshes([curr_limb[k].mesh for k in curr_group]) for curr_group in new_branch_groups]
#     divided_neighbor_meshes_with_original = [tu.combine_meshes([k,v]) for k,v in zip(divided_neighbor_meshes,divided_submeshes)]
#     #sk.graph_skeleton_and_mesh(other_meshes=)
#     sk.graph_skeleton_and_mesh(other_meshes=divided_neighbor_meshes_with_original,
#                               other_meshes_colors=["black","red"],
#                               other_skeletons=exported_skeletons,
#                               other_skeletons_colors=["black","red"],
#                               other_scatter=[endpoint_nodes_coord[0].reshape(-1,3),endpoint_nodes_coord[1].reshape(-1,3)],
#                                other_scatter_colors=["black","red"],)

    
    # ----------------- Part G: Put Everything Back into a Limb Object ----------------- #
    new_limbs = []
    for curr_new_branch_idx in range(len(new_branch_groups)):
        print(f"\n--- Working on new limb {curr_new_branch_idx} -------")
        
        #new_limb_dict[curr_new_branch_idx]["soma_border_verts"] = soma_border_verts[curr_new_branch_idx]

        # a) Creating the new concept network
        curr_limb_divided_skeletons =  [curr_limb[k].skeleton for k in new_branch_groups[curr_new_branch_idx]] + [exported_skeletons[curr_new_branch_idx]]
        closest_endpoint = starting_endpoints[curr_new_branch_idx]
        endpoints = neuron.Branch(exported_skeletons[curr_new_branch_idx]).endpoints
        curr_limb_concept_network = nru.branches_to_concept_network(curr_limb_divided_skeletons,closest_endpoint,np.array(endpoints).reshape(-1,3),
                                            touching_soma_vertices= soma_border_verts[curr_new_branch_idx])

        #Run some checks on the new concept network developed
        curr_starting_branch_idx= nru.check_concept_network(curr_limb_concept_network,closest_endpoint = closest_endpoint,
                                  curr_limb_divided_skeletons=curr_limb_divided_skeletons,print_flag=True)[0]

        # b) Creating the new mesh

        """Old way: 
        remaining_meshes_faces_idx =  [curr_limb[k].mesh_face_idx for k in new_branch_groups[curr_new_branch_idx]]
        remaining_meshes_faces_idx.append(np.array(curr_branch.mesh_face_idx[divided_submeshes_idx[curr_new_branch_idx]]))
        """

        new_limb_branch_face_idx = []
        remaining_meshes_face_idx = []
        total_face_count = 0
        for k in new_branch_groups[curr_new_branch_idx]:
            curr_face_idx  = curr_limb[k].mesh_face_idx
            remaining_meshes_face_idx.append(curr_face_idx)
            new_limb_branch_face_idx.append(np.arange(total_face_count,total_face_count+len(curr_face_idx)))
            total_face_count += len(curr_face_idx)

        last_face_idx = np.array(curr_branch.mesh_face_idx[divided_submeshes_idx[curr_new_branch_idx]])
        remaining_meshes_face_idx.append(last_face_idx)
        new_limb_branch_face_idx.append(np.arange(total_face_count,total_face_count+len(last_face_idx)))


        final_remaining_faces = np.concatenate(remaining_meshes_face_idx)                         
        curr_new_limb_mesh = curr_limb.mesh.submesh([final_remaining_faces],append=True,repair=False)

        """
        Checking that it went well:
        reovered_mesh = curr_new_limb_mesh.submesh([new_limb_branch_face_idx[2]],append=True,repair=False)
        original_mesh = curr_limb[new_branch_groups[curr_new_branch_idx][2]].mesh
        reovered_mesh,original_mesh
        """

        curr_limb_correspondence = dict()
        for j,neighb in enumerate(new_branch_groups[curr_new_branch_idx]):
            #calculate the new mesh correspondence
            curr_limb_correspondence[j] = dict(branch_skeleton = curr_limb[neighb].skeleton,
                                              width_from_skeleton = curr_limb[neighb].width,
                                              branch_mesh=curr_limb[neighb].mesh,
                                              branch_face_idx=new_limb_branch_face_idx[j])
        #add on the new mesh
        curr_limb_correspondence[len(new_branch_groups[curr_new_branch_idx])] = dict(
                        branch_skeleton = exported_skeletons[curr_new_branch_idx],
                        width_from_skeleton = div_st_branch_width[curr_new_branch_idx],
                        branch_mesh=divided_submeshes[curr_new_branch_idx],
                        branch_face_idx=new_limb_branch_face_idx[-1])

        # curr_limb_concept_network_dicts = [dict(starting_endpoints=endpoints,
        #                                        starting_node=curr_starting_branch_idx,
        #                                        starting_soma=soma_idx,
        #                                        starting_coordinate=closest_endpoint)]
        curr_limb_concept_network_dicts = {soma_idx:curr_limb_concept_network}

        new_limb_obj = neuron.Limb(mesh = curr_new_limb_mesh,
                                   curr_limb_correspondence=curr_limb_correspondence,
                                   concept_network_dict=curr_limb_concept_network_dicts)
        new_limb_obj.all_concept_network_data = nru.compute_all_concept_network_data_from_limb(new_limb_obj,
                                                                                               current_neuron_mesh=current_neuron_mesh,
                                                                                              soma_meshes=soma_meshes)

        new_limbs.append(new_limb_obj)
        #new_limb_dict[curr_new_branch_idx]["curr_starting_branch_idx"] = new_limb_obj.current_starting_node
        
    return new_limbs

In [19]:
def recursive_limb_splitting(curr_limb,soma_meshes,current_neuron_mesh,significant_skeleton_threshold=30000,
                            print_flag=False):
    """
    Purpose: To split the a limb as many times as needed if connected at the soma
    
    Pseudocode:
    1) Get all the somas that the limb is attached to (from all_concept_network_data)
    2) For each soma it is attached to, check if it needs to be split:
    
    If yes:
    a. Split the limb into its parts for that soma
    b. Compute the all_concept_network_data for all of the split limbs
    c. Start loop where send all of the limb objects through function and collect results
    d. concatenate results and return
    
    if No: 
    - continue to next soma
    
    if No and the last soma
    - return the limb object
    
    Arguments:
    1) Limb
    2) Soma
    
    Example: 
    ex_limb = current_neuron[2]
    split_limbs = recursive_limb_splitting(current_neuron,ex_limb)

    color_choices = ["red","black"]
    sk.graph_skeleton_and_mesh(other_meshes=[split_limbs[0].mesh,split_limbs[1].mesh],
                               other_meshes_colors=color_choices,
                               other_skeletons=[split_limbs[0].skeleton,split_limbs[1].skeleton],
                               other_skeletons_colors=color_choices)
    """

    #1) Get all the somas that the limb is attached to (from all_concept_network_data)
    total_somas_idx = curr_limb.touching_somas()
    total_soams_meshes = [soma_meshes[k] for k in total_somas_idx]
    
    if print_flag:
        print(f"total_somas_idx = {total_somas_idx}")
        print(f"total_soams_meshes = {total_soams_meshes}")
    
    #2) For each soma it is attached to, check if it needs to be split:
    for soma_idx,curr_soma_mesh in zip(total_somas_idx,total_soams_meshes):
        
        cut_coordinate = check_if_branch_needs_splitting(curr_limb,soma_idx,curr_soma_mesh,
                                   significant_skeleton_threshold=significant_skeleton_threshold,
                                   print_flag=print_flag)
        if print_flag:
            print(f"cut_coordinate = {cut_coordinate}")
        
        # If No then continue to next soma
        if cut_coordinate is None:
            continue
            
        #If yes:
        #a. Split the limb into its parts for that soma and
        #b. Compute the all_concept_network_data for all of the split limbs
        
        if print_flag:
            split_limb_objs = split_limb_on_soma(curr_limb,soma_idx,curr_soma_mesh,
                                                 current_neuron_mesh = current_neuron_mesh,
                                                 soma_meshes=soma_meshes,
                                                 cut_coordinate=cut_coordinate,
                                                print_flag=print_flag)
        else:
            with su.suppress_stdout_stderr():
                split_limb_objs = split_limb_on_soma(curr_limb,soma_idx,curr_soma_mesh,
                                                     current_neuron_mesh = current_neuron_mesh,
                                                     soma_meshes=soma_meshes,
                                                 cut_coordinate=cut_coordinate,
                                                print_flag=print_flag)
        
        if print_flag:
            print(f"split_limb_objs = {split_limb_objs}")
        
        total_split_limbs = []
        for split_limb in split_limb_objs:
            curr_results = recursive_limb_splitting(curr_limb=split_limb,
                                                    soma_meshes=soma_meshes,
                                                    current_neuron_mesh = current_neuron_mesh,
                                     significant_skeleton_threshold=significant_skeleton_threshold,
                                    print_flag=print_flag)
            total_split_limbs = total_split_limbs + curr_results
        return total_split_limbs
        
    #If Did not need to split any of then return the current limb
    if print_flag:
        print("Hit Recursive return point and returning limb")
    return [curr_limb]


# Part that will run limb splitting for all the limbs in the neuron 

In [None]:
neuron = reload(neuron)
current_neuron = neuron.Neuron(current_neuron)

# How to run it with the Neuron Object Available

In [40]:
nru = reload(nru)
neuron = reload(neuron)


from copy import deepcopy
def limb_split(limbs,soma_meshes,current_neuron_mesh,print_flag=False):
    """
    will map the [limb_idx AS A NUMBER][branch_idx] to 
    dict_keys(['branch_skeleton', 'width_from_skeleton', 'branch_mesh', 'branch_face_idx'])
    """
    new_limb_correspondence = dict() 
    """
    Maps Soma to who they are connected to
    Ex: {0: [0, 1, 3, 4, 5, 9], 1: [1, 2, 6, 7, 8]}
    """
    new_soma_to_piece_connectivity = dict([(k,[]) for k,v in enumerate(soma_meshes)])

    """
    Just a list that will hold all of the meshes
    """
    new_limb_meshes = []

    """
    a dictionary that maps the limb_idx to a dictionary mapping the soma_idx to the concept map
    {0:{0:Graph},
     1:{0:Graph,1:Graph},
     2:{1:Graph}....}

    ** can easily get this from the limb property concept_network_data_by_soma
    """
    new_limb_concept_networks = dict()

    """
    Labels for the limbs
    """
    new_limb_labels = []



    """
    Pseudocode: 
    Iterate through each split limb
    1) Get all of the split limbs from that one limb
    For each limb :
    -look at the current length of the new_limb_meshes to get the current index for limb
    a) Add a new entry in the new_limb_correspondence by iterating over the branches
    b) get the somas that touch the limb and add CURRENT INDEX them to the new_soma_to_piece_connectivity dictionary
    c) Add the limb mesh to new_limb_meshes
    d) Use the concept_network_data_by_soma attribute to get the concept_network dictionary and add to 
        new_limb_concept_networks
    e) make new merge labels based on the number of connections in the concept_network_data
    """



    for limb_idx,curr_limb in enumerate(limbs):
        print(f"\n----- Working on Limb {limb_idx}--------")
#         if limb_idx != 1:
#             continue
        split_limbs = recursive_limb_splitting(curr_limb,soma_meshes,
                                              current_neuron_mesh=current_neuron_mesh,
                                              print_flag=print_flag)

        print(f"Found {len(split_limbs)} limbs after limb split")

        for sp_limb in split_limbs:
            #-look at the current length of the new_limb_meshes to get the current index for limb
            new_limb_idx = len(new_limb_meshes)
            #a) Add a new entry in the new_limb_correspondence by iterating over the branches
            new_limb_correspondence[new_limb_idx] = dict()
            for curr_branch_idx in sp_limb.get_branch_names():
                curr_branch = sp_limb[curr_branch_idx]
                new_limb_correspondence[new_limb_idx][curr_branch_idx] = dict(
                                                        branch_skeleton = curr_branch.skeleton,
                                                        width_from_skeleton=curr_branch.width,
                                                        branch_mesh=curr_branch.mesh,
                                                        branch_face_idx=curr_branch.mesh_face_idx)
            #b) get the somas that touch the limb and add CURRENT INDEX them to the new_soma_to_piece_connectivity dictionary
            touching_somas = sp_limb.touching_somas()
            for s in touching_somas:
                new_soma_to_piece_connectivity[s].append(new_limb_idx)

            #c) Add the limb mesh to new_limb_meshes
            new_limb_meshes.append(sp_limb.mesh)

            #d) Use the concept_network_data_by_soma attribute to get the concept_network dictionary and add to 
            #new_limb_concept_networks
            concept_network_dict = dict()
            for s in sp_limb.touching_somas():
                print(f"Finished Soma {s}")
                sp_limb.set_concept_network_directional(starting_soma=s)
                concept_network_dict[s] = deepcopy(sp_limb.concept_network)
                print(f"concept_network_dict = {concept_network_dict}")
            
                previous_starting_node = xu.get_starting_node(concept_network_dict[s],only_one=False)
                print(f"previous_starting_node NEW = {previous_starting_node}")
                print(f'sp_limb.concept_network_directional.nodes[previous_starting_node] = {concept_network_dict[s].nodes[previous_starting_node[0]] }')

            new_limb_concept_networks[new_limb_idx] = concept_network_dict

            #e) make new merge labels based on the number of connections in the concept_network_data
            # OPtions: (['Normal'], ['MergeError'])
            if len(sp_limb.concept_network_data_by_soma.keys()) > 1:
                new_limb_labels.append("MergeError")
            else:
                new_limb_labels.append("Normal")
    
    return new_limb_correspondence,new_soma_to_piece_connectivity,new_limb_meshes,new_limb_concept_networks,new_limb_labels
            
            

In [None]:
# ex_limb = current_neuron[1]
# concept_network_dict = dict()
# for s in ex_limb.touching_somas():
#     ex_limb.set_concept_network_directional(starting_soma=s)
#     concept_network_dict[s] = ex_limb.concept_network

# concept_network_dict

# Create the preprocessed data and make a neuron out of it

In [None]:
neuron = reload(neuron)
limbs = [current_neuron[k] for k in current_neuron.get_limb_node_names(return_int=True)]
soma_meshes = [current_neuron.concept_network.nodes[nru.soma_label(k)]["data"].mesh for k in [0,1]]
current_neuron_mesh = current_neuron.mesh

(new_limb_correspondence,
 new_soma_to_piece_connectivity,
 new_limb_meshes,
 new_limb_concept_networks,
 new_limb_labels) = limb_split(limbs,soma_meshes,current_neuron_mesh)

In [None]:
limb_idx = 3
neuron= reload(neuron)
new_limb = neuron.Limb(mesh = new_limb_meshes[limb_idx],
                      curr_limb_correspondence=new_limb_correspondence[limb_idx],
                      concept_network_dict=new_limb_concept_networks[limb_idx],
                      )

In [None]:
new_limb.skeleton.shape

In [None]:
sk.graph_skeleton_and_mesh(other_meshes=[new_limb.mesh],
                           other_skeletons=[new_limb.skeleton,new_limb.skeleton[-10:]],
                          other_skeletons_colors=["red","black"])

In [None]:
new_limb_concept_networks

In [None]:
xu.get_all_nodes_with_certain_attribute_key(new_limb_concept_networks[1][0],"touching_soma_vertices")

In [None]:
xu.get_all_nodes_with_certain_attribute_key(new_limb_concept_networks[1][1],"touching_soma_vertices")

In [None]:
xu.get_all_nodes_with_certain_attribute_key(new_limb_concept_networks[1][1],"starting_coordinate")

In [None]:
new_limb_concept_networks

# Testing the pn function 

In [None]:
nru = reload(nru)
import preprocess_neuron as pn
pn = reload(pn)

limbs = [current_neuron[k] for k in current_neuron.get_limb_node_names(return_int=True)]
soma_meshes = [current_neuron.concept_network.nodes[nru.soma_label(k)]["data"].mesh for k in [0,1]]
current_neuron_mesh = current_neuron.mesh

(new_limb_correspondence,
 new_soma_to_piece_connectivity,
 new_limb_meshes,
 new_limb_concept_networks,
 new_limb_labels) = limb_split(limbs,soma_meshes,current_neuron_mesh)

# Debugging the incorrect concept networks

In [8]:
import system_utils as su
new_limb_objs= su.decompress_pickle("new_limb_objs.pbz2")
soma_meshes = su.decompress_pickle("soma_meshes.pbz2")
limb_concept_networks = su.decompress_pickle("limb_concept_networks.pbz2")
current_neuron_mesh = su.decompress_pickle("current_neuron.pbz2")

In [32]:
ex_limb = new_limb_objs[1]

In [None]:
import skeleton_utils as sk
sk.graph_skeleton_and_mesh(other_meshes=[ex_limb.mesh],
                          other_skeletons=[ex_limb.skeleton])

In [42]:
xu.get_all_nodes_with_certain_attribute_key(new_limb_concept_networks[1][0],"touching_soma_vertices")

{11: TrackedArray([[764369.4, 961579. , 877947.4],
               [763335.2, 961687.8, 879612.2],
               [763925.8, 960916.3, 879439.4],
               [763378.6, 961830.4, 878169.8],
               [764348. , 960489.9, 878697.4],
               [763189.8, 961895.4, 879268.1],
               [764502.8, 960665.2, 878480.4],
               [763201.2, 961529.1, 879621.8],
               [763385.6, 961419.9, 879647.8],
               [764359.6, 961425.1, 878059.9],
               [764093.4, 960579.9, 879224.1],
               [763293.6, 961809.8, 879253.9],
               [763290.2, 961772.8, 879492.6],
               [763181. , 962043.1, 879046.7],
               [764403. , 960454.9, 878594.8],
               [763599.8, 961072.9, 879668.7],
               [763677.2, 960913.9, 879506.8],
               [763037.8, 961926.4, 878059.2],
               [764585.2, 960803.8, 878456.6],
               [763032.2, 962070.7, 878683.1],
               [764290.6, 961644.6, 877834. ],
         

In [43]:
xu.get_all_nodes_with_certain_attribute_key(new_limb_concept_networks[1][1],"touching_soma_vertices")

{42: TrackedArray([[857494.2, 996403.1, 860228.6],
               [856763.8, 995194.4, 861137. ],
               [856815. , 995692.9, 861534.8],
               [856808.2, 995366.9, 861323.4],
               [857435.2, 996167.1, 861367.6],
               [857673. , 996411.6, 861035. ],
               [857012.2, 995560.1, 860352.4],
               [856939.4, 995636.8, 860228.3],
               [857700.6, 996514.1, 860354.4],
               [857855.4, 996494.9, 860550.9],
               [857581.2, 996336.4, 860069.7],
               [856718.8, 995567.6, 860360.4],
               [857359.4, 996054.1, 859950.1],
               [856829.2, 995195.4, 860931.1],
               [857512. , 996274.3, 861203.9],
               [856786.2, 995564.2, 861440.3],
               [856853.4, 995452.4, 860459. ],
               [857039.7, 995754.9, 860137.8],
               [857354.9, 996202.1, 859973.8],
               [856640.5, 995209.2, 860957.8],
               [857708.9, 996521.4, 860641.9],
         

In [44]:
nru = reload(nru)
import preprocess_neuron as pn
pn = reload(pn)


new_limb_objs
soma_meshes 
limb_concept_networks
current_neuron_mesh

(new_limb_correspondence,
 new_soma_to_piece_connectivity,
 new_limb_meshes,
 new_limb_concept_networks,
 new_limb_labels) = pn.limb_split(new_limb_objs,soma_meshes,current_neuron_mesh)


----- Working on Limb 0--------
Found 1 limbs after limb split

----- Working on Limb 1--------
Found 1 limbs after limb split

----- Working on Limb 2--------
Found 2 limbs after limb split

----- Working on Limb 3--------
Found 1 limbs after limb split

----- Working on Limb 4--------
Found 2 limbs after limb split

----- Working on Limb 5--------
Found 1 limbs after limb split

----- Working on Limb 6--------
Found 1 limbs after limb split

----- Working on Limb 7--------
Found 1 limbs after limb split

----- Working on Limb 8--------
Found 1 limbs after limb split

----- Working on Limb 9--------
Found 1 limbs after limb split


In [30]:
new_limb_concept_networks[1][0].nodes[11]

{'endpoints': array([[764206.9, 960501.9, 878885.6],
        [783215. , 967106. , 880939. ]]),
 'data': <neuron.Branch at 0x7f245471d518>}

In [None]:
k=2
debug_limb = new_limb_objs[4]
sk.graph_skeleton_and_mesh(other_meshes=[debug_limb[k].mesh],
                           other_meshes_colors=["red","black"],
                          other_skeletons=[debug_limb.skeleton])