In [1]:
"""
Purpose: Would extract the spines from a mesh and test
that you can make it completely manifold
"""

'\nPurpose: Would extract the spines from a mesh and test\nthat you can make it completely manifold\n'

In [2]:
import numpy as np
import datajoint as dj
import time
import pymeshfix
import os
import datetime
import calcification_Module as cm
from meshparty import trimesh_io

#for supressing the output
import os, contextlib
import pathlib
import subprocess

#for error counting
from collections import Counter

#for reading in the new raw_skeleton files
import csv



from Skeleton_Stitcher import stitch_skeleton_with_degree_check, find_skeleton_distance

In [3]:
#setting the address and the username
dj.config['database.host'] = '10.28.0.34'
dj.config['database.user'] = 'celiib'
dj.config['database.password'] = 'newceliipass'
dj.config['safemode']=True
dj.config["display.limit"] = 20

schema = dj.schema('microns_pinky')
pinky = dj.create_virtual_module('pinky', 'microns_pinky')


Connecting celiib@10.28.0.34:3306


In [4]:
#output for the skeleton edges to be stored by datajoint
""" OLD WAY THAT DATAJOINT WAS GETTING MAD AT 
def read_skeleton(file_path):
    with open(file_path) as f:
        bones = list()
        for line in f.readlines():
            bones.append(np.array(line.split()[1:], float).reshape(-1, 3))
    return np.array(bones)
"""

""" NEW FLAT LIST WAY, this is outdated for one below"""
#
def read_skeleton_flat(file_path):
    with open(file_path) as f:
        bones = list()
        for line in f.readlines():
            for r in (np.array(line.split()[1:], float).reshape(-1, 3)):
                bones.append(r)
            bones.append([np.nan,np.nan,np.nan])
    return np.array(bones).astype(float)


""" New read function: for adjusted 2 vert skeleton output"""
# def read_raw_skeleton(file_path):
#     edges = list()
#     with open(file_path) as f:
#         reader = csv.reader(f, delimiter=' ', quoting=csv.QUOTE_NONE)
#         for i,row in enumerate(reader):
#             v1 = (float(row[1]),float(row[2]),float(row[3]))
#             v2 = (float(row[4]),float(row[5]),float(row[6]))
#             edges.append((v1,v2))
#     return np.array(edges).astype(float)


def read_skeleton_revised(file_path):
    with open(file_path) as f:
        bones = np.array([])
        for line in f.readlines():
            #print(line)
            line = (np.array(line.split()[1:], float).reshape(-1, 3))
            #print(line[:-1])
            #print(line[1:])

            #print(bones.size)
            if bones.size <= 0:
                bones = np.stack((line[:-1],line[1:]),axis=1)
            else:
                bones = np.vstack((bones,(np.stack((line[:-1],line[1:]),axis=1))))
            #print(bones)


    return np.array(bones).astype(float)


In [5]:
#make sure there is a temp file in the directory, if not then make one
#if temp folder doesn't exist then create it
if (os.path.isdir(os.getcwd() + "/pymesh_neurons")) == False:
    os.mkdir("pymesh_neurons")

In [6]:
#create the output file
##write the OFF file for the neuron
import pathlib
def write_Whole_Neuron_Off_file(neuron_ID,
                                vertices=[], 
                                triangles=[],
                                folder="pymesh_neurons"):
    #primary_key = dict(segmentation=1, segment_id=segment_id, decimation_ratio=0.35)
    #vertices, triangles = (mesh_Table_35 & primary_key).fetch1('vertices', 'triangles')
    
    num_vertices = (len(vertices))
    num_faces = len(triangles)
    
    #get the current file location
    file_loc = pathlib.Path.cwd() / folder
    filename = "neuron_" + str(neuron_ID)
    path_and_filename = file_loc / filename
    
    #print(file_loc)
    #print(path_and_filename)
    
    #open the file and start writing to it    
    f = open(str(path_and_filename) + ".off", "w")
    f.write("OFF\n")
    f.write(str(num_vertices) + " " + str(num_faces) + " 0\n" )
    
    
    #iterate through and write all of the vertices in the file
    for verts in vertices:
        f.write(str(verts[0]) + " " + str(verts[1]) + " " + str(verts[2])+"\n")
    
    #print("Done writing verts")
        
    for faces in triangles:
        f.write("3 " + str(faces[0]) + " " + str(faces[1]) + " " + str(faces[2])+"\n")
    
    print("Done writing OFF file")
    #f.write("end")
    
    return str(path_and_filename),str(filename),str(file_loc)

In [7]:
def meshlab_fix_manifold(key,folder="pymesh_NEURITES"):
    
    file_loc = pathlib.Path.cwd() / folder
    filename = "neuron_" + str(key["segment_id"])
    path_and_filename = str(file_loc / filename)
    
    
    input_mesh = path_and_filename + ".off"
    output_mesh = path_and_filename+"_mls.off"
    
    
    meshlab_script = str(pathlib.Path.cwd()) + "/" + "remeshing_remove_non_man_edges.mls"
    
    print("starting meshlabserver fixing non-manifolds")
    subprocess_result_1 = run_meshlab_script(meshlab_script,
                      input_mesh,
                      output_mesh)
    #print("Poisson subprocess_result= "+ str(subprocess_result_1))
    
    if str(subprocess_result_1)[-13:] != "returncode=0)":
        raise Exception('neuron' + str(key["segment_id"]) + 
                         ' did not fix the manifold edges')
    
    return output_mesh

def meshlab_fix_manifold_path(path_and_filename,segment_id=-1):
    #fix the path if it comes with the extension
    if path_and_filename[-4:] == ".off":
        path_and_filename = path_and_filename[-4:]
    
    input_mesh = path_and_filename + ".off"
    output_mesh = path_and_filename+"_mls.off"
    
    #print("input_mesh = " + str(input_mesh))
    #print("output_mesh = " + str(output_mesh))
    
    meshlab_script = str(pathlib.Path.cwd()) + "/" + "remeshing_remove_non_man_edges.mls"
    
    print("starting meshlabserver fixing non-manifolds")
    subprocess_result_1 = run_meshlab_script(meshlab_script,
                      input_mesh,
                      output_mesh)
    #print("Poisson subprocess_result= "+ str(subprocess_result_1))
    
    if str(subprocess_result_1)[-13:] != "returncode=0)":
        raise Exception('neuron' + str(segment_id) + 
                         ' did not fix the manifold edges')
    
    return output_mesh

def meshlab_fix_manifold_path_specific_mls(path_and_filename,segment_id=-1,meshlab_script=""):
    #fix the path if it comes with the extension
    if path_and_filename[-4:] == ".off":
        path_and_filename = path_and_filename[-4:]
    
    input_mesh = path_and_filename + ".off"
    output_mesh = path_and_filename+"_mls.off"
    
    #print("input_mesh = " + str(input_mesh))
    #print("output_mesh = " + str(output_mesh))
    if meshlab_script == "":
        meshlab_script = str(pathlib.Path.cwd()) + "/" + "remeshing_remove_non_man_edges.mls"
    
    print("meshlab_script = " + str(meshlab_script))
    #print("starting meshlabserver fixing non-manifolds")
    subprocess_result_1 = run_meshlab_script(meshlab_script,
                      input_mesh,
                      output_mesh)
    #print("Poisson subprocess_result= "+ str(subprocess_result_1))
    
    if str(subprocess_result_1)[-13:] != "returncode=0)":
        raise Exception('neuron' + str(segment_id) + 
                         ' did not fix the manifold edges')
    
    return output_mesh



In [8]:
def run_meshlab_script(mlx_script,input_mesh_file,output_mesh_file):
    
    
    script_command = (" -i " + str(input_mesh_file) + " -o " + 
                                    str(output_mesh_file) + " -s " + str(mlx_script))
    #return script_command
    subprocess_result = subprocess.run('xvfb-run -a -s "-screen 0 800x600x24" meshlabserver $@ ' + 
                   script_command,shell=True)
    
    return subprocess_result

In [9]:
key_source = ((dj.U("segmentation","segment_id") & pinky.CoarseLabelFinal.proj()) 
+ (dj.U("segmentation","segment_id") & pinky.CoarseLabelOrphan.proj()))

key_source

segmentation  segmentation id,segment_id  segment id unique within each Segmentation
3,648518346341371119
3,648518346349386137
3,648518346349470171
3,648518346349471156
3,648518346349471500
3,648518346349471562
3,648518346349471565
3,648518346349471910
3,648518346349472574
3,648518346349472601


In [None]:
# key = dict(segmentation=3,segment_id=648518346341371119)
# split_significance_threshold = 100

# global_time = time.time()
# #get the mesh with the error segments filtered away
# start_time = time.time()
# new_key = remove_error_segments(key)
# print(f"Step 1: Retrieving Mesh and removing error segments: {time.time() - start_time}")

# #where i deal with the error segments
# if new_key["vertices"].size<2:
#     start_time = time.time()
#     print("All faces were error segments, inserting dummy entry")
#     #create the key with None
#     new_key["n_vertices"] = 0
#     new_key["n_triangles"] = 0
#     new_key["vertices"] = np.array([]).astype(float)
#     new_key["triangles"] = np.array([]).astype(float)
#     new_key["n_edges"] = 0
#     new_key["edges"] = np.array([]).astype(float)
#     new_key["n_bodies"] = 0
#     new_key["n_bodies_stitched"] = 0
#     new_key["largest_mesh_perc"] = 0
#     new_key["largest_mesh_distance_perc"] = 0
#     self.insert1(new_key,skip_duplicates=True)

#     #insert dummy dictionary into correspondence table
# #             new_correspondence_dict = dict(segmentation=key["segmentation"],
# #                                            segment_id=key["segment_id"],
# #                                            time_updated=str(datetime.datetime.now()),
# #                                            n_correspondence = 0,
# #                                            correspondence=np.array([]).astype(float))

# #             #if all goes well then write to correspondence database
# #             ta3p100.NeuronRawSkeletonCorrespondence.insert1(new_correspondence_dict,skip_duplicates=True)


#     print(f"Step 2: Inserting dummy dictionary: {time.time() - start_time}")
#     print(f"Total time: {time.time() - global_time}")
#     print("\n\n")

# else:
#     mesh = trimesh_io.Mesh(vertices=new_key["vertices"], faces=new_key["triangles"])
#     total_splits = mesh.split(only_watertight=False)
#     print(f"There were {len(total_splits)} after split and significance threshold")
#     mesh_pieces = [k for k in total_splits if len(k.faces) > split_significance_threshold]
#     print(f"There were {len(mesh_pieces)} after split and significance threshold")
#     for g,mh in enumerate(mesh_pieces):
#         print(f"Mesh piece {g} with number of faces {len(mh.faces)}")

#     print(f"Step 2a: Getting the number of splits: {time.time() - start_time}")

#     #get the largest mesh piece
#     largest_mesh_index = -1
#     largest_mesh_size = 0

#     for t,msh in enumerate(mesh_pieces):
#         if len(msh.faces) > largest_mesh_size:
#             largest_mesh_index = t
#             largest_mesh_size = len(msh.faces) 

#     #largest mesh piece
#     largest_mesh_perc = largest_mesh_size/len(mesh.faces)
#     new_key["largest_mesh_perc"] = largest_mesh_perc
#     print("largest mesh perc = " + str(largest_mesh_perc))

#     largest_mesh_skeleton_distance = -1

#     paths_used = []
#     total_edges = np.array([])

#     for h,m in enumerate(mesh_pieces): 
#         print(f"Working on split {h} with face total = {len(m.faces)}")



#         #print("Step 2: Remove all error semgents")
#         start_time = time.time()
#         #pass the vertices and faces to pymeshfix to become watertight
#         #meshfix = pymeshfix.MeshFix(new_key["vertices"],new_key["triangles"])
#         meshfix = pymeshfix.MeshFix(m.vertices,m.faces)
#         meshfix.repair(verbose=False,joincomp=True,remove_smallest_components=False)
#         print(f"Step 2: Pymesh shrinkwrapping: {time.time() - start_time}")

#         #print("Step 2: Writing Off File")
#         start_time = time.time()
#         #write the new mesh to off file
#         path_and_filename,filename,file_loc = write_Whole_Neuron_Off_file(str(new_key["segment_id"]),meshfix.v,meshfix.f)
#         print(f"Step 3: Writing shrinkwrap off file: {time.time() - start_time}")
#         paths_used.append(path_and_filename)

#         #Run the meshlabserver scripts
#         start_time = time.time()
#         #output_mesh = meshlab_fix_manifold(key) old way without path
#         output_mesh = meshlab_fix_manifold_path(path_and_filename,key["segment_id"])
#         print(f"Step 4: Meshlab fixing non-manifolds: {time.time() - start_time}")

#         print(output_mesh[:-4])

#         #send to be skeletonized
#         start_time = time.time()
#         return_value = cm.calcification(output_mesh[:-4])
#         if return_value > 0:
#             raise Exception('skeletonization for neuron ' + str(new_key["segment_id"]) + 
#                             ' did not finish... exited with error code: ' + str(return_value))
#         #print(f"Step 5: Generating Skeleton: {time.time() - start_time}")



#         #read in the skeleton files into an array
#         #start_time = time.time()

#         ##****** this needs to be changed for reading them in******
#         bone_array = read_skeleton_revised(output_mesh[:-4]+"_skeleton.cgal")
#         #correspondence_array = read_skeleton_revised(output_mesh[:-4]+"_correspondance.cgal")
#         #print(bone_array)
#         if len(bone_array) <= 0:
#             raise Exception('No skeleton generated for ' + str(new_key["segment_id"]))

# #             if len(correspondence_array) <= 0:
# #                 raise Exception('No CORRESPONDENCE generated for ' + str(new_key["segment_id"]))

#         print(f"Step 5: Generating and reading Skeleton: {time.time() - start_time}")

#         #get the largest mesh skeleton distance
#         if h == largest_mesh_index:
#             largest_mesh_skeleton_distance = find_skeleton_distance(bone_array)

#             #add the skeleton edges to the total edges
#         if not total_edges.any():
#             total_edges = bone_array
#         else:
#             total_edges = np.vstack([total_edges,bone_array])

#     total_edges_stitched = stitch_skeleton_with_degree_check(total_edges)
#     #get the total skeleton distance for the stitched skeleton
#     total_skeleton_distance = find_skeleton_distance(total_edges_stitched)
#     largest_mesh_distance_perc = largest_mesh_skeleton_distance/total_skeleton_distance

#     start_time = time.time()
#     new_key["n_edges"] = len(total_edges_stitched)
#     new_key["edges"] = total_edges_stitched
#     new_key["n_bodies"] = len(total_splits)
#     new_key["n_bodies_stitched"] = len(mesh_pieces)
#     new_key["largest_mesh_perc"] = largest_mesh_perc
#     new_key["largest_mesh_distance_perc"] = largest_mesh_distance_perc

#     #self.insert1(new_key,skip_duplicates=True)
#     print(f"Step 6: Inserting dictionary: {time.time() - start_time}")
#     #raise Exception("done with one neuron")
#     for path_and_filename in paths_used:
#         os.system("rm "+str(path_and_filename)+"*")

#     print(f"Total time: {time.time() - global_time}")
#     print("\n\n")



In [10]:
"""
Pseudocode code:
1) Get the labels
2) Get the decimated mesh
3) Get the undecimated mesh
4) Do a KD tree to map the decimated vertices to the undecimated and give it the labels
"""


'\nPseudocode code:\n1) Get the labels\n2) Get the decimated mesh\n3) Get the undecimated mesh\n4) Do a KD tree to map the decimated vertices to the undecimated and give it the labels\n'

In [27]:
search_key = dict(segment_id=648518346341371119,segmentation=3,decimation_ratio=0.35)
new_key = dict(segment_id = search_key["segment_id"],segmentation=search_key["segmentation"])
dec_vert_labels,dec_tri_labels = (pinky.OverlayedSpineLabel & search_key).fetch1("vertices","triangles")

In [12]:
new_key = dict(segment_id = search_key["segment_id"],segmentation=search_key["segmentation"])

In [14]:
#get the decimated mesh
dec_mesh_table = pinky.PymeshfixDecimatedExcitatoryStitchedMesh & search_key
dec_vertices, dec_triangles = dec_mesh_table.fetch1("vertices","triangles")


In [38]:
#make sure that the labels match up:
len(dec_vert_labels),len(dec_vertices)
#len(dec_tri_labels),len(dec_triangles)

(223210, 223210)

In [36]:
#get the undecimated mesh
undec_mesh_table = pinky.ExcitatoryStitchedMeshVp2 & new_key
undec_vertices, undec_triangles = undec_mesh_table.fetch1("vertices","triangles")
print(len(undec_triangles))

1274428


In [16]:
from scipy.spatial import KDTree
dec_KDTree = KDTree(dec_vertices)

In [17]:
start_time = time.time()
distances, nearest_nodes = dec_KDTree.query(undec_vertices)
print(f"Total time = {time.time() - start_time}")

Total time = 129.28571438789368


In [40]:
np.max(nearest_nodes)
print(Counter(dec_vert_labels))

Counter({2: 54795, 13: 44085, 15: 41153, 3: 35567, 5: 27356, 14: 11763, 4: 6051, 6: 1567, 10: 762, 12: 111})


In [23]:
from collections import Counter
Counter(undecimated_vert_labels)

Counter({15: 116243,
         14: 33249,
         4: 16801,
         13: 124971,
         2: 153983,
         3: 102038,
         5: 85413,
         10: 2137,
         6: 4620,
         12: 267})

In [42]:
distance_threshold = 200
undecimated_vert_labels_new = undecimated_vert_labels.copy()
#undecimated_vert_labels_new[np.where(distances>distance_threshold)] = 10
Counter(undecimated_vert_labels_new)

Counter({15: 116243,
         14: 33249,
         4: 16801,
         13: 124971,
         2: 153983,
         3: 102038,
         5: 85413,
         10: 2137,
         6: 4620,
         12: 267})

In [43]:
undec_triangles

array([[353238, 105573, 115153],
       [586809, 506429, 245211],
       [370585, 249401, 488177],
       ...,
       [397914, 174993, 446972],
       [397914, 446972, 509367],
       [577749, 598622, 406695]])

In [44]:
undec_triangles[:,0]

array([353238, 586809, 370585, ..., 397914, 397914, 577749])

In [45]:
traingle_labels = undecimated_vert_labels_new[undec_triangles[:,0]]
Counter(traingle_labels)

Counter({2: 306320,
         15: 231805,
         13: 249033,
         3: 203752,
         14: 66398,
         10: 4221,
         5: 169747,
         6: 9217,
         4: 33401,
         12: 534})

In [46]:
#save the follow lables to make sure that the migration went correctly
np.savez("undecimated_labels.npz",vertices=undecimated_vert_labels_new,triangles=traingle_labels)




In [30]:
np.savez("decimated_labels.npz",vertices=dec_vert_labels,triangles=dec_tri_labels)

In [29]:
pinky.OverlayedSpineLabel()

segmentation  segmentation id,segment_id  segment id unique within each Segmentation,decimation_ratio,vertices,triangles
3,648518346341371119,0.35,=BLOB=,=BLOB=
3,648518346349386137,0.35,=BLOB=,=BLOB=
3,648518346349470171,0.35,=BLOB=,=BLOB=
3,648518346349471156,0.35,=BLOB=,=BLOB=
3,648518346349471500,0.35,=BLOB=,=BLOB=
3,648518346349471562,0.35,=BLOB=,=BLOB=
3,648518346349471565,0.35,=BLOB=,=BLOB=
3,648518346349471910,0.35,=BLOB=,=BLOB=
3,648518346349472574,0.35,=BLOB=,=BLOB=
3,648518346349472601,0.35,=BLOB=,=BLOB=


In [None]:
search_key = dict(segment_id=648518346341371119,segmentation=3,decimation_ratio=0.35)
new_key = dict(segment_id = search_key["segment_id"],segmentation=search_key["segmentation"])
dec_vert_labels,dec_tri_labels = (pinky.OverlayedSpineLabel & search_key).fetch1("vertices","triangles")