In [None]:
import numpy as np
import datajoint as dj
import trimesh

In [None]:
import os
import contextlib

def print_trimesh(current_mesh,file_name):
    with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        current_mesh.export(file_name)

In [None]:
pinky = dj.create_virtual_module("pinky","microns_pinky")
ta3p100 = dj.create_virtual_module("ta3p100","microns_ta3p100")



In [None]:
segment_id == None
mesh_index = 2

if segment_id == None:
    #gets the possible segment ids
    segment_ids = (pinky.AllenSoma() & "cell_class='excitatory'").fetch("segment_id")
    neuron_id = segment_ids[mesh_index]
else:
    neuron_id = segment_id

print(neuron_id)
key = dict(segment_id=segment_ids[mesh_index],segmentation=3)
vertices,triangles = (pinky.Mesh & key).fetch("vertices","triangles")

unfiltered_mesh = trimesh.Trimesh()
unfiltered_mesh.vertices = vertices[0]
unfiltered_mesh.faces = triangles[0]
#new_mesh.show()

In [None]:
def filter_mesh_significant_outside_pieces(unfiltered_mesh,significance_threshold=2000,n_sample_points=1000):
    """
    Purpose; will take in a full, unfiltered mesh and find the biggest mesh piece, and then return a list of that mesh 
    with all of the other mesh fragments that are both above the significance_threshold AND outside of the biggest mesh piece

    Pseudocode: 
    1) split the meshes to unconnected pieces
    2) Filter the meshes for only those above the significance_threshold
    3) find the biggest mesh piece
    4) Iterate through all of the remaining pieces:
        a. Determine if mesh inside or outside main mesh
        b. If outside add to final list to return

    Returns: 
    1) list of significant mesh pieces, including the main one that are not inside of main mesh

    """

    mesh_pieces = unfiltered_mesh.split(only_watertight=False)
    
    print(f"There were {len(mesh_pieces)} pieces after mesh split")

    significant_pieces = [m for m in mesh_pieces if len(m.faces) > significance_threshold]

    print(f"There were {len(significant_pieces)} pieces found after size threshold")
    if len(significant_pieces) <=0:
        print("THERE WERE NO MESH PIECES GREATER THAN THE significance_threshold")
        return []

    #find piece with largest size
    max_index = 0
    max_face_len = len(significant_pieces[max_index].faces)

    for i in range(1,len(significant_pieces)):
        if max_face_len < len(significant_pieces[i].faces):
            max_index = i
            max_face_len = len(significant_pieces[i].faces)

    print("max_index = " + str(max_index))
    print("max_face_len = " + str(max_face_len))

    final_mesh_pieces = []

    main_mesh = significant_pieces[max_index]

    #final_mesh_pieces.append(main_mesh)
    for i,mesh in enumerate(significant_pieces):
        if i != max_index:
            #get a random sample of points
            # points = np.array(mesh.vertices[:n_sample_points,:]) # OLD WAY OF DOING THIS
            idx = np.random.randint(len(mesh.vertices), size=n_sample_points)
            points = mesh.vertices[idx,:]
            
            
            start_time = time.time()
            signed_distance = trimesh.proximity.signed_distance(main_mesh,points)
            print(f"Total time = {time.time() - start_time}")

            outside_percentage = sum(signed_distance < 0)/n_sample_points
            if outside_percentage > 0.9:
                final_mesh_pieces.append(mesh)
                print(f"Mesh piece {i} OUTSIDE mesh")
            else:
                print(f"Mesh piece {i} inside mesh :( ")
                
    return main_mesh,final_mesh_pieces



In [None]:
#Runs the filtering function for inside and outside meshes
global_timer = time.time()


#setting thresholds
significance_threshold=10 #number of faces needed for pieces to be considered to be kept
n_sample_points = 3 #number of points sampled on the mesh for determination of inside or outside
start_time = time.time()

#the main mesh is the first mesh in the piece
main_mesh,child_meshes = filter_mesh_significant_outside_pieces(unfiltered_mesh,
                            significance_threshold=significance_threshold,
                                n_sample_points=n_sample_points)
print(f"Total time for Mesh Cleansing: {time.time() - start_time}")


# HOW TO SAVE AND LOAD OF THE MESHES