# Generate traces for illustration

This notebook generates the voltage and calcium traces shown in panel b.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8"

In [3]:
import time

import jax.numpy as jnp
from jax import jit, vmap, value_and_grad
import numpy as np
import matplotlib.pyplot as plt
import pickle
import matplotlib as mpl
import h5py

import jaxley as jx

from nex.rgc.utils.data_utils import (
    read_data,
    build_avg_recordings,
    build_training_data,
)
from nex.rgc.utils.utils import (
    build_cell,
    build_kernel,
)

In [4]:
import pandas as pd
from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

In [14]:
path_prefix = "../../../nex/rgc"
results_prefix = "results/train_runs/2024_05_30__10_55_30/0"

In [6]:
start_n_scan = 0
num_datapoints_per_scanfield = 128 * 8
nseg = 4
cell_id = "20161028_1"
rec_ids = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]

In [7]:
stimuli, recordings, setup, noise_full = read_data(
    start_n_scan,
    num_datapoints_per_scanfield,
    cell_id,
    rec_ids,
    "noise",
    ".."
)

# avg_recordings = build_avg_recordings(
#     recordings, rec_ids, nseg, num_datapoints_per_scanfield
# )
# with open("../results/intermediate/avg_recordings.pkl", "wb") as handle:
#     pickle.dump(avg_recordings, handle)
with open(f"{path_prefix}/results/intermediate/avg_recordings.pkl", "rb") as handle:
    avg_recordings = pickle.load(handle)

number_of_recordings_each_scanfield = list(avg_recordings.groupby("rec_id").size())

In [8]:
def value_from_parameters(parameter_values, distance):
    return parameter_values
    
def simulate(params, basal_neuron_params, somatic_neuron_params, currents):
    """Run simulation and return recorded voltages."""
    syn_weights = params[0]["w_bc_to_rgc"]
    cell_params = params[1:]

    input_currents = syn_weights * currents
    
    # Define stimuli.
    step_currents = jx.datapoint_to_step_currents(warmup, t_max-warmup, input_currents, dt, t_max)
    data_stimuli = None
    for branch, comp, step_current in zip(stim_branch_inds, stim_comps, step_currents):
        data_stimuli = cell.branch(branch).loc(comp).data_stimulate(step_current, data_stimuli=data_stimuli)

    # Define parameters.
    pstate = None
    for param in basal_neuron_params:
        name = list(param.keys())[0]
        parameter_values = param[name]
        value = value_from_parameters(parameter_values, 0.0)
        pstate = cell[basal_inds].data_set(name, value, param_state=pstate)

    for param in somatic_neuron_params:
        name = list(param.keys())[0]
        parameter_values = param[name]
        value = value_from_parameters(parameter_values, 0.0)
        pstate = cell[somatic_inds].data_set(name, value, param_state=pstate)

    # Run simulation.
    v = jx.integrate(
        cell,
        params=cell_params,
        param_state=pstate,
        data_stimuli=data_stimuli,
        checkpoint_lengths=[90, 90]
    )
    return v

sim = jit(simulate)
vmapped_sim = jit(vmap(simulate, in_axes=(None, None, None, 0)))


In [9]:
warmup = 5.0
i_amp = 0.1

currents, labels, loss_weights = build_training_data(
    i_amp,
    stimuli,
    avg_recordings,
    rec_ids, 
    num_datapoints_per_scanfield,
    number_of_recordings_each_scanfield,
)

stim_branch_inds = stimuli["branch_ind"].to_numpy()
stim_comps = stimuli["comp"].to_numpy()

In [10]:
dt = 0.025
t_max = 200.0
time_vec = np.arange(0, t_max+2*dt, dt)

In [11]:
cell = build_cell(cell_id, nseg, 5.0, path_prefix)
basal_inds = list(np.unique(cell.group_nodes["basal"]["branch_index"].to_numpy()))
somatic_inds = list(np.unique(cell.group_nodes["soma"]["branch_index"].to_numpy()))

cell.delete_recordings()
cell.delete_stimuli()

for i, rec in avg_recordings.iterrows():
    cell.branch(rec["branch_ind"]).loc(rec["comp"]).record("v", verbose=False)

for i, rec in avg_recordings.iterrows():
    cell.branch(rec["branch_ind"]).loc(rec["comp"]).record("Cai", verbose=False)

print(f"Inserted {len(cell.recordings)} recordings")
print(f"number_of_recordings_each_scanfield {number_of_recordings_each_scanfield}")
number_of_recordings = np.sum(number_of_recordings_each_scanfield)

  warn(
  warn("Found a segment with length 0. Clipping it to 1.0")


Inserted 294 recordings
number_of_recordings_each_scanfield [12, 6, 15, 21, 13, 10, 9, 10, 10, 6, 4, 11, 8, 4, 8]


In [12]:
cell.delete_trainables()
cell.basal.branch("all").make_trainable("axial_resistivity")
cell.basal.branch("all").make_trainable("radius")

Number of newly added trainable parameters: 154. Total number of trainable parameters: 154
Number of newly added trainable parameters: 154. Total number of trainable parameters: 308


In [15]:
with open(f"{path_prefix}/{results_prefix}/opt_params/params_10.pkl", "rb") as handle:
    all_opt_params = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/transforms/transform_params.pkl", "rb") as handle:
    transform_params = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/transforms/transform_basal.pkl", "rb") as handle:
    transform_basal = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/transforms/transform_somatic.pkl", "rb") as handle:
    transform_somatic = pickle.load(handle)

opt_params, opt_basal_params, opt_somatic_params = all_opt_params

parameters = transform_params.forward(opt_params)
basal_neuron_params = transform_basal.forward(opt_basal_params)
somatic_neuron_params = transform_somatic.forward(opt_somatic_params)

In [16]:
test_currents = currents[0]

v_trained = sim(
    transform_params.forward(opt_params), 
    transform_basal.forward(opt_basal_params),
    transform_somatic.forward(opt_somatic_params),
    test_currents,
)

In [17]:
with open(f"../results/02_loss/noise_image.pkl", "wb") as handle:
    pickle.dump(noise_full[:, :, 0], handle)

with open(f"../results/02_loss/v_and_cai.pkl", "wb") as handle:
    pickle.dump(v_trained, handle)

with open(f"../results/02_loss/time_vec.pkl", "wb") as handle:
    pickle.dump(time_vec, handle)