In [1]:
"""Pseudo Code:
Pt1
1) Create highly discretized skeletons to pull from in main function
Pt2
2) Create a table that pulls from both Orphan and Exhitatory that has any sort of dendrite part
Do get the width of the dendrite:
3) Get neuron
3b) Pull down the highly discretized skeleton 
4) For each significant dendritic compartment piece in the neurons components
    a. Extract the mesh of that piece and the bounding box
    b. Get rid of all skeleton edges whose both vertices are not within bounding box
    c. Make the KDTree from the mesh
    d. Find the nearest distance for each of the points on the skeleton
    e. Find percentiles of these minimum distances
    f. Save the following to the database:
        1. mesh id
        2. Label/Compartment type
        3. Compartment Index
        4. Different percentiles/quartiles of lengths
"""

""" 2 Ways can do this:
1) Use Mesh as source of KDTree and sample each point of the skeleton
-- think i want to do this because this will give you shortest distance for every point on the skeleton

2) Use skeleton as source of KD






"""

' 2 Ways can do this:\n1) Use Mesh as source of KDTree and sample each point of the skeleton\n-- think i want to do this because this will give you shortest distance for every point on the skeleton\n\n2) Use skeleton as source of KD\n\n\n\n\n\n\n'

In [2]:
import numpy as np
import datajoint as dj
import time
import os
import datetime
import trimesh as trimesh_io

#for supressing the output
import os, contextlib
import pathlib
import subprocess

#for error counting
from collections import Counter

#for reading in the new raw_skeleton files
import csv
import pandas as pd
from tqdm import tqdm

#for filtering
import math
from pykdtree.kdtree import KDTree

In [3]:
#setting the address and the username
dj.config['database.host'] = '10.28.0.34'
dj.config['database.user'] = 'celiib'
dj.config['database.password'] = 'newceliipass'
dj.config['safemode']=True
dj.config["display.limit"] = 20

schema = dj.schema('microns_pinky')
pinky = dj.create_virtual_module('pinky', 'microns_pinky')



Connecting celiib@10.28.0.34:3306


In [4]:
#(schema.jobs & "table_name='__dendrite_width'").delete()

In [5]:
#will create the vertex and face indices given a 3x3xn array
#function that takes in a 3x3 array of coordinates for faces and returns triangles and vertices
def index_unique_rows(full_coordinate_array):
    """
    Separates an array of nested coordinate rows into an array of unique rows and and index array.
    """
    vertices, flat_idx = np.unique(full_coordinate_array.reshape(-1, full_coordinate_array.shape[-1]), axis=0, return_inverse=True)
    return vertices, flat_idx.reshape(-1, full_coordinate_array.shape[-2])


In [6]:
#function will take in either array of skeletons or single skeleton 
#returns: Either list of the edges or dictionary of list of the edges
def discretize_skeletons(skeletons, maximum_length, skeleton_ids=None):
        def discretize_skeleton(full_edges, maximum_length):
            p0s = full_edges[:, 0]
            p1s = full_edges[:, 1]

            diffs = p1s - p0s
            distances = np.linalg.norm(diffs, axis=1)
            inc_nums = np.ceil(distances / maximum_length).astype(int)
            inc_nums[inc_nums<2] = 2
            diffs_inc = np.repeat(diffs / inc_nums[:, None], inc_nums, axis=0)

            p0s_stack = np.repeat(p0s, inc_nums, axis=0)
            max_arange = np.arange(inc_nums.max())
            multiplicative_incrementer = np.hstack([max_arange[0:i] for i in inc_nums.tolist()])
            evenly_spaced = p0s_stack + (multiplicative_incrementer[:, None] * diffs_inc)

            total = 0
            incremented_edges = list()
            for i, p1 in zip(inc_nums, p1s):
                temp_total = total+i
                inc_edge = evenly_spaced[total:temp_total]
                inc_edge[-1] = p1
                incremented_edges.append(inc_edge)
                total = temp_total
            new_full_edges = np.vstack([np.array((inc_edge[:-1], inc_edge[1:])).transpose(1, 0, 2) for inc_edge in incremented_edges])
            return new_full_edges

        if skeleton_ids is None:
            output = list()
            for full_edges in skeletons:
                output.append(discretize_skeleton(full_edges, maximum_length))
        else:
            output = dict()
            for seg_id, full_edges in zip(skeleton_ids, skeletons):
                output[seg_id] = discretize_skeleton(full_edges, maximum_length)

        return output

In [7]:
#goes and gets the unique mesh faces and vertices
#will take in and populate the soma table based on the key it gets

#Original Query: 

comp_interest = ["Apical","Basal","Oblique","Dendrite"]
significance_threshold = 10000
    

def get_component_mesh(query_key):
    
    #gets the vertices indices and triangle indices of significant, dendritic components
    # of the excitatory neurons
    table=""
    vertices_soma,triangles_soma = (pinky.CompartmentFinal.ComponentFinal() & query_key
                                    & [dict(compartment_type=f) for f in comp_interest] 
                                    & "n_vertex_indices>"+str(significance_threshold)).fetch("vertex_indices",
                                                                                             "triangle_indices")
    # check if there were any vertices found in the excitatory table
    # of dendritic compartment, significant size for that segment_id
    if len(vertices_soma) > 0:
        print("Component found in Exhitatory")
        #get the regular mesh from CleansedMesh
        vertices_mesh,triangles_mesh = (pinky.PymeshfixDecimatedExcitatoryStitchedMesh & query_key).fetch("vertices","triangles")
    else:
        #if there weren't any indices found in excitatory table, now check the orphan
        vertices_soma,triangles_soma = (pinky.CompartmentOrphan.ComponentOrphan() & query_key 
                                        & [dict(compartment_type=f) for f in comp_interest] 
                                        & "n_vertex_indices>"+str(significance_threshold)).fetch("vertex_indices","triangle_indices")
        # check if there were any vertices found in the orphan table
        # of dendritic compartment, significant size for that segment_id
        if len(vertices_soma) > 0:
            print("Component found in Orphans")
            vertices_mesh,triangles_mesh = (pinky.Decimation35OrphanStitched & query_key).fetch("vertices","triangles")
        else:
            print("No Component exists for " + str(query_key["segment_id"]))
            return np.array([]),np.array([])
    
    
    #gets the triangles of the submesh for those dendrite compartments
    #and all the compartment triangles on top of each other
    ts_flatten = np.hstack(triangles_soma).astype("int64")

    #the full vertices of the mesh
    vertices_real = vertices_mesh[0]
    triangles_real = triangles_mesh[0]

    ts_stack_whole = vertices_real[triangles_real[ts_flatten]]

    vertices_whole, triangles_whole = index_unique_rows(ts_stack_whole)
    return vertices_whole, triangles_whole

In [8]:
def filter_edges_by_bounding_box(edges,max_bb_zone,min_bb_zone):
    """
    Filters edges by only those inside the bounding box
    """
    
    #can just use bounding box function to get rid of any inside edges
    filtered_remaining = list()

    for i,e in enumerate(edges):
        #print(e)
        if min(e[0][0],e[1][0])>max_bb_zone[0]:
            #print("minx>maxx")
            #filtered_remaining.append(e)
            
            continue

        if max(e[0][0],e[1][0])<min_bb_zone[0]:
            #print("maxx<minx")
            #filtered_remaining.append(e)
            continue

        if min(e[0][1],e[1][1])>max_bb_zone[1]:
            #print("miny>maxy")
            #filtered_remaining.append(e)
            continue

        if max(e[0][1],e[1][1])<min_bb_zone[1]:
            #print("maxy<miny")
            #filtered_remaining.append(e)
            continue

        if min(e[0][2],e[1][2])>max_bb_zone[2]:
            #print("minz>maxz")
            #filtered_remaining.append(e)
            continue

        if max(e[0][2],e[1][2])<min_bb_zone[2]:
            #print("maxz<minz")
            #filtered_remaining.append(e)
            continue

        filtered_remaining.append(e)

    return np.array(filtered_remaining)

# Keysource is All Dendritic Sections of Neuron that has Filtered Skeleton that is of significant size

In [9]:
"""
Keysource = 
a. All orphan dendritic components of significant length that have a filtered skeleton
b. All excitatory dendritic components of significant length that have a filtered skeleton
"""

'\nKeysource = \na. All orphan dendritic components of significant length that have a filtered skeleton\nb. All excitatory dendritic components of significant length that have a filtered skeleton\n'

In [10]:
#filtered Skeleton stripping out the mesh
@schema
class DendriteWidth(dj.Computed):
    definition="""
    -> pinky.Segment
    decimation_ratio     : decimal(3,2) 
    compartment_type     : varchar(16)                  # Basal, Apical, spine head, etc.
    component_index      : smallint unsigned            # Which sub-compartment of a certain label this is.
    discrete_length :int #the maximum size of a discretized segment of postsyn skeleton
    ---
    min_width_perc_50           :float #50th pecentile min width range for all skeleton
    min_width_perc_55           :float #55th pecentile min width range for all skeleton
    min_width_perc_60           :float #60th pecentile min width range for all skeleton
    min_width_perc_65           :float #65th pecentile min width range for all skeleton
    min_width_perc_70           :float #70th pecentile min width range for all skeleton
    min_width_perc_75           :float #75th pecentile min width range for all skeleton
    min_width_perc_80           :float #80th pecentile min width range for all skeleton
    min_width_perc_85           :float #85th pecentile min width range for all skeleton
    min_width_perc_90           :float #90th pecentile min width range for all skeleton
    min_width_perc_95           :float #95th pecentile min width range for all skeleton
    min_width_max               :float #max min width range for all skeleton
    """
    
    comp_interest = ["Apical","Basal","Oblique","Dendrite"]
    significance_threshold = 10000

    #getting all of the axons and dendrites that have dendrites
    key_source = (pinky.CompartmentOrphan.ComponentOrphan & [dict(compartment_type=f) for f in comp_interest] 
         & "n_vertex_indices>"+str(significance_threshold) & pinky.FilteredNeuronSkeleton()).proj() 
    
    """
    3) Get neuron (Obtained from key source)
    3b) Pull down the highly discretized skeleton 
    4) For each significant dendritic compartment piece in the neurons components
        a. Extract the mesh of that piece and the bounding box
        b. Get rid of all skeleton edges whose both vertices are not within bounding box
        c. Make the KDTree from the mesh
        d. Find the nearest distance for each of the points on the skeleton
        e. Find percentiles of these minimum distances
        f. Save the following to the database:
            1. mesh id
            2. Label/Compartment type
            3. Compartment Index
            4. Different percentiles/quartiles of lengths
    """
    def make(self, key):
        print()
        print()
        
        print(str(key["segment_id"])+ ": " + str(key["compartment_type"])+ "-" + str(key["component_index"] ))
        global_start_time = time.time()
        #create return key
        return_key = key.copy()

        #*****pull down the skeleton for the mesh (should be from highly discretized list****
        skeleton_data = (pinky.FilteredNeuronSkeleton() & key).fetch(as_dict=True)[0]



        #get the vertices and triangles for the Dendrite
        start_time = time.time()
        vertices_whole, triangles_whole = get_component_mesh(key)
        print("len of vertices = " + str(len(vertices_whole)))
        print(f"Step 1: extracted Soma Mesh = {time.time()-start_time}")
        
        discrete_threshold=10

        #if no soma portion was found then just write regular skeleton
        if not vertices_whole.any():

            print("ERROR NO COMPONENT MESH FOUND")
            new_key = dict(key,
                        discrete_length = discrete_threshold,
                        min_width_perc_50 = -2,
                        min_width_perc_55 = -2,
                        min_width_perc_60 = -2,
                        min_width_perc_65 = -2,
                        min_width_perc_70 = -2,
                        min_width_perc_75 = -2,
                        min_width_perc_80 = -2,
                        min_width_perc_85 = -2,
                        min_width_perc_90 = -2,
                        min_width_perc_95 = -2,
                        min_width_max = -2
                      )

        else:
            print("Mesh successfully extracted")
            #get the discretized skeleton
            discrete_threshold = 10
            new_skeleton = discretize_skeletons([skeleton_data["edges"]],discrete_threshold)[0]

            #calculate the bounding box by hand
            min_bb_bac = (min(new_skeleton.reshape(-1,1,3)[:,:,0])[0], 
                 min(new_skeleton.reshape(-1,1,3)[:,:,1])[0],
                 min(new_skeleton.reshape(-1,1,3)[:,:,2])[0])

            max_bb_bac = (max(new_skeleton.reshape(-1,1,3)[:,:,0])[0],
                max(new_skeleton.reshape(-1,1,3)[:,:,1])[0],
                max(new_skeleton.reshape(-1,1,3)[:,:,2])[0])

            mesh = trimesh_io.Trimesh()
            mesh.vertices = vertices_whole
            mesh.faces = triangles_whole
            
            min_bb = np.array(mesh.bounding_box.vertices).min(0)
            max_bb = np.array(mesh.bounding_box.vertices).max(0)


            #filter all of the skeleton points away that are outside of the bounding box of the compartment
            filtered_edges_postsyn = filter_edges_by_bounding_box(new_skeleton,max_bb,min_bb)
            
            if (not filtered_edges_postsyn.any()):
                print("FILTERED EDGES EMPTY")
                new_key = dict(key,
                            discrete_length = discrete_threshold,
                            min_width_perc_50 = -1,
                            min_width_perc_55 = -1,
                            min_width_perc_60 = -1,
                            min_width_perc_65 = -1,
                            min_width_perc_70 = -1,
                            min_width_perc_75 = -1,
                            min_width_perc_80 = -1,
                            min_width_perc_85 = -1,
                            min_width_perc_90 = -1,
                            min_width_perc_95 = -1,
                            min_width_max = -1
                          )
                
                

            else:
                # #do KDTree on the midpoints of the new edges to find which ones to keep
                kdtree = KDTree(vertices_whole)

                distances, nearest_nodes = kdtree.query(filtered_edges_postsyn.reshape(-1,3))

                #what is stored in the key






                new_key = dict(key,
                               discrete_length = discrete_threshold,
                                min_width_perc_50 = np.percentile(distances, 50),
                                min_width_perc_55 = np.percentile(distances, 55),
                                min_width_perc_60 = np.percentile(distances, 60),
                                min_width_perc_65 = np.percentile(distances, 65),
                                min_width_perc_70 = np.percentile(distances, 70),
                                min_width_perc_75 = np.percentile(distances, 75),
                                min_width_perc_80 = np.percentile(distances, 80),
                                min_width_perc_85 = np.percentile(distances, 85),
                                min_width_perc_90 = np.percentile(distances, 90),
                                min_width_perc_95 = np.percentile(distances, 95),
                                min_width_max = np.max(distances)
                              )

        self.insert1(new_key,skip_duplicates=True,ignore_extra_fields=True)
        print(f"Total time = {time.time()-global_start_time}")
    
    

In [11]:
start_time = time.time()
DendriteWidth.populate(reserve_jobs=True)
print(f"Total time = {time.time()-start_time}")



648518346341353058: Dendrite-0
Component found in Orphans
len of vertices = 38766
Step 1: extracted Soma Mesh = 1.33292555809021
Mesh successfully extracted
Total time = 1.9278881549835205


648518346341353574: Basal-0
Component found in Orphans
len of vertices = 45208
Step 1: extracted Soma Mesh = 1.4084727764129639
Mesh successfully extracted
Total time = 2.24190616607666


648518346341353607: Dendrite-0
Component found in Orphans
len of vertices = 68095
Step 1: extracted Soma Mesh = 1.7498245239257812
Mesh successfully extracted
Total time = 2.2729251384735107


648518346341354313: Dendrite-0
Component found in Orphans
len of vertices = 45794
Step 1: extracted Soma Mesh = 1.3231925964355469
Mesh successfully extracted
Total time = 1.7311265468597412


648518346341354496: Dendrite-0
Component found in Orphans
len of vertices = 61171
Step 1: extracted Soma Mesh = 1.6157963275909424
Mesh successfully extracted
Total time = 2.051596164703369


648518346341355048: Dendrite-0
Component 