In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8"

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import pandas as pd
import matplotlib as mpl

import jaxley as jx

In [17]:
cell_id = "20170610_1"
rec_id = 1  # Can pick any here.

In [18]:
cell = jx.read_swc(f"morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)

In [19]:
bc_output_df = pd.read_pickle(f"results/data/off_bc_output_{cell_id}.pkl")

In [20]:
stim = bc_output_df[bc_output_df["cell_id"] == cell_id]
stim = stim[stim["rec_id"] == rec_id]

In [21]:
def compute_jaxley_stim_locations(x, y):
    """For a given (x,y) location, return all branch and compartment inds within a specified distance."""
    min_dists = []
    min_comps = []
    branch_inds_in_pixel = []
    comps_in_pixel = []
    min_dist_of_branch_in_pixel = []

    for i, xyzr in enumerate(cell.xyzr):
        dists = np.sqrt((x - xyzr[:, 0])**2 + (y - xyzr[:, 1])**2)
        is_in_reach = np.min(dists) < 20  # 20 um

        if is_in_reach:
            branch_inds_in_pixel.append(i)
            min_dist_of_branch_in_pixel.append(np.min(dists))
            
            argmin_dist = np.argmin(dists)
            if len(dists) > 1:
                comp = argmin_dist / (len(dists) - 1)
            else:
                comp = 0.5
            comps_in_pixel.append(comp)
            
    return branch_inds_in_pixel, comps_in_pixel, min_dist_of_branch_in_pixel

In [22]:
bc_loc_x = stim["x_loc"].to_numpy()
bc_loc_y = stim["y_loc"].to_numpy()
bc_ids = stim["bc_id"].to_numpy()

In [23]:
bcs_which_stimulate = 0

branch_inds_for_every_bc = []
comp_inds_for_every_bc = []
mind_dists_of_branches_for_every_bc = []
bc_ids_per_stim = []

for x, y, id in zip(bc_loc_x, bc_loc_y, bc_ids):
    branches, comps, min_dist_of_branch_in_pixel = compute_jaxley_stim_locations(x, y)
    branch_inds_for_every_bc += branches
    comp_inds_for_every_bc += comps
    mind_dists_of_branches_for_every_bc += min_dist_of_branch_in_pixel
    bc_ids_per_stim += [id] * len(branches)

In [24]:
cell_id

'20170610_1'

In [25]:
stim_df = pd.DataFrame().from_dict(
    {
        "cell_id": cell_id, 
        "bc_id": bc_ids_per_stim, 
        "branch_ind": branch_inds_for_every_bc, 
        "comp": comp_inds_for_every_bc, 
        "dist_from_bc": mind_dists_of_branches_for_every_bc
    }
)

In [26]:
stim_df

Unnamed: 0,cell_id,bc_id,branch_ind,comp,dist_from_bc
0,20170610_1,25,73,1.000000,18.656333
1,20170610_1,36,67,1.000000,18.021363
2,20170610_1,36,68,1.000000,18.384294
3,20170610_1,36,69,0.700000,13.410218
4,20170610_1,36,70,0.854545,1.202681
...,...,...,...,...,...
231,20170610_1,78,23,1.000000,13.104314
232,20170610_1,78,28,1.000000,19.774721
233,20170610_1,78,95,1.000000,11.906523
234,20170610_1,84,148,0.581633,11.011436


In [14]:
stim_df["num_synapses_of_bc"] = stim_df.groupby("bc_id").bc_id.transform(len)

In [15]:
stim_df.to_pickle(f"results/data/stimuli_meta_{cell_id}.pkl")