In [1]:
# % matplotlib inline
from neuprint import Client, skeleton
from neuprint import fetch_synapses, NeuronCriteria as NC, SynapseCriteria as SC
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import importlib
import random
from os.path import isfile
from sklearn.linear_model import LogisticRegression
import time
from sklearn.decomposition import PCA
import statsmodels.api as sm

token_id = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6ImdhcnJldHQuc2FnZXJAeWFsZS5lZHUiLCJsZXZlbCI6Im5vYXV0aCIsImltYWdlLXVybCI6Imh0dHBzOi8vbGgzLmdvb2dsZXVzZXJjb250ZW50LmNvbS9hLS9BT2gxNEdpTGNqZXlHYWNnS3NPcTgzdDNfczBoTU5sQUtlTkljRzdxMkU5Rz1zOTYtYz9zej01MD9zej01MCIsImV4cCI6MTgwMTAxNzUwNn0.dzq7Iy01JwSWbKq-Qvi8ov7Hwr0-ozpYeSnOsUD-Mx0"
np.set_printoptions(precision=5, suppress=True)  # suppress scientific float notation
home_dir = '/home/gsager56/hemibrain/clean_mito_code'
c = Client('neuprint.janelia.org', dataset='hemibrain:v1.2.1', token=token_id)
neuron_quality = pd.read_csv(home_dir + '/saved_data/neuron_quality.csv')
neuron_quality_np = neuron_quality.to_numpy()
server = 'http://hemibrain-dvid.janelia.org'

# import utils file
spec = importlib.util.spec_from_file_location('utils', home_dir+'/util_files/utils.py')
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)

# import skel_clean_utils file
spec = importlib.util.spec_from_file_location('skel_clean_utils', home_dir+'/util_files/skel_clean_utils.py')
skel_clean_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(skel_clean_utils)

# import GLM_utils file
spec = importlib.util.spec_from_file_location('GLM_utils', home_dir+'/util_files/GLM_utils.py')
GLM_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(GLM_utils)

# import config file
spec = importlib.util.spec_from_file_location('config', home_dir+'/util_files/config.py')
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

node_class_dict = config.node_class_dict
analyze_neurons  = config.analyze_neurons

In [2]:
# import skel_clean_utils file
spec = importlib.util.spec_from_file_location('skel_clean_utils', home_dir+'/util_files/skel_clean_utils.py')
skel_clean_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(skel_clean_utils)

std_vals = pd.read_csv(home_dir + '/saved_data/trivial_leaf_std.csv').to_numpy()
mean_vals = pd.read_csv(home_dir + '/saved_data/trivial_leaf_mean.csv').to_numpy()
betas = pd.read_csv(home_dir + '/saved_data/trivial_leaf_betas.csv').to_numpy()

for i_neuron in np.where( np.isin(neuron_quality_np[:,1], analyze_neurons) )[0]:
    t0 = time.time()
    bodyId, neuron_type = neuron_quality_np[i_neuron,[0,1]]
    skel_file = home_dir + f'/saved_neuron_skeletons/s_pandas_{bodyId}_{neuron_type}_200nm.csv'
    new_skel_file = home_dir + f'/saved_clean_skeletons/s_pandas_{bodyId}_{neuron_type}_200nm.csv'
    if isfile(skel_file) and not isfile(new_skel_file):
        old_s_pandas = c.fetch_skeleton( bodyId, format='pandas', heal=True, with_distances=False) # I will heal the skeleton later
        node_classes, important_nodes = skel_clean_utils.classify_nodes(old_s_pandas, fetch_synapses(NC(bodyId=bodyId)), neuron_quality.iloc[i_neuron])
        if node_classes is not None:
            s_pandas = pd.read_csv(skel_file)
            assert len(s_pandas) > 10
            s_pandas = skel_clean_utils.heal_resampled_skel(s_pandas, bodyId)
            skeleton.reorient_skeleton( s_pandas, rowId = important_nodes['root node'] )

            s_np = s_pandas.to_numpy()
            leaf_nodes, branch_nodes = utils.find_leaves_and_branches( s_np )
            keep_bool = np.ones( len(s_np), dtype=bool )
            for i_leaf, leaf_node in enumerate(leaf_nodes):
                leaf_idxs = utils.get_down_idxs(s_np, leaf_node, np.isin(s_np[:,0], branch_nodes))
                leaf_length = np.max( np.sqrt( np.sum( (s_np[leaf_idxs,:][:,[1,2,3]] - s_np[leaf_idxs[-1],[1,2,3]][np.newaxis,:])**2, axis=1) ) ) - s_np[leaf_idxs[-1],4]
                if leaf_length - (s_np[leaf_idxs[0],4]*2) <= (91.439 / 8):
                    keep_bool[ leaf_idxs[:-1] ] = False # this is a trivial leaf
            s_pandas = pd.DataFrame( data = s_np[ keep_bool ], columns = s_pandas.columns )
            s_np = s_pandas.to_numpy()
            
            # make sure all the leaves can get to the root
            leaf_nodes, branch_nodes = utils.find_leaves_and_branches( s_np )
            for node in leaf_nodes:
                idx = np.where(s_np[:,0] == node)[0][0]
                while s_np[idx,5] != -1:
                    idx = np.where(s_np[:,0] == s_np[idx,5])[0][0]
                    
            # eliminated nodes with the same coordinate
            del_idxs = [1]
            while len(del_idxs) > 0:
                skel_coords = s_np[:,[1,2,3]]
                row_cols = np.array( [ [] for _ in range(2) ] , dtype=int).T
                for row in range( len(s_np)-1 ):
                    cols = np.where( np.all(skel_coords[row].reshape((1,3)) == skel_coords[row+1:],axis=1) )[0]
                    if len(cols) >= 2:
                        # figure out who is connected to who
                        same_coord_idxs = np.append(row,cols+row+1)
                        same_coord_nodes = s_np[same_coord_idxs,0]
                        same_coord_down_nodes = s_np[same_coord_idxs,5]
                        if np.any( np.isin( same_coord_down_nodes, same_coord_nodes) ):
                            i_idx = np.where( np.isin( same_coord_down_nodes, same_coord_nodes) )[0][0]
                            j_idx = np.where( same_coord_down_nodes[i_idx] == same_coord_nodes )[0][0]
                            assert s_np[ same_coord_idxs[i_idx], 5] == s_np[ same_coord_idxs[j_idx], 0]

                            this_row, this_col = s_np[same_coord_idxs[[i_idx,j_idx]], 0]
                            if this_row not in row_cols and this_col not in row_cols:
                                row_cols = np.append( row_cols, [[this_row,this_col]], axis=0)
                    elif len(cols)==1:
                        this_row = s_np[row,0]
                        this_col = s_np[cols[0]+row+1,0]
                        if this_row not in row_cols and this_col not in row_cols:
                            row_cols = np.append( row_cols, [[this_row,this_col]], axis=0)
                assert len(np.unique(row_cols)) == len(row_cols.flatten())
                del_idxs = []
                for row_node, col_node in row_cols:
                    row = np.where( s_np[:,0] == row_node )[0][0]
                    col = np.where( s_np[:,0] == col_node )[0][0]

                    assert np.all( skel_coords[row] == skel_coords[col] )
                    if s_np[row,5] == s_np[col,0] or s_np[col,5] == s_np[row,0]:
                        if s_np[row,5] == s_np[col,0]:
                            up_idx, down_idx = row, col
                        elif s_np[col,5] == s_np[row,0]:
                            down_idx, up_idx = row, col
                        assert s_np[up_idx,5] == s_np[down_idx,0]

                        if np.sum( s_np[up_idx,0] == s_np[:,5] ) > 0:
                            # connect node(s) upstream of up_idx to down_idx 
                            for idx in np.where( s_np[up_idx,0] == s_np[:,5] )[0]:
                                s_pandas.at[idx, 'link'] = s_np[down_idx,0]
                        del_idxs.append( up_idx )
                if len(del_idxs) > 0:
                    s_np = s_np[ ~np.isin( np.arange(len(s_np)), del_idxs ) ]
                    s_pandas = pd.DataFrame( data = s_np, columns = s_pandas.columns )
            assert len(s_np) == len(s_pandas)
            # make sure all the leaves can get to the root
            leaf_nodes, branch_nodes = utils.find_leaves_and_branches( s_np )
            for node in leaf_nodes:
                idx = np.where(s_np[:,0] == node)[0][0]
                while s_np[idx,5] != -1:
                    idx = np.where(s_np[:,0] == s_np[idx,5])[0][0]
            
            # change direction of theta and phi to ensure they point down the skeleton
            new_cols = np.concatenate( [np.array(s_pandas.columns)[:6], ['distance'], np.array(s_pandas.columns)[6:] ] )
            s_pandas = skel_clean_utils.append_distance( s_pandas ) # create a distance field in the dataframe
            s_pandas = s_pandas.reindex(columns=new_cols)
            s_np = s_pandas.to_numpy()
            
            frac_wrong = 0.0
            for idx in range(len(s_np)):
                xyz = np.array( utils.spherical_2_cart(1, s_np[idx,7], s_np[idx,8]) )
                xyz = xyz / np.sqrt(np.sum(xyz**2))
                if s_np[idx,5] != -1:
                    down_idx = np.where(s_np[idx,5] == s_np[:,0] )[0][0]
                    down_xyz = s_np[ down_idx, [1,2,3]] - s_np[ idx, [1,2,3] ]
                    assert np.sqrt(np.sum(down_xyz**2)) > 0
                    down_xyz = down_xyz / np.sqrt(np.sum(down_xyz**2))
                    assert np.abs( np.sum(down_xyz * xyz) ) < 1.01, f'{down_xyz}, {xyz}, {np.sum(down_xyz * xyz)}'
                    if np.sum(down_xyz * xyz) < 0:
                        # xyz is pointed in the wrong direction
                        _, theta, phi = utils.cart_2_spherical( -xyz[0], -xyz[1], -xyz[2] )
                        s_np[idx,7] = theta
                        s_np[idx,8] = phi
                        frac_wrong += 1 / len(s_np)
            node_classes, important_nodes = skel_clean_utils.classify_nodes(s_pandas, fetch_synapses(NC(bodyId=bodyId)), neuron_quality.iloc[i_neuron])
            s_pandas['node_classes'] = node_classes
            assert len(s_pandas) > 10
            s_pandas.to_csv(new_skel_file, index = False)
            print( f'Finished with {bodyId} {neuron_type}' )
print('Finished')

Finished with 1158187240 LC4
Finished with 1158864995 LC4


Exception ignored in: <ssl.SSLSocket fd=62, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=6, laddr=('172.22.86.239', 57292), raddr=('206.241.0.102', 443)>
Exception ignored in: <ssl.SSLSocket fd=61, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=6, laddr=('172.22.86.239', 57284), raddr=('206.241.0.102', 443)>


Finished with 1189559257 LC4
Finished with 1218901359 LC4
Finished with 1249932198 LC4
Finished with 1251287671 LC4
Finished with 1281303666 LC4
Finished with 1405780725 LC4
Finished with 1438524573 LC4
Finished with 1466861327 LC4
Finished with 1471601440 LC4
Finished with 1498574596 LC4
Finished with 1503999967 LC4
Finished with 1562673627 LC4
Finished with 1590979045 LC4
Finished with 1621357756 LC4
Finished with 1625080038 LC4
Finished with 1688505620 LC4
Finished with 1715459859 LC4
Finished with 1745821751 LC4
Finished with 1749258134 LC4
Finished with 1782668028 LC4
Finished with 1810956698 LC4
Finished with 1809264255 LC4
Finished with 1839288696 LC4
Finished with 1840636280 LC4
Finished with 1874035952 LC4
Finished with 1876894387 LC4
Finished with 1876898200 LC4
Finished with 1877217777 LC4
Finished with 1877930505 LC4
Finished with 1877939213 LC4
Finished with 1906496111 LC4
Finished with 1907924777 LC4
Finished with 1907933561 LC4
Finished with 1906159299 LC4
Finished with 

In [3]:
assert False

AssertionError: 

In [None]:
# import skel_clean_utils file
spec = importlib.util.spec_from_file_location('skel_clean_utils', home_dir+'/util_files/skel_clean_utils.py')
skel_clean_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(skel_clean_utils)

std_vals = pd.read_csv(home_dir + '/saved_data/trivial_leaf_std.csv').to_numpy()
mean_vals = pd.read_csv(home_dir + '/saved_data/trivial_leaf_mean.csv').to_numpy()
betas = pd.read_csv(home_dir + '/saved_data/trivial_leaf_betas.csv').to_numpy()

for i_neuron in np.where( np.isin(neuron_quality_np[:,1], analyze_neurons) )[0]:
    t0 = time.time()
    bodyId, neuron_type = neuron_quality_np[i_neuron,[0,1]]
    skel_file = home_dir + f'/saved_neuron_skeletons/s_pandas_{bodyId}_{neuron_type}_200nm.csv'
    new_skel_file = home_dir + f'/saved_clean_skeletons/s_pandas_{bodyId}_{neuron_type}_200nm.csv'
    if isfile(skel_file) and not isfile(new_skel_file):
        old_s_pandas = c.fetch_skeleton( bodyId, format='pandas', heal=True, with_distances=False) # I will heal the skeleton later
        node_classes, important_nodes = skel_clean_utils.classify_nodes(old_s_pandas, fetch_synapses(NC(bodyId=bodyId)), neuron_quality.iloc[i_neuron])
        if node_classes is not None:
            s_pandas = pd.read_csv(skel_file)
            s_pandas = skel_clean_utils.heal_resampled_skel(s_pandas, bodyId)
            skeleton.reorient_skeleton( s_pandas, rowId = important_nodes['root node'] )

            #keep_bool = np.array([False])
            #while np.any(~keep_bool):
            s_np = s_pandas.to_numpy()
            leaf_nodes, branch_nodes = utils.find_leaves_and_branches( s_np )
            leaf_space, nodes = skel_clean_utils.get_is_trivial_leaf_space(bodyId, leaf_nodes, mean_vals.shape[1], s_pandas.copy())
            zscore_leaf = (leaf_space - mean_vals) / std_vals
            q_vals = np.matmul(np.append(np.ones((len(leaf_nodes),1)), zscore_leaf,axis=1), betas.T)[:,0]
            probs = 1 / (1 + np.exp(-q_vals))
            is_trivial = probs >= 0.6

            keep_bool = np.ones( len(s_np), dtype=bool )
            for node in nodes[is_trivial]:
                keep_bool[ utils.get_down_idxs(s_np, node, np.isin(s_np[:,0],branch_nodes))[:-1] ] = False
            s_pandas = pd.DataFrame( data = s_np[ keep_bool ], columns = s_pandas.columns )
            s_np = s_pandas.to_numpy()
            
            # make sure all the leaves can get to the root
            leaf_nodes, branch_nodes = utils.find_leaves_and_branches( s_np )
            for node in leaf_nodes:
                idx = np.where(s_np[:,0] == node)[0][0]
                while s_np[idx,5] != -1:
                    idx = np.where(s_np[:,0] == s_np[idx,5])[0][0]
                    
            # eliminated nodes with the same coordinate
            del_idxs = [1]
            while len(del_idxs) > 0:
                skel_coords = s_np[:,[1,2,3]]
                row_cols = np.array( [ [] for _ in range(2) ] , dtype=int).T
                for row in range( len(s_np)-1 ):
                    cols = np.where( np.all(skel_coords[row].reshape((1,3)) == skel_coords[row+1:],axis=1) )[0]
                    if len(cols) >= 2:
                        # figure out who is connected to who
                        same_coord_idxs = np.append(row,cols+row+1)
                        same_coord_nodes = s_np[same_coord_idxs,0]
                        same_coord_down_nodes = s_np[same_coord_idxs,5]
                        if np.any( np.isin( same_coord_down_nodes, same_coord_nodes) ):
                            i_idx = np.where( np.isin( same_coord_down_nodes, same_coord_nodes) )[0][0]
                            j_idx = np.where( same_coord_down_nodes[i_idx] == same_coord_nodes )[0][0]
                            assert s_np[ same_coord_idxs[i_idx], 5] == s_np[ same_coord_idxs[j_idx], 0]

                            this_row, this_col = s_np[same_coord_idxs[[i_idx,j_idx]], 0]
                            if this_row not in row_cols and this_col not in row_cols:
                                row_cols = np.append( row_cols, [[this_row,this_col]], axis=0)
                    elif len(cols)==1:
                        this_row = s_np[row,0]
                        this_col = s_np[cols[0]+row+1,0]
                        if this_row not in row_cols and this_col not in row_cols:
                            row_cols = np.append( row_cols, [[this_row,this_col]], axis=0)
                assert len(np.unique(row_cols)) == len(row_cols.flatten())
                del_idxs = []
                for row_node, col_node in row_cols:
                    row = np.where( s_np[:,0] == row_node )[0][0]
                    col = np.where( s_np[:,0] == col_node )[0][0]

                    assert np.all( skel_coords[row] == skel_coords[col] )
                    if s_np[row,5] == s_np[col,0] or s_np[col,5] == s_np[row,0]:
                        if s_np[row,5] == s_np[col,0]:
                            up_idx, down_idx = row, col
                        elif s_np[col,5] == s_np[row,0]:
                            down_idx, up_idx = row, col
                        assert s_np[up_idx,5] == s_np[down_idx,0]

                        if np.sum( s_np[up_idx,0] == s_np[:,5] ) > 0:
                            # connect node(s) upstream of up_idx to down_idx 
                            for idx in np.where( s_np[up_idx,0] == s_np[:,5] )[0]:
                                s_pandas.at[idx, 'link'] = s_np[down_idx,0]
                        del_idxs.append( up_idx )
                if len(del_idxs) > 0:
                    s_np = s_np[ ~np.isin( np.arange(len(s_np)), del_idxs ) ]
                    s_pandas = pd.DataFrame( data = s_np, columns = s_pandas.columns )
            assert len(s_np) == len(s_pandas)
            # make sure all the leaves can get to the root
            leaf_nodes, branch_nodes = utils.find_leaves_and_branches( s_np )
            for node in leaf_nodes:
                idx = np.where(s_np[:,0] == node)[0][0]
                while s_np[idx,5] != -1:
                    idx = np.where(s_np[:,0] == s_np[idx,5])[0][0]
            
            # change direction of theta and phi to ensure they point down the skeleton
            new_cols = np.concatenate( [np.array(s_pandas.columns)[:6], ['distance'], np.array(s_pandas.columns)[6:] ] )
            s_pandas = skel_clean_utils.append_distance( s_pandas ) # create a distance field in the dataframe
            s_pandas = s_pandas.reindex(columns=new_cols)
            s_np = s_pandas.to_numpy()
            
            frac_wrong = 0.0
            for idx in range(len(s_np)):
                xyz = np.array( utils.spherical_2_cart(1, s_np[idx,7], s_np[idx,8]) )
                xyz = xyz / np.sqrt(np.sum(xyz**2))
                if s_np[idx,5] != -1:
                    down_idx = np.where(s_np[idx,5] == s_np[:,0] )[0][0]
                    down_xyz = s_np[ down_idx, [1,2,3]] - s_np[ idx, [1,2,3] ]
                    assert np.sqrt(np.sum(down_xyz**2)) > 0
                    down_xyz = down_xyz / np.sqrt(np.sum(down_xyz**2))
                    assert np.abs( np.sum(down_xyz * xyz) ) < 1.01, f'{down_xyz}, {xyz}, {np.sum(down_xyz * xyz)}'
                    if np.sum(down_xyz * xyz) < 0:
                        # xyz is pointed in the wrong direction
                        _, theta, phi = utils.cart_2_spherical( -xyz[0], -xyz[1], -xyz[2] )
                        s_np[idx,7] = theta
                        s_np[idx,8] = phi
                        frac_wrong += 1 / len(s_np)
            node_classes, important_nodes = skel_clean_utils.classify_nodes(s_pandas, fetch_synapses(NC(bodyId=bodyId)), neuron_quality.iloc[i_neuron])
            s_pandas['node_classes'] = node_classes
            s_pandas.to_csv(new_skel_file, index = False)
            print( f'Finished with {bodyId} {neuron_type}' )
print('Finished')

In [None]:
for neuron_type in config.analyze_neurons:
    num_analyzed = 0
    for i_neuron in np.where( neuron_type == neuron_quality_np[:,1] )[0]:
        bodyId, neuron_type = neuron_quality_np[i_neuron,[0,1]]
        new_skel_file = home_dir + f'/saved_clean_skeletons/s_pandas_{bodyId}_{neuron_type}_200nm.csv'
        num_analyzed += int( isfile(new_skel_file) )
    print( num_analyzed, np.sum( neuron_type == neuron_quality_np[:,1] ), neuron_type)