In [1]:
"""
Purpose: To create a table that can be autopopulated to find the soma centeres
- table will be a dj.Computed that will populate from whichever table created
    by Christos and Stelios that has version numbers and ids of the postsynaptic targets

What table will save: 
- soma center
- soma mesh (vertices and triangles)

"""

'\nPurpose: To create a table that can be autopopulated to find the soma centeres\n\nWhat table will save: \n- soma center\n- soma mesh (vertices and triangles)\n\n'

In [None]:
import cgal_Segmentation_Module as csm
from whole_neuron_classifier_datajoint_adapted import extract_branches_whole_neuron
import time
import trimesh
import numpy as np
import datajoint as dj

m65 = dj.create_virtual_module('m65', 'microns_minnie65_01')
schema = dj.schema("microns_minnie65_01")

# mesh stitching functions

In [None]:
#combine all the meshes into one mesh
def add_mesh_piece(main_mesh_vertices,main_mesh_faces,sub_mesh_vertices,sub_mesh_faces):
    """
    Purpose: Takes in a large mesh piece and an array of other meshes and 
    returns a large mesh with all meshes appended
    
    Parameters:
    main_mesh_vertices (np.array) : np array store the vertices as rows and the elements as the coordinates
    main_mesh_faces (np.array) : np array store the faces as rows and the elements as the referenced vertices
    sub_mesh_vertices(list of np.arrays) : list of np arrays with the vertices arrays for all subsegments to be added
    sub_mesh_faces(list of np.arrays) : list of np arrays with the faces arrays for all subsegments to be added
    
    Returns:
    mesh_vertices (np.array) : np array store the vertices as rows and the elements as the coordinates for NEW CONCATENATED MESH
    mesh_faces (np.array) : np array store the faces as rows and the elements as the referenced vertices for NEW CONCATENATED MESH
    
    
    Pseudocode: 
    - Checks: 
    a. Make sure there sub_mesh arrays are greater than 0 and of the same length

    1) Count the number of vertices and faces in the main mesh
    2) Iterate through the submesh vertices and faces. In loop:
    a. Count the number of vertices in the submesh and concate the vertices arrays to the main mesh array
    b. Add the vertices_count and add that to every number in the faces array
    c. Concatenate the submesh faces onto the larger mesh face
    d. Save this new vertices and faces as the main_mesh verts and faces
    e. Print out how many new vertices and faces added
    3) Print out number of segments added, total faces/vertices for new mesh
    4) Return the main mesh vertices and faces
    
    """
    #a. Make sure there sub_mesh arrays are greater than 0 and of the same length
    if len(sub_mesh_vertices) <= 0:
        print("There were no vertices in submesh to add, returning main mesh")
        return main_mesh_vertices, main_mesh_faces
    if len(sub_mesh_faces) <= 0:
        print("There were no face in submesh to add, returning main mesh")
        return main_mesh_vertices, main_mesh_faces
    if len(sub_mesh_faces) != len(sub_mesh_vertices):
        raise Exception("The sub_mesh_faces and sub_mesh_vertices length did not match")
        
    
    #1) Count the number of vertices and faces in the main mesh
    n_main_vertices = len(main_mesh_vertices)
    n_main_faces = len(main_mesh_faces)
    
    
    #2) Iterate through the submesh vertices and faces. In loop:
    for i,(sub_verts, sub_faces) in enumerate(zip(sub_mesh_vertices,sub_mesh_faces)):
        #a. Count the number of vertices in the submesh and concate the vertices arrays to the main mesh array
        n_sub_verts = len(sub_verts)
        n_sub_faces = len(sub_faces)
        
        main_mesh_vertices = np.vstack([main_mesh_vertices,sub_verts])

        
        #b. Add the vertices_count of main to every number in the faces array
        sub_faces = sub_faces + n_main_vertices
        
        #c. Concatenate the submesh faces onto the larger mesh face
        main_mesh_faces = np.vstack([main_mesh_faces,sub_faces])
        
        #d. Save this new vertices and faces as the main_mesh verts and faces (DONE)
        
        #e. Print out how many new vertices and faces added
        #print(f"Added subsegment {i} with {n_sub_verts} vertices and {n_sub_faces} faces")
        
        n_main_vertices = len(main_mesh_vertices)
        n_main_faces = len(main_mesh_faces)
    
    #3) Print out number of segments added, total faces/vertices for new mesh  
    print(f"Added {len(sub_mesh_vertices)} subsegements \n  --> final mesh: {len(main_mesh_vertices)} vertices and {len(main_mesh_faces)} faces")
        
    return main_mesh_vertices,main_mesh_faces 

# meshlab functions

In [None]:
def run_meshlab_script(mlx_script,input_mesh_file,output_mesh_file):
    script_command = (" -i " + str(input_mesh_file) + " -o " + 
                                    str(output_mesh_file) + " -s " + str(mlx_script))
    #return script_command
    command_to_run = 'xvfb-run -a -s "-screen 0 800x600x24" meshlabserver $@ ' + script_command
    #command_to_run = 'meshlabserver ' + script_command
    
    print(command_to_run)
    subprocess_result = subprocess.run(command_to_run,shell=True)
    
    return subprocess_result

import os, contextlib
import pathlib
import subprocess
def meshlab_fix_manifold_path_specific_mls(input_path_and_filename,
                                           output_path_and_filename="",
                                           segment_id=-1,meshlab_script=""):
    #fix the path if it comes with the extension
    if input_path_and_filename[-4:] == ".off":
        path_and_filename = input_path_and_filename[:-4]
        input_mesh = input_path_and_filename
    else:
        raise Exception("Not passed off file")
    
    
    if output_path_and_filename == "":
        output_mesh = path_and_filename+"_mls.off"
    else:
        output_mesh = output_path_and_filename
    
    if meshlab_script == "":
        meshlab_script = str(pathlib.Path.cwd()) + "/" + "remeshing_remove_non_man_edges.mls"
    
    #print("meshlab_script = " + str(meshlab_script))
    #print("starting meshlabserver fixing non-manifolds")
    subprocess_result_1 = run_meshlab_script(meshlab_script,
                      input_mesh,
                      output_mesh)
    #print("Poisson subprocess_result= "+ str(subprocess_result_1))
    
    if str(subprocess_result_1)[-13:] != "returncode=0)":
        raise Exception('neuron' + str(segment_id) + 
                         ' did not fix the manifold edges')
    
    return output_mesh


In [None]:
@schema
def SomaCenters(dj.Computed):
    
    definition="""
    -> m65.[************TABLE TO BE FILLED IN***********************]
    ---
    soma_center             : longblob                 # the xyz coordinates of the soma (with the [4,4,40] adjustment already applied)
    vertices            : longblob                     # vertices for soma mesh
    faces                : longblob                    # faces array for soma mesh
    """
    
    def make(self,key):
        segment_id = key["segment_id"]
        version = key["version"]
        
        """
        .[************ retrieve the vertices and faces array of the mesh .[************
        new_mesh_vertices, new_mesh_faces =
        """
        
        """
        start MLS remeshing
        
        """
        
        # make sure temp folder exists, if not then create one
        import os
        directory = "./temp"
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        original_main = trimesh.Trimesh(new_mesh_vertices,new_mesh_faces)
        output_mesh_name = "temp/" + str(segment_id) + "_original.off"
        original_main.export("./" + output_mesh_name)
        
        import pathlib
        # run the meshlab server script
        script_name = "poisson_working_meshlab.mls"
        meshlab_script_path_and_name = str(pathlib.Path.cwd()) + "/" + script_name
        input_path =str(pathlib.Path.cwd()) + "/" +  output_mesh_name

        indices = [i for i, a in enumerate(input_path) if a == "_"]
        stripped_ending = input_path[:-(len(input_path)-indices[-1])]

        output_path = stripped_ending + "_mls.off"
        print(meshlab_script_path_and_name)
        print(input_path)
        print(output_path)
        print("Running the mls function")
        meshlab_fix_manifold_path_specific_mls(input_path_and_filename=input_path,
                                                   output_path_and_filename=output_path,
                                                   segment_id=segment_id,
                                                   meshlab_script=meshlab_script_path_and_name)
        
        """
        start the CGAL segmentation:
        """
        new_mesh = trimesh.load_mesh(output_path)
        
        mesh_splits = new_mesh.split(only_watertight=True)

        len("Total mesh splits = " + str(mesh_splits))
        #get the largest mesh
        mesh_lengths = np.array([len(split.faces) for split in mesh_splits])

        # import matplotlib.pyplot as plt
        # import seaborn as sns
        # sns.set()
        # sns.distplot(mesh_lengths)

        largest_index = np.where(mesh_lengths == np.max(mesh_lengths))
        largest_mesh = mesh_splits[largest_index][0]


        indices = [i for i, a in enumerate(output_path) if a == "_"]
        stripped_ending = output_path[:-(len(output_path)-indices[-1])]
        largest_mesh_path = stripped_ending + "_largest_piece.off"

        largest_mesh.export(largest_mesh_path)
        print("done exporting")
        
        
        faces = np.array(largest_mesh.faces)
        verts = np.array(largest_mesh.vertices)
        #run the whole algorithm on the neuron to test
        verts_labels, faces_labels = extract_branches_whole_neuron(import_Off_Flag=False,segment_id=segment_id,vertices=verts,
                             triangles=faces,pymeshfix_Flag=False,
                             import_CGAL_Flag=False,
                             return_Only_Labels=True,
                             clusters=3,
                             smoothness=0.2)
        
        soma_faces = np.where(faces_labels == 5.0)[0]
        soma_mesh = largest_mesh.submesh([soma_faces],append=True)
        
        soma_center = soma_mesh.vertices.mean(axis=0).astype("float")
        soma_center = soma_center/np.array([4,4,40])
        print("Poor man's center from just averagin vertices = " + str(soma_center))
        
        
        
        insert_key = dict(key)
        insert_key["soma_center"] = soma_center
        insert_key["vertices"] = soma_mesh.vertices
        insert_key["faces"] = soma_mesh.faces
        
        #4) Insert the key into the table
        self.insert1(insert_key,skip_duplicates=True)

In [None]:
#(schema.jobs & "table_name='__whole_auto_annotations_label_clusters3'")#.delete()
import time
start_time = time.time()
SomaCenters.populate(reserve_jobs=True)
print(f"Total time for SomaCenters populate = {time.time() - start_time}")