In [None]:
import os
import sys
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
if '..' not in sys.path:
    sys.path = ['..'] + sys.path
from dlutils.utils import extract_mechanisms
from neuroutils.nodes import Node, SWCImpedanceNode
from neuroutils.trees import SWCImpedanceTree

In [None]:
config_file = os.path.join('thorny.json')
config = json.load(open(config_file, 'r'))
optim_dir = config['optimization_folder']
cell_type = config['cell_type']
cell_name = config['cell_name']
optim_run = config['optimization_run']
individual = config['individual']
base_folder = os.path.join(optim_dir, cell_type.capitalize(), cell_name, optim_run)
swc_file = os.path.join(base_folder, config['swc_file'])
params_file = os.path.join(base_folder, f'individual_{individual}.json')
parameters = json.load(open(params_file, 'r'))
mechanisms = extract_mechanisms(os.path.join(base_folder, 'parameters.json'), cell_name+'_')

In [None]:
def find_param(params, name):
    D = {}
    section_IDs = {'somatic': 1, 'axonal': 2, 'basal': 3, 'apical': 4}
    for param in params:
        if param['param_name'] == name:
            if param['sectionlist'] == 'all':
                return param['value']
            D[section_IDs[param['sectionlist']]] = param['value']
    return D
cm = find_param(parameters, 'cm')
ra = find_param(parameters, 'Ra')
rm = {k: 1/v for k,v in find_param(parameters, 'g_pas').items()}

In [None]:
tree = SWCImpedanceTree(swc_file, cm, rm, ra, root_point=1)

In [None]:
F = 0
tree.compute_impedances(F)
tree.compute_attenuations()

In [None]:
if cell_name == 'DH070813' and tree.root.ID == 1:
    col_names = 'ID','typ','x','y','z','diam','parent_ID'
    col_types = {'ID': np.int32, 'typ': np.int32, 'x': np.float32,
                 'y': np.float32, 'z': np.float32, 'diam': np.float32,
                 'parent_ID': np.int32}
    df = pd.read_table(swc_file, sep=' ', header=None, names=col_names, index_col='ID')
    start,stop = 4858,4887
    IDs_on_path = [stop]
    ID = stop
    parent_ID = int(df.loc[stop,'parent_ID'])
    distances = []
    while parent_ID != -1:
        distances.append(np.sqrt(np.sum((df.loc[ID,['x','y','z']]-df.loc[parent_ID,['x','y','z']])**2)))
        ID = parent_ID
        parent_ID = int(df.loc[ID,'parent_ID'])
        IDs_on_path.append(ID)
        if ID == start:
            break
    IDs_on_path = np.array(IDs_on_path[::-1])
    distances = np.array(distances[::-1])
    path = tree.find_connecting_path(start,stop)
    dst = np.array([n._h for n in path[1:]])
    assert distances.sum() == dst.sum()
    n1 = tree.find_node_with_ID(start)
    n2 = tree.find_node_with_ID(stop)
    print('Path distance between points ({:.2f},{:.2f},{:.2f}) and ({:.2f},{:.2f},{:.2f}): {:.5f}.'.\
          format(n1._x, n1._y, n1._z, n2._x, n2._y, n2._z, np.sum(distances)))
    print('Attenuation between points ({:.2f},{:.2f},{:.2f}) and ({:.2f},{:.2f},{:.2f}): {:.5f}.'.\
          format(n1._x, n1._y, n1._z, n2._x, n2._y, n2._z, tree.compute_attenuation(start, stop)))

In [None]:
A = {}
D = {}
D_lambda = {}
for node in tree:
    if node.parent is not None and node._node_type in (3,4):
        parent = node.parent
        siblings = parent.children
        idx = siblings.index(node)
        try:
            A[node.ID] = A[parent.ID] * parent.A[idx]
            D[node.ID] = D[parent.ID] + node._h
            D_lambda[node.ID] = D_lambda[parent.ID] + 1e-4*node._h/node._lambda_DC
        except:
            A[node.ID] = parent.A[idx]
            D[node.ID] = node._h
            D_lambda[node.ID] = 1e-4*node._h/node._lambda_DC
for node in tree:
    try:
        coeff = -1 if node._node_type == 3 else 1
        D[node.ID] *= coeff
        D_lambda[node.ID] *= coeff
    except:
        pass
for k,v in A.items():
    A[k] = np.abs(v)

In [None]:
pickle.dump(A, open(f'A_{cell_type}.pkl','wb'))
pickle.dump(D, open(f'D_{cell_type}.pkl','wb'))

In [None]:
fig,ax = plt.subplots(1, 2, figsize=(8,4), sharey=True)
for node in tree:
    if node._node_type in (3,4) and node != tree.root:
        ax[0].plot(D[node.ID], 20*np.log10(A[node.ID]), 'k.', ms=1)
        ax[1].plot(D_lambda[node.ID], 20*np.log10(A[node.ID]), 'k.', ms=1)
ax[0].set_xlabel('Distance from root (μm)')
ax[1].set_xlabel('Distance from root (λ)')
ax[0].set_ylabel('Attenuation (log)')
sns.despine()
fig.tight_layout()