In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".9"

import subprocess
import tempfile

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from copy import deepcopy
from jax import jit, vmap, value_and_grad, lax
from jax.tree_util import tree_map
from jaxley.optimize.transforms import ParamTransform
import jax
import jaxlib
import jaxley as jx
import jax.numpy as jnp
from jaxley.channels import Leak
import jaxley.optimize.transforms as jt
import pandas as pd
import optax
from itertools import chain
import random
import pickle
import dill 
import optax
import json
from scipy import interpolate

from Allen_mech_allActive import * 

from neuron import h

from swc_utils import read_swc_varComps

import time
start_time = time.time()

from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

from jaxley.optimize.utils import l2_norm

--No graphics will be displayed.


In [2]:
# load in swc file from patch_seq morph

ID = 623434306
trans_file = str(ID) + '_transformed.swc'
base_dir = '/Users/elena.westeinde/Datasets/patch_seq/patch_seq_morphology'
data_path = '/Users/elena.westeinde/Datasets/patch_seq/electrophysiology'

swc_file = os.path.join(base_dir, trans_file)



metadata_file = '/Users/elena.westeinde/Datasets/patch_seq/specimen_metadata/20200711_patchseq_metadata_mouse.csv'
metadata = metadata = pd.read_csv(metadata_file)

metadata_cell = metadata[metadata['cell_specimen_id'] == ID]
ephys_ID = metadata_cell['ephys_session_id'].values[0]
metadata_cell.head()

output_dir = os.path.join('/Users/elena.westeinde/Datasets/raw_ephys', str(ID), 'ephys_ID_' + str(ephys_ID)+'.json')
processed_data_dir = os.path.join('/Users/elena.westeinde/Datasets/raw_ephys', str(ID))


In [4]:
code_dir = '/Users/elena.westeinde/Code/patch_seq'
print(swc_file)
import shutil
# copy the swc file to code directory
shutil.copy(swc_file, code_dir)

/Users/elena.westeinde/Datasets/patch_seq/patch_seq_morphology/623434306_transformed.swc


'/Users/elena.westeinde/Code/patch_seq/623434306_transformed.swc'

In [None]:
def find_files(substring, directory):
    matching_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if substring in file:
                matching_files.append(os.path.join(root, file))
    return matching_files
data_files = find_files(str(ephys_ID), data_path)

def call_load_data_in_allensdk(ID):
    """
    Calls load_data function from load_ephys.py in the allensdk environment
    """
    # Create wrapper script with the full function call including default parameters
    wrapper_code = f"""
import sys
from load_ephys import load_data

# Call the function with the ID and default parameters
load_data({ID})
"""
    
    # Get the directory where load_ephys.py is located
    script_dir = '/Users/elena.westeinde/Code/patch_seq'  # Directory containing load_ephys.py
    wrapper_path = os.path.join(script_dir, 'wrapper.py')
    
    # Write the wrapper script
    with open(wrapper_path, 'w') as f:
        f.write(wrapper_code)
    
    # Command using conda run, changing to the correct directory first
    command = f"cd {script_dir} && conda run -n allensdk python wrapper.py"
    
    try:
        result = subprocess.run(
            command,
            shell=True,
            check=True,
            capture_output=True,
            text=True
        )
        print("Command output:", result.stdout)
        return result
        
    except subprocess.CalledProcessError as e:
        print("Command failed with error:", e.stderr)
        raise
        
    finally:
        # Clean up wrapper file
        if os.path.exists(wrapper_path):
            os.remove(wrapper_path)

if not os.path.exists(output_dir):
    call_load_data_in_allensdk(ID)
    
    command = [
    "conda", "run",
    "-n", "allensdk",
    "python", "-m", "ipfx.bin.run_pipeline_from_nwb_file",
    data_files[0],
    processed_data_dir
    ]

    try:
        result = subprocess.run(
            command,
            check=True,
            capture_output=True,
            text=True
        )
        print("Command output:", result.stdout)
    except subprocess.CalledProcessError as e:
        print("Command failed with error:", e.stderr)
    
with open(output_dir, 'r') as f:
    ephys_data = json.load(f)
    print('Data file loaded')

# get index of last slash in data_files[0]
slash_idx = data_files[0].rfind('/')
# get the file name
folder_name = data_files[0][slash_idx+1:-4]
processed_data_file = os.path.join(processed_data_dir, folder_name, 'output.json')
with open(processed_data_file, 'r') as f:
    sweep_data = json.load(f)
    print('Sweep file loaded')


In [None]:
nseg_per_branch = 2
cell = jx.read_swc(swc_file, nseg_per_branch, max_branch_len=20000.0, assign_groups=True)

print(cell.shape)
print(list(cell.groups.keys()))
test = cell.nodes



In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 7))
colors = plt.cm.tab10.colors
cell.basal.vis(ax=ax, col=colors[2])
cell.axon.vis(ax=ax, col='k')
cell.soma.vis(ax=ax, col=colors[0])
plt.axis("off")
plt.title("Groups")
plt.show()

In [6]:
# Get relevant ephys parameters from ephys data to set initial conditions

# resting potential
# resistance
# capacitance
# rheobase
# reversal potentials
# conductances if available

# decide on loss attributes
# Option 1. point-point comparison of voltage/calcium traces
# Option 2. Comparison of features extracted from traces & calculated from model (e.g. AP threshold, AP amplitude, AP width, frequency, etc.)
# ^ is what was done for original Allen biophysical model fitting, but optimization was done by hand


# All active Allen biophysical model:
# Axon was removed, replaced with AIS of 2 compartments (?) 30 um each, 60 um length total
# inhibitory neurons only had a single dendritic type, excitatory neurons had basal + apical

# optimized passive & active properties with the same procedure
# Passive properties: one value for each was uniformly distributed across all compartments
    # specific capacitance (cm): 'capacitance'
    # passive conductance (g_pas): Leak_gLeak
    # passive reversal potential (e_pas): Leak_eLeak
    # cytoplasmic resistivity (Ra): 'axial_resistivity'

# Active properties: uniformly distributed across all compartments for each group. Every group received a separate set of channels
# ion channel identities & mechanisms were identical to those in the Allen perisomatic models (inhib cells had something diff --> look at)
# slow inactivating K current was replaced with two separate K currents: one fast inactivating, one slow inactivating (Kv1 & Kv2). 
# Original M current was replaced by a model from rat CA1 neurons Im_v2
# total of 26 free parameters: 18 active conductance densities, 4 intracellular Ca2+ dynamics parameters, 4 passive parameters

# simulation temperatre was 34 degrees C, check if cell was recorded at same temp & if not scale kinetics with a Q10 of 2.3

In [7]:
# Set free parameter bounds, for first run, copy distributions from Allen biophysical all-active model fitting

bounds = [
    {'capacitance': jt.SigmoidTransform(0.5, 10)},
    {'axial_resistivity': jt.SigmoidTransform(50, 150)},
    {'Leak_gLeak': jt.SigmoidTransform(1e-7, 1e-2)},
    {'Leak_eLeak': jt.SigmoidTransform(-110, -60)},
    {'soma_H_gH': jt.SigmoidTransform(1e-7, 1e-5)},
    {'soma_NaTs_gNaTs': jt.SigmoidTransform(1e-7, 5e-1)},
    {'soma_Nap_gNap': jt.SigmoidTransform(1e-7, 5e-2)},
    {'soma_SKE2_gSKE2': jt.SigmoidTransform(1e-7, 1e-2)},
    {'soma_SKv3_1_gSKv3_1': jt.SigmoidTransform(1e-7, 1)},
    {'soma_CaHVA_gCaHVA': jt.SigmoidTransform(1e-7, 1e-2)},
    {'soma_CaLVA_gCaLVA': jt.SigmoidTransform(1e-7, 1e-4)},
    {'soma_CaPump_gamma': jt.SigmoidTransform(5e-4, 5e-2)},
    {'soma_CaPump_decay': jt.SigmoidTransform(100, 1000)},
    {'basal_H_gH': jt.SigmoidTransform(1e-7, 1e-5)},
    {'basal_NaTs_gNaTs': jt.SigmoidTransform(1e-7, 5e-2)},
    {'basal_Nap_gNap': jt.SigmoidTransform(1e-7, 5e-2)},
    {'basal_SKv3_1_gSKv3_1': jt.SigmoidTransform(1e-7, 1)},
    {'basal_Im_v2_gIm_v2': jt.SigmoidTransform(1e-7, 1e-2)},
    {'axon_NaTs_gNaTs': jt.SigmoidTransform(1e-7, 5e-2)},
    {'axon_Nap_gNap': jt.SigmoidTransform(1e-7, 5e-2)},
    {'axon_Kd_gKd': jt.SigmoidTransform(1e-7, 1e-2)},
    {'axon_Kv2like_gKv2like': jt.SigmoidTransform(1e-7, 1e-1)},
    {'axon_K_T_gK_T': jt.SigmoidTransform(1e-7, 1e-2)},
    {'axon_SKE2_gSKE2': jt.SigmoidTransform(1e-7, 1e-2)},
    {'axon_SKv3_1_gSKv3_1': jt.SigmoidTransform(1e-7, 1)},
    {'axon_CaHVA_gCaHVA': jt.SigmoidTransform(1e-7, 1e-2)},
    {'axon_CaLVA_gCaLVA': jt.SigmoidTransform(1e-7, 1e-4)},
    {'axon_CaPump_gamma': jt.SigmoidTransform(5e-4, 5e-2)},
    {'axon_CaPump_decay': jt.SigmoidTransform(100, 1000)}
]

In [None]:


# Some electrophysiological parameters can be determined from the electrophysiology data
# For example, the resting potential can be set to the average resting potential of the cell

# insert active channels into only the soma, elsewhere only has passive currents: 
# will want to figure out how to obtain the values for these from the allen models automatically (fit_parameters.json)

# test with no constraints, but can then set upper & lower bounds for each conductance

cell.set("v", -80) # initial voltage

##########PASSIVE###############

# Capacitance (cm)

cell.set('capacitance', 1) #uF/cm^2
cell.set('capacitance', 1) # change
cell.set('capacitance', 1) # change
cell.make_trainable('capacitance')

# Axial resistance (Ra)
cell.set('axial_resistivity', 100) # ohm cm
cell.make_trainable('axial_resistivity')

# Leak
cell.insert(Leak())
cell.set('Leak_gLeak', 0.0005)
cell.set('Leak_eLeak', -70)
cell.make_trainable('Leak_gLeak')
cell.make_trainable('Leak_eLeak')

##########ACTIVE###############

# Soma
#     #Ih
cell.soma.insert(H().change_name("soma_H"))
cell.soma.set('soma_H_gH', 3e-06)
cell.make_trainable('soma_H_gH')
# #     # NaTs, replacing Nav
cell.soma.insert(NaTs().change_name("soma_NaTs"))
cell.soma.set('soma_NaTs_gNaTs', 0.5)
cell.make_trainable('soma_NaTs_gNaTs')
# #     #NaP
cell.soma.insert(Nap().change_name("soma_Nap"))
cell.soma.set('soma_Nap_gNap', 3e-06)
cell.make_trainable('soma_Nap_gNap')
#     # SK
cell.soma.insert(SKE2().change_name("soma_SKE2"))
cell.soma.set('soma_SKE2_gSKE2', 0.0004)
cell.make_trainable('soma_SKE2_gSKE2')
#     # Kv3_1
cell.soma.insert(SKv3_1().change_name("soma_SKv3_1"))
cell.soma.set('soma_SKv3_1_gSKv3_1', 0.03)
cell.make_trainable('soma_SKv3_1_gSKv3_1')
# #     # Ca_HVA
cell.soma.insert(CaHVA().change_name("soma_CaHVA"))
cell.soma.set('soma_CaHVA_gCaHVA', 1e-05)
cell.make_trainable('soma_CaHVA_gCaHVA')
#     # Ca_LVA
cell.soma.insert(CaLVA().change_name("soma_CaLVA"))
cell.soma.set('soma_CaLVA_gCaLVA', 0.0003)
cell.make_trainable('soma_CaLVA_gCaLVA')
# Ca pump
cell.soma.insert(CaPump().change_name("soma_CaPump")) # shell depth matches allen
cell.soma.set('soma_CaPump_gamma', 0.05)
cell.make_trainable('soma_CaPump_gamma')
cell.soma.set('soma_CaPump_decay', 589)
cell.make_trainable('soma_CaPump_decay')

# Dendrites
#     #Ih
cell.basal.insert(H().change_name("basal_H"))
cell.basal.set('basal_H_gH', 3e-06)
cell.make_trainable('basal_H_gH')
#     #NaTs
cell.basal.insert(NaTs().change_name("basal_NaTs"))
cell.basal.set('basal_NaTs_gNaTs', 0.05)
cell.make_trainable('basal_NaTs_gNaTs')
#     #NaP
cell.basal.insert(Nap().change_name("basal_Nap"))
cell.basal.set('basal_Nap_gNap', 0.0003)
cell.make_trainable('basal_Nap_gNap')
#     #SKv3_1
cell.basal.insert(SKv3_1().change_name("basal_SKv3_1"))
cell.basal.set('basal_SKv3_1_gSKv3_1', 0.03)
cell.make_trainable('basal_SKv3_1_gSKv3_1')
#     #M
cell.basal.insert(Im_v2().change_name("basal_Im_v2"))
cell.basal.set('basal_Im_v2_gIm_v2', 0.0001)
cell.make_trainable('basal_Im_v2_gIm_v2')

# Axon
#     #NaTs
cell.axon.insert(NaTs().change_name("axon_NaTs"))
cell.axon.set('axon_NaTs_gNaTs', 0.05)
cell.make_trainable('axon_NaTs_gNaTs')
#     #NaP
cell.axon.insert(Nap().change_name("axon_Nap"))
cell.axon.set('axon_Nap_gNap', 0.0003)
cell.make_trainable('axon_Nap_gNap')
#     #Kd
cell.axon.insert(Kd().change_name("axon_Kd"))
cell.axon.set('axon_Kd_gKd', 0.0001)
cell.make_trainable('axon_Kd_gKd')
#     #Kv2like
cell.axon.insert(Kv2like().change_name("axon_Kv2like"))
cell.axon.set('axon_Kv2like_gKv2like', 0.01)
cell.make_trainable('axon_Kv2like_gKv2like')
#   #K_T
cell.axon.insert(K_T().change_name("axon_K_T"))
cell.axon.set('axon_K_T_gK_T', 0.0001)
cell.make_trainable('axon_K_T_gK_T')
#     #SK
cell.axon.insert(SKE2().change_name("axon_SKE2"))
cell.axon.set('axon_SKE2_gSKE2', 0.0001)
cell.make_trainable('axon_SKE2_gSKE2')
#     #SKv3_1
cell.axon.insert(SKv3_1().change_name("axon_SKv3_1"))
cell.axon.set('axon_SKv3_1_gSKv3_1', 0.03)
cell.make_trainable('axon_SKv3_1_gSKv3_1')
#     #Ca_HVA
cell.axon.insert(CaHVA().change_name("axon_CaHVA"))
cell.axon.set('axon_CaHVA_gCaHVA', 1e-05)
cell.make_trainable('axon_CaHVA_gCaHVA')
#     #Ca_LVA
cell.axon.insert(CaLVA().change_name("axon_CaLVA"))
cell.axon.set('axon_CaLVA_gCaLVA', 0.0003)
cell.make_trainable('axon_CaLVA_gCaLVA')
#     #Ca pump
cell.axon.insert(CaPump().change_name("axon_CaPump"))
cell.axon.set('axon_CaPump_gamma', 0.05)
cell.make_trainable('axon_CaPump_gamma')
cell.axon.set('axon_CaPump_decay', 589)
cell.make_trainable('axon_CaPump_decay')

# reversal potentials & concentrations
# shared channel parameters overwrite eachother, set final parameter values after adding all the channels
# these values were obtained from the Allen biophysical models white pages
cell.insert(CaNernstReversal())
cell.set('eK', -107.0)
cell.set('eNa', 53.0)
cell.set('CaCon_e', 2) # mM units
cell.set('CaCon_i', 1e-04) # mM units
cell.set("celsius", 34) # temperature
    

In [9]:
test = cell.nodes

In [10]:
# Define training functions

transform = ParamTransform(bounds) # order of parameters in bounds must match order of parameters in cell.get_parameters()

def simulate(params, dt):
    return jx.integrate(cell, params=params, delta_t=dt, voltage_solver="jaxley.thomas")

# Setup loss function
def loss_from_v(jaxley_output, target_v):
    v = jnp.array(jaxley_output[0].T)
    # jaxley output has 1-2 extra time steps compared to the bmtk data, not sure why
    if len(v) != len(target_v):
        print('Lengths of voltage traces do not match, difference is: ', len(v) - len(target_v))
        diff = len(v) - len(target_v)
        if diff > 0:
            v = v[diff:]
        else:
            v= v[:diff]
    voltage_loss = jnp.mean(jnp.abs(v - target_v))  
    combined_loss = voltage_loss
    return combined_loss

def loss_fun(opt_params, target_v, dt):
    params = transform.forward(opt_params) # if set bounds on trainable params will need to transform in & out of constrained space
    jaxley_output = simulate(params, dt)
    loss = loss_from_v(jaxley_output, target_v)
    # reg = regularizer(opt_params) define a regularizer to had a penality for specific needs if desired
    return 1.0 * loss + 1e-8

def init(opt_params, init_params = None):
    """
    Set initial parameter values 

    """
    if init_params is None:
        for i in range(len(opt_params)):
            key = list(opt_params[i].keys())[0]
            rands = 2.0 * np.random.randn(*(opt_params[i][key].shape))
            opt_params[i][key] = jnp.asarray(rands)
        return opt_params
    else:
        opt_params = init_params
    return opt_params


def resample_timeseries(time, data, new_dt=0.1):
    """
    Resample time series data to a new time step using cubic spline interpolation
    
    Args:
        time: Original time points
        data: Original data values
        new_dt: New time step (default 0.1)
    
    Returns:
        new_time: Resampled time points
        new_data: Resampled data values
    """
    # Create new time points
    new_time = np.arange(time[0], time[-1], new_dt)
    
    # Create cubic spline interpolator
    # s=0 means no smoothing, which minimizes artifacts
    f = interpolate.splrep(time, data, s=0)
    
    # Evaluate spline at new time points
    new_data = interpolate.splev(new_time, f)
    
    return new_time, new_data

In [11]:
jitted_sim = jit(simulate, static_argnums=(1,))
vmapped_sim = jit(vmap(simulate, in_axes=(0,)))
# paralellize the loss function
vmapped_loss = vmap(loss_from_v, in_axes=(0,))
jitted_loss_fn = jit(loss_fun)
grad_fn = jit(value_and_grad(loss_fun))#, static_argnames=['transform'])

In [None]:
_, new_voltage = resample_timeseries(ephys_data['long_squares']['20']['time'], ephys_data['long_squares']['20']['voltage'], new_dt=0.0001)
new_time, new_current = resample_timeseries(ephys_data['long_squares']['20']['time'], ephys_data['long_squares']['20']['current'], new_dt=0.0001)

fig, axes = plt.subplots(2, 1, figsize=(7, 5))
axes[0].plot(ephys_data['long_squares']['20']['time'], ephys_data['long_squares']['20']['voltage'], color='k', linewidth=3)
axes[0].plot(new_time, new_voltage, color='y', linewidth=1)
axes[1].plot(ephys_data['long_squares']['20']['time'], ephys_data['long_squares']['20']['current'], color='k', linewidth=3)
axes[1].plot(new_time, new_current, color='y', linewidth=1)
axes[0].set_ylabel("membrane potential (mV)")
axes[1].set_ylabel("current (pA)")
axes[1].set_xlabel("time (s)")
axes[0].set_title("All sweeps")
axes[0].set_xlim(0.5, 1.5)
axes[1].set_xlim(0.5, 1.5)

# minimal sampling rate that maintains spike fidelity is 0.0001s, or 0.1ms
# original sampling rate was 0.02ms, so resampled to 0.1ms

In [None]:
# will test training with just a single sweep first, then move on to multiple sweeps

# get sweep data
sweep_ids = list(ephys_data['long_squares'].keys())
sweep_id = sweep_ids[0]

# get target voltage & time
new_dt = 0.0001 # time step in seconds
_, new_voltage = resample_timeseries(ephys_data['long_squares'][sweep_id]['time'], ephys_data['long_squares'][sweep_id]['voltage'], new_dt=0.0001)
new_time, new_current = resample_timeseries(ephys_data['long_squares'][sweep_id]['time'], ephys_data['long_squares'][sweep_id]['current'], new_dt=0.0001)

sweep_info = sweep_data['sweep_extraction']['sweep_features'][int(sweep_id)]
stim_amp = sweep_info['stimulus_amplitude']
stim_dur = sweep_info['stimulus_duration'] # seconds
stim_start = sweep_info['stimulus_start_time'] # seconds

# Define stimulus parameters
i_delay = round(stim_start * 1e3) - 500  # ms
i_dur = round(stim_dur * 1e3)  # ms
dt = 0.1  # time step in ms
t_max = 3000.0  # ms
i_amp = stim_amp  * 1e-3 # nA current stim amplitude, convert from pA

# convert to time step in seconds
idx_max = int(t_max/dt)
idx_min = int(i_delay/dt)

target_v = new_voltage[idx_min:idx_max]
target_time = new_time[idx_min:idx_max]

# recording site
cell.delete_stimuli()
# define step current stimulus
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max - i_delay)

# stimulate the soma
cell.branch(0).loc(0.0).stimulate(current)

# record from the soma
cell.branch(0).comp(0).record("v",verbose=True)
cell.branch(0).comp(0).record("CaCon_i", verbose=True)
cell.branch(0).comp(0).record("eCa", verbose=True)
cell.branch(0).comp(0).recordings

cell.init_states()




In [14]:
# # set = init()
# # Transform constrained parameters into unconstrained space
# # This is done because optimization algorithms like GD work better in unconstrianed space
# # Avoids getting stuck at boundary values
# # prevents invalid parameter values during optimization steps
# # makes optimization landscape smoother
# # lets optimizers take larger steps without violating constraints
# start_time = time.time()
# params = cell.get_parameters()
# opt_params = transform.inverse(params)
# l = jitted_loss_fn(opt_params, target_v, dt)
# print("Loss time", time.time() - start_time)
# # returns only the loss function value (not the gradient)
# # used only for evaluation purposes

# start_time = time.time()
# loss_val, grad_val = grad_fn(opt_params, target_v, dt)
# print("Gradient time", time.time() - start_time)
# # returns both the loss function value and the gradient
# # needed in gradient-based optimization process 

In [None]:
import psutil
losses_of_every_starting_point = []  # Stores loss trajectories for each random start
best_traces = []                     # Stores best-fitting voltage traces
best_losses = []                     # Stores lowest loss achieved for each start
all_best_params = []                 # Stores optimal parameters found
total_sims = 0                       # Counter for total simulations run
beta = 0.8                           # Gradient normalization power factor
epochs = 2                    

seed = 1
_ = np.random.seed(seed)

params = cell.get_parameters()
opt_params =  transform.inverse(params)

opt_params = init(opt_params) # set random initial parameters
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(opt_params)

train_losses = []
grad_norms = []
epoch = 0
best_loss = 10000.0

for epoch in range(epochs):
    
    t0 = time.time()
    loss_val, grad_val = grad_fn(opt_params, target_v, dt)#, transform)
    total_sims += 1

    if loss_val < best_loss:
        best_loss = loss_val
        #best_params = tree_map(jnp.array, opt_params)  # Explicit JAX copy
        # Convert JAX arrays to Python floats
        best_params = [
            {k: float(v[0]) for k, v in d.items()}
            for d in opt_params
        ]
#         with open(os.path.join(save_dir, "opt_params_unconstrained.pkl"), "wb") as handle:
#             pickle.dump(opt_params, handle)

    grad_norm = l2_norm(grad_val)
    grad_val = tree_map(lambda x: x / grad_norm**beta * 1.0, grad_val) 
    grad_norm_corrected = l2_norm(grad_val)
    train_losses.append(loss_val)
    # Stop early if loss is below required threshold
    if loss_val < 0.1: #required_losses[iteration]:
        break
    
    t0 = time.time()
    updates, opt_state = optimizer.update(grad_val, opt_state)
    update_norm = l2_norm(updates)
    opt_params = optax.apply_updates(opt_params, updates)
    grad_norms.append(grad_norm)
    print(f"Update time: {time.time() - t0:.2f}")
    
    epoch += 1
    print(f"Step {epoch} time: {time.time() - t0:.2f}")
    
    # if epoch % 5 == 0:
    #     print(f"loss in epoch {epoch}: {loss_val:.4f}, gradient_norm {grad_norm:.4f}, corrected {grad_norm_corrected:.4f}")
    
    #     params = transform.forward(opt_params)
    #     output = jitted_sim(params)
        
    #     jax_v = np.array(output[0].T)
    #     jax_cai = np.array(output[1].T)

    #     fig, ax = plt.subplots(2, 3, figsize=(10, 5))
    #     _ = ax[0,0].plot(target_v, c="k", linewidth=4)
    #     _ = ax[0,0].plot(jax_v[3:], c="y", linestyle="--")
    #     _ = ax[0,1].plot(target_v, c="k", linewidth=4) #, marker="o")
    #     _ = ax[0,1].plot(jax_v[3:], c="y",linestyle="--") #, marker="o")
    #     _ = ax[0,1].set_xlim([5000, 8200])
    #     _ = ax[0,1].set_ylim([-70, -40])
    #     _ = ax[0,2].plot(target_v, c="k", linewidth=4) #, marker="o")
    #     _ = ax[0,2].plot(jax_v[:-3], c="y",linestyle="--") #, marker="o")
    #     _ = ax[0,2].set_xlim([8000, 17000])
    #     _ = ax[0,2].set_ylim([-40, -20])
    #     # set figure title
    #     _ = ax[0,1].set_title("Data vs Jaxley Voltage, adjusted")
    #     # add legend
    #     _ = ax[0,2].legend(["Data", "Jaxley"])

    #     plt.tight_layout()
    #     plt.show()

    

In [None]:
print(type(best_params[0]['capacitance']))

In [18]:
params = transform.forward(opt_params)
output = jitted_sim(params, dt)

jax_v = np.array(output[0].T)
jax_cai = np.array(output[1].T)



In [None]:
t_max = int(t_max)
fig, ax = plt.subplots(1, 1, figsize=(5, 3))
_ = ax.plot(target_time, target_v, c="k", linewidth=4)
_ = ax.plot(target_time,jax_v[2:], c="y", linestyle="--")
_ = ax.plot(target_time,target_v, c="k", linewidth=4) #, marker="o")
_ = ax.plot(target_time,jax_v[2:], c="y",linestyle="--") #, marker="o")
# _ = ax[0,1].set_xlim([5000, 8200])
# _ = ax[0,1].set_ylim([-70, -40])
_ = ax.plot(target_time,target_v, c="k", linewidth=4) #, marker="o")
_ = ax.plot(target_time,jax_v[2:], c="y",linestyle="--") #, marker="o")
# _ = ax[0,2].set_xlim([8000, 17000])
# _ = ax[0,2].set_ylim([-40, -20])
# set figure title
_ = ax.set_title("Data vs Jaxley Voltage, adjusted")
# add legend
_ = ax.legend(["Data", "Jaxley"])

plt.tight_layout()
plt.show()