In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8"

In [2]:
import jax.numpy as jnp
from jax import jit, vmap
import numpy as np
import matplotlib.pyplot as plt
import pickle
import h5py
import pandas as pd
import matplotlib as mpl
from scipy.ndimage import rotate

import jaxley as jx
from jaxley.channels import HH

In [3]:
def compute_jaxley_branch(roi_pos):
    min_dists = []
    min_comps = []
    for xyzr in cell.xyzr:
        dists = np.sum((roi_pos[:3] - xyzr[:, :3])**2, axis=1)
        min_dist = np.min(dists)
        argmin_dist = np.argmin(dists)
        if len(xyzr) > 1:
            comp_of_min = argmin_dist / (len(xyzr) - 1)
        else:
            comp_of_min = 0.5
        min_dists.append(min_dist)
        min_comps.append(comp_of_min)
        
    return np.argmin(min_dists), min_comps[np.argmin(min_dists)]

In [4]:
fnames = []
for (dirpath, dirnames, filenames) in os.walk("morphologies"):
    fnames.extend(filenames)

setup_df = pd.read_pickle("results/data/setup.pkl")

In [5]:
write_dfs = []

for morph_full in fnames:
    cell_id = morph_full[:-4]
    df = setup_df[setup_df["cell_id"] == cell_id]
    cell = jx.read_swc(f"morphologies/{morph_full}", nseg=4, max_branch_len=300.0, min_radius=1.0)
    
    for index, pos in df[["roi_x", "roi_y", "roi_z", "cell_id", "rec_id", "roi_id"]].iterrows():
        write_df = pd.DataFrame()
        jaxley_branch, jaxley_compartment = compute_jaxley_branch(pos.to_numpy())
        write_df["cell_id"] = [pos["cell_id"]]
        write_df["rec_id"] = [pos["rec_id"]]
        write_df["roi_id"] = [pos["roi_id"]]
        write_df["roi_x"] = [pos["roi_x"]]
        write_df["roi_y"] = [pos["roi_y"]]
        write_df["roi_z"] = [pos["roi_z"]]
        write_df["branch_ind"] = [int(jaxley_branch)]
        write_df["comp"] = [jaxley_compartment]

        write_dfs.append(write_df)
    
write_dfs = pd.concat(write_dfs).reset_index(drop=True)

  warn("Found a segment with length 0. Clipping it to 1.0")


In [6]:
write_dfs.to_pickle("results/data/recording_meta.pkl")