# Robustness analysis data generation

**Note!** This notebook cannot be run like this because it pulls its data from the internal Datajoint database. It is meant to showcase the truncation procedure and feature computation of each neuron. 
If you want to reproduce the analysis please adjust the data read in accordingly.

In [2]:
# DATAJOINT
import datajoint as dj
dj.config.load('./utils/dj_mysql_conf.json')
c = dj.conn()

from schemata.cell import *
from schemata.classification import *
schema = dj.schema('agberens_morphologies', locals())

import pandas as pd
import networkx as nx
from utils import NeuronTree as nt

import multiprocessing
from functools import partial

# PLOTTING

import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

Connecting slaturnus@172.25.240.205:3306


  from numpy.core.umath_tests import inner1d


In [3]:
np.unique((Cell() & 'ds_id =3').fetch('type'))

array(['BC', 'BPC', 'BTC', 'ChC', 'DBC', 'MC', 'NGC', 'pyr'], dtype=object)

# Create data for each type individually

In [4]:
types = pd.DataFrame((Neuron()*Cell() & dict(ds_id=3, reduction_id=0)& "type!='pyr'").fetch(as_dict=True))
types.to_csv('./data/types.csv')

In [6]:
neuron_keys = (Neuron()*Cell() & dict(ds_id=3, reduction_id=0)& dj.OrList(["type='%s'"%k for k in np.unique(types['type'])])).proj().fetch(as_dict=True)

In [7]:
trees = []

for k in neuron_keys:
    trees.append(Neuron().get_tree(k))


In [8]:
def connected_components(G_):
    seen = set()
    for v in G_:
        if v not in seen:
            c = set(nx.bfs_tree(G_,v))
            yield c
            seen.update(c)

def truncate_branches(G, perc=.1, no_trunc_branches=None):

    branchpoints = G.get_branchpoints()
    branchorders = G.get_branch_order()

    # identify points that are farthest away in terms of branch order
    bp_bo = np.array([branchorders[point] for point in branchpoints])
    no_branchpoints = len(bp_bo)

    if perc is not None:
        no_to_cut = int(np.round(no_branchpoints * perc))     
    else:
        no_to_cut = int(round(no_trunc_branches/2))     # since one branch point removes two branches

    sorted_idx = np.argsort(bp_bo)

    G_ = copy.deepcopy(G.get_graph())
    
    selected_bp = branchpoints[sorted_idx[-no_to_cut:]]
    children_ids = []

    for k in range(len(selected_bp)):
        children_ids +=list(G_[selected_bp[k]].keys())
        
    if no_to_cut*2 >= len(children_ids):
        choice = children_ids
    else:
        choice = np.random.choice(children_ids,int(np.round(len(branchpoints) *2 * perc)),replace=False )
    
    to_delete = []
    for d in choice:
        to_delete += nx.bfs_tree(G_,d).nodes()
        
    G_.remove_nodes_from(to_delete)
    return nt.NeuronTree(graph=G_)

In [9]:
def projection(G, u, v):
    n = G.node[u]['pos']
    r = G.node[v]['pos']
    return np.dot((n - r), np.array([0,0,1]))

def calculate_truncated_features(perc, trees, neuron_keys, method='branch points', folder='full'):
    

    try:
        morphometry_data = pd.read_csv('./data/'+folder+'/morphometry_truncated_%0.2f.csv'%(perc))
    except FileNotFoundError:
        morphometry_data = pd.DataFrame()
    
    try:
        persistence_data = pd.read_csv('./data/'+folder+'/persistence_truncated_%0.2f.csv'%(perc))
    except FileNotFoundError:
        persistence_data = pd.DataFrame()
        
    try:
        point_cloud_data = pd.read_csv('./data/'+folder+'/point_cloud_truncated_%0.2f.csv'%(perc))
    except FileNotFoundError:
        point_cloud_data = pd.DataFrame()
    
    for k,t in enumerate(trees):
        if method == 'nodes':
            T = t.truncate_nodes(perc=perc)
        elif method == 'branch points':
            T = truncate_branches(t,perc=perc)
                
        if morphometry_data.empty or (neuron_keys[k]['c_num'] not in morphometry_data['c_num'].values) :

            # create morphometric statistics vector
            z = copy.copy(neuron_keys[k])
            z['truncated'] = perc

            z['branch_points'] = T.get_branchpoints().size
            extend = T.get_extend()

            z['width'] = extend[0]
            z['depth'] = extend[1]
            z['height'] = extend[2]

            tips = T.get_tips()

            z['tips'] = tips.size

            z['stems'] = len(T.edges(1))

            z['total_length'] = np.sum(list(nx.get_edge_attributes(T.get_graph(), 'path_length').values()))
            # get all radii
            radii = nx.get_node_attributes(T.get_graph(), 'radius')
            # delete the soma
            radii.pop(T.get_root())
            z['avg_thickness'] = np.mean(list(radii.values()))
           
            z['total_surface'] = np.sum(list(T.get_surface().values()))
            z['total_volume'] = np.sum(list(T.get_volume().values()))

            z['max_path_dist_to_soma'] = np.max(T.get_distance_dist()[1])
            z['max_branch_order'] = np.max(list(T.get_branch_order().values()))

            path_angles = []
            for p1 in T.get_path_angles().items():
                if p1[1].values():
                    path_angles += list(list(p1[1].values())[0].values())

            z['max_path_angle'] = np.percentile(path_angles,99.5)
            z['median_path_angle'] = np.median(path_angles)

            R = T.get_mst()
            segment_length = R.get_segment_length()
            terminal_segment_pl = [item[1] for item in segment_length.items() if item[0][1] in tips]
            intermediate_segment_pl = [item[1] for item in segment_length.items() if item[0][1] not in tips]

            z['max_segment_path_length'] = np.max(list(segment_length.values()))
            z['median_intermediate_segment_pl'] = np.median([0] + intermediate_segment_pl)
            z['median_terminal_segment_pl'] = np.median(terminal_segment_pl)

            tortuosity = [e[2]['path_length'] / e[2]['euclidean_dist'] for e in R.edges(data=True)]

            z['max_tortuosity'] = np.log(np.percentile(tortuosity,99.5))
            z['median_tortuosity'] = np.log(np.median(tortuosity))

            branch_angles = R.get_branch_angles()
            z['max_branch_angle'] = np.max(branch_angles)
            z['min_branch_angle'] = np.min(branch_angles)
            z['mean_branch_angle'] = np.mean(branch_angles)

            # get maximal degree within data
            z['max_degree'] = np.max([item[1] for item in R.get_graph().out_degree().items() if item[0] != R.get_root()])

            # get tree asymmetry
            weights, psad = R.get_psad()
            if np.sum(list(weights.values())) != 0:
                z['tree_asymmetry'] = np.sum([weights[k]*psad[k] for k in psad.keys()])/np.sum(list(weights.values()))
            else:
                z['tree_asymmetry'] = 0


            morphometry_data = morphometry_data.append(pd.DataFrame(z, index=['c_num']))
        
        #### get persistence
        if persistence_data.empty or (neuron_keys[k]['c_num'] not in persistence_data['c_num'].values) :
            try:
                persistence = T.get_persistence(f=projection)
                persistence['c_num'] = z['c_num']

                persistence_data = persistence_data.append(persistence)
            except ValueError as e:
                continue

        ### sample point cloud
        if point_cloud_data.empty or (neuron_keys[k]['c_num'] not in point_cloud_data['c_num'].values) :
            pc = nt.NeuronTree.resample_nodes(T.get_graph(), 0.025) 
            point_cloud = pd.DataFrame(pc, columns=['x','y','z'])
            point_cloud['ds_id'] = neuron_keys[k]['ds_id']
            point_cloud['c_num'] = neuron_keys[k]['c_num']
            point_cloud_data = point_cloud_data.append(point_cloud)
        
                                   
    morphometry_data.to_csv('./data/'+folder+'/morphometry_truncated_%0.2f.csv'%(perc))
    persistence_data.to_csv('./data/'+folder+'/persistence_truncated_%0.2f.csv'%(perc))
    point_cloud_data.to_csv('./data/'+folder+'/point_cloud_truncated_%0.2f.csv'%(perc))

In [10]:
# whole trees
method = 'branch points'
truncation = np.array(range(0,10))/10
data = trees
folder = 'full' 

In [11]:
with multiprocessing.Pool(8) as pool:
    most_trunc = pool.map(partial(calculate_truncated_features,trees=data, neuron_keys=neuron_keys, method=method, folder=folder), truncation)