In [None]:
import os
import sys
from utils import *

import jax
import braincell
# set visible GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
devices = jax.devices()
print("available devices:", devices)

## load gol mor
parent_folder_path = os.path.join(os.getcwd(), 'golgi_NEURON')
sys.path.append(parent_folder_path)
from GOL import Golgi_morpho_1

from neuron import h, gui
## load gol mod
current_dir = os.getcwd()
mod_path = os.path.join(current_dir, 'golgi_NEURON', 'mod_gol')
os.system(f'nrnivmodl {mod_path}')

current_dir = os.path.dirname(os.path.abspath('.'))
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
dendritex_path = os.path.join(project_root, 'braincell')
sys.path.insert(0, dendritex_path )

brainstate.environ.set(precision=32)
#jax.config.update("jax_disable_jit", True)

## NEURON

In [None]:
def neuron_simulation(dt):    
    ## create cell in NEURON
    cell_neuron = Golgi_morpho_1(el=-55, gl=1, ghcn1=1, ghcn2=1, ena=60, gna=1, ek=-80, gkv11=1, gkv34=1, gkv43=1)
    ## add stim
    stim = step_stim(cell_neuron, delay=0, dur=10, amp=0.0)
    ## simulation
    import time 
    time_start = time.time()
    t_neuron, v_neuron, spike_neuron = NeuronRun(cell=cell_neuron, stim=stim, tstop=100, dt=dt, v_init=-65)
    time_end = time.time()
    print(f'NEURON simulation time with dt={dt} ms: {time_end - time_start} s')
    ## plot voltage traces
    plot_voltage_traces(t_neuron, v_neuron, indices=[0], title='NEURON Voltage')
    return np.array(t_neuron), np.array(v_neuron)[0,:]

In [None]:
t_neuron, v_neuron = neuron_simulation(dt=0.01)

In [None]:
# def run_neuron_with_input(
#     amp: float,
#     dur: float,
#     delay: float = 0.0,
#     tstop: float = 100.0,
#     dt: float = 0.01,
#     v_init: float = -65.0
# ):
#     h('forall delete_section()')
#     cell = Golgi_morpho_1(
#         el=-55, gl=1, ghcn1=1, ghcn2=1,
#         ena=60, gna=1, ek=-80,
#         gkv11=1, gkv34=1, gkv43=1
#     )
#     stim = step_stim(cell, delay=delay, dur=dur, amp=amp)

#     # 运行仿真
#     t, v, spikes = NeuronRun(cell=cell, stim=stim, tstop=tstop, dt=dt, v_init=v_init)
#     return t, v, spikes

# amps = [0.1* i for i in range(100)]
# results = [run_neuron_with_input(amp=a, dur=10.0) for a in amps]

# t_list, v_list, spikes_list = zip(*results)
# for i, (t, v) in enumerate(zip(t_list, v_list)):
#     plot_voltage_traces(t, v, indices=[0], title=f'NEURON Cell #{i}')


In [None]:
class Golgi(braincell.MultiCompartment):
    def __init__(self, popsize, morphology, el, gl, gh1, gh2, ek, gkv11, gkv34, gkv43, ena, gnarsg, gcagrc=0, gcav23=0,
                 gcav31=0, gkca31=0, Gl=0, El=-65, V_init=-65, solver='ind_exp_euler', solver_na = 'rk4', compute_steps_na=2):
        super().__init__(
            popsize=popsize,
            morphology=morphology,
            V_th=20. * u.mV,
            V_initializer=braintools.init.Constant(V_init * u.mV),
            spk_fun=braintools.surrogate.ReluGrad(),
            solver=solver,
        )

        self.IL = braincell.channel.IL(self.varshape, E=el * u.mV, g_max=gl * u.mS / (u.cm ** 2))
        self.Ih1 = braincell.channel.Ih1_Ma2020(self.varshape, E=-20. * u.mV, g_max=gh1 * u.mS / (u.cm ** 2))
        self.Ih2 = braincell.channel.Ih2_Ma2020(self.varshape, E=-20. * u.mV, g_max=gh2 * u.mS / (u.cm ** 2))

        self.k = braincell.ion.PotassiumFixed(self.varshape, E=ek * u.mV)
        self.k.add(IKv11=braincell.channel.IKv11_Ak2007(self.varshape, g_max=gkv11 * u.mS / (u.cm ** 2)))
        self.k.add(IKv34=braincell.channel.IKv34_Ma2020(self.varshape, g_max=gkv34 * u.mS / (u.cm ** 2)))
        self.k.add(IKv43=braincell.channel.IKv43_Ma2020(self.varshape, g_max=gkv43 * u.mS / (u.cm ** 2)))

        self.na = braincell.ion.SodiumFixed(self.varshape, E=ena * u.mV)
        self.na.add(INa_Rsg=braincell.channel.INa_Rsg(self.varshape, g_max=gnarsg * u.mS / (u.cm ** 2), solver = solver_na, compute_steps= compute_steps_na))
        #self.na.add(INa =braincell.channel.INa_HH1952(self.varshape))

        #self.ca = braincell.ion.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=10. * u.ms, d=0.5 * u.um)
        #self.ca = braincell.ion.CalciumFixed(self.varshape, E=137.* u.mV, C =5e-5 * u.mM)
        #self.ca.add(ICaL=braincell.channel.ICaGrc_Ma2020(self.varshape, g_max=gcagrc * (u.mS / u.cm ** 2)))
        #self.ca.add(ICaL=braincell.channel.ICav23_Ma2020(self.varshape, g_max=gcav23 * (u.mS / u.cm ** 2)))
        #self.ca.add(ICaL=braincell.channel.ICav31_Ma2020(self.varshape, g_max=gcav31 * (u.mS / u.cm ** 2)))

        #self.kca = braincell.MixIons(self.k, self.ca)
        #self.kca.add(IKca = braincell.channel.IKca1_1_Ma2020(self.varshape, g_max=gkca31 * u.mS / (u.cm ** 2)))

    def step_run(self, t, inp):
        with brainstate.environ.context(t=t):
            self.update(inp)
            return self.V.value
        
# I_pop = u.math.stack([
#     step_input(num=nseg, dur=[100, 0, 0], amp=[i * 0.001, 0, 0], dt=DT)
#     for i in range(popsize)
# ], axis=1)  

# t_braincell, v_braincell = BraincellRun(cell=cell_braincell, I=I_pop, dt=DT)

# v0 = v_braincell[:, :, 0]

# for i in range(v0.shape[1]): 
#     plt.plot(t_braincell, v0[:, i], label=f'cell_{i}')
 
# plt.xlabel("Time")
# plt.ylabel("V (mV)")
# plt.legend()
# plt.show()

In [None]:
def step_input(num, dur, amp):
    value = u.math.zeros((len(dur), num))
    for i in range(len(value)):
        value = value.at[i, 0].set(amp[i])
    return braintools.input.section(values=value, durations=dur * u.ms) * u.nA

class Golgi(braincell.MultiCompartment):
    def __init__(
        self,
        popsize,
        morphology,
        E_L,
        gl,
        gh1,
        gh2,
        E_K,
        gkv11,
        gkv34,
        gkv43,
        E_Na,
        gnarsg,
        V_init=-65 * u.mV,
        solver_na = 'rk4',
        compute_steps_na = 1,
    ):
        super().__init__(
            popsize=popsize,
            morphology=morphology,
            V_th=20. * u.mV,
            V_initializer=braintools.init.Constant(V_init),
            spk_fun=braintools.surrogate.ReluGrad(),
            solver='staggered',
        )
        self.IL = braincell.channel.IL(self.varshape, E=E_L, g_max=gl * u.mS / (u.cm ** 2))
        self.Ih1 = braincell.channel.Ih1_Ma2020(self.varshape, E=-20. * u.mV, g_max=gh1 * u.mS / (u.cm ** 2))
        self.Ih2 = braincell.channel.Ih2_Ma2020(self.varshape, E=-20. * u.mV, g_max=gh2 * u.mS / (u.cm ** 2))

        self.k = braincell.ion.PotassiumFixed(self.varshape, E=E_K)
        self.k.add(IKv11=braincell.channel.IKv11_Ak2007(self.varshape, g_max=gkv11 * u.mS / (u.cm ** 2)))
        self.k.add(IKv34=braincell.channel.IKv34_Ma2020(self.varshape, g_max=gkv34 * u.mS / (u.cm ** 2)))
        self.k.add(IKv43=braincell.channel.IKv43_Ma2020(self.varshape, g_max=gkv43 * u.mS / (u.cm ** 2)))

        self.na = braincell.ion.SodiumFixed(self.varshape, E=E_Na)
        self.na.add(INa_Rsg=braincell.channel.INa_Rsg(self.varshape, g_max=gnarsg * u.mS / (u.cm ** 2), solver = solver_na, compute_steps = compute_steps_na))

    def step_run(self, t, inp):
        with brainstate.environ.context(t=t):
            self.update(inp)
            return self.V.value
        

morphology = braincell.Morphology.from_asc('golgi.asc')
morphology.set_passive_params()

gl, gh1, gh2, gkv11, gkv34, gkv43, gnarsg, gcagrc, gcav23, gcav31, gkca31 = seg_ion_params(morphology)

@brainstate.transform.jit
def simulate(cell_braincell, I):
    times = u.math.arange(I.shape[0]) * brainstate.environ.get_dt()
    cell_braincell.init_state()
    cell_braincell.reset_state()
    vs = brainstate.transform.for_loop(cell_braincell.step_run, times, I)  
    return times.to_decimal(u.ms), vs.to_decimal(u.mV)

def run_simulation(solver_na, compute_steps_na):
    cell_braincell = Golgi(
        popsize=10,  # number of cells in the population
        morphology=morphology,
        E_L=-55. * u.mV,
        gl=gl,
        gh1=gh1,
        gh2=gh2,
        E_K=-80. * u.mV,
        gkv11=gkv11,
        gkv34=gkv34,
        gkv43=gkv43,
        E_Na=60. * u.mV,
        gnarsg=gnarsg,
        V_init=-65 * u.mV,
        solver_na=solver_na,
        compute_steps_na=compute_steps_na,)
    
    brainstate.environ.set(dt=0.01 * u.ms)
    I = step_input(num=len(morphology.segments), dur=[1000, 0, 0], amp=[0, 0, 0])
    times = u.math.arange(I.shape[0]) * brainstate.environ.get_dt()
    cell_braincell.init_state()
    cell_braincell.reset_state()
    vs = brainstate.transform.for_loop(cell_braincell.step_run, times, I) 
    return times.to_decimal(u.ms), vs.to_decimal(u.mV)


In [None]:
plt.figure(figsize=(8, 5))  # 设置画布大小
# 循环不同 step 数
results_rk4 = {}
results_bwd = {}

for i in range(5, 6):
    t, vs_rk4 = run_simulation(solver_na='rk4', compute_steps_na=i)
    plt.plot(t, vs_rk4[:, 0, 0], linestyle='--', label=f'RK4 steps = {i}')
    results_rk4[i] = (vs_rk4[:, 0, 0])
for i in range(1, 1):
    t, vs_bwd = run_simulation(solver_na='backward_euler', compute_steps_na=i)
    plt.plot(t, vs_bwd[:, 0, 0], label=f'Backward Euler steps = {i}')
    results_bwd[i] = (vs_bwd[:, 0, 0])


In [None]:
def find_spike_times(t, v):
    """从电压数据中找出过零上升的时间点"""
    t = np.asarray(t).squeeze()
    v = np.asarray(v).squeeze()
    cross_idx = np.where((v[:-1] < 0) & (v[1:] >= 0))[0]
    t_spike = t[cross_idx] + (t[cross_idx + 1] - t[cross_idx]) * (-v[cross_idx]) / (v[cross_idx + 1] - v[cross_idx])
    return t_spike


In [None]:
spikes_neuron = find_spike_times(t_neuron, v_neuron)
results_diffs = {}

for method, result_dict in [('RK4', results_rk4), ('BWD', results_bwd)]:
    for steps, v in result_dict.items():
        spikes_model = find_spike_times(t, v)  # 使用相同时间轴
        min_len = min(len(spikes_neuron), len(spikes_model))
        diffs = spikes_model[:min_len] - spikes_neuron[:min_len]
        mean_diff = np.mean(diffs)
        std_diff = np.std(diffs)
        results_diffs[(method, steps)] = (diffs, mean_diff, std_diff)
        print(f"{method} steps={steps}: n={min_len}, mean Δt={mean_diff:.4f} ms, std={std_diff:.4f} ms")


In [None]:
# 计算新 NEURON 的 spike 差异
spikes_compare = find_spike_times(t_neuron_compare, v_neuron_compare)
min_len = min(len(spikes_neuron), len(spikes_compare))
diffs_neuron_dt = spikes_compare[:min_len] - spikes_neuron[:min_len]

mean_diff_dt = np.mean(diffs_neuron_dt)
std_diff_dt = np.std(diffs_neuron_dt)

# 把它加入 results_diffs
results_diffs[('NEURON_dt', 0.01)] = (diffs_neuron_dt, mean_diff_dt, std_diff_dt)

print(f"NEURON dt=0.01ms vs 0.001ms: mean Δt={mean_diff_dt:.4f} ms, std={std_diff_dt:.4f} ms")


In [None]:
import matplotlib.cm as cm

plt.figure(figsize=(10, 5))

all_steps = sorted(set(steps for (_, steps) in results_diffs.keys()))

cmap = cm.get_cmap('viridis', len(all_steps))

for (method, steps), (diffs, mean_diff, std_diff) in sorted(results_diffs.items()):
    spike_indices = np.arange(1, len(diffs) + 1)
    color = cmap(all_steps.index(steps))  # 用步长索引映射颜色
    label = f"{method} steps={steps} (mean={mean_diff:.3f} ms)"
    linestyle = '--' if method == 'RK4' else '-'  # 用线型区分算法
    plt.plot(spike_indices, diffs, linestyle=linestyle, color=color, linewidth=1.8, label=label)

plt.axhline(0, color='gray', linestyle='--', linewidth=1)

plt.title('Spike Time Differences vs NEURON')
plt.xlabel('Spike index')
plt.ylabel('Δt (ms)')

max_index = max(len(diffs) for _, (diffs, _, _) in results_diffs.items())
plt.xticks(np.arange(1, max_index + 1, 1))

plt.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.tight_layout(rect=[0, 0, 0.8, 1])
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()


In [None]:
plt.figure(figsize=(16, 5))

# 所有 step 数
all_steps = sorted(set(steps for (_, steps) in results_diffs.keys() if isinstance(steps, (int, float))))
n_steps = len(all_steps)

def blue_gradient(i, n):
    r = 1.0 - 0.7 * (i / (n - 1))
    g = 0.5
    b = 1.0
    return (r, g, b)

def orange_gradient(i, n):
    r = 1.0
    g = 0.8 - 0.7 * (i / (n - 1))
    b = 0.5
    return (r, g, b)

for (method, steps), (diffs, mean_diff, std_diff) in sorted(results_diffs.items()):
    spike_indices = np.arange(1, len(diffs) + 1)
    step_idx = all_steps.index(steps)

    # 不画参考线
    if method == 'NEURON_dt':
        plt.plot(spike_indices, diffs, color='red', linewidth=2, linestyle='-', 
                 label=f'NEURON dt={steps} ms (mean={mean_diff:.3f})')
        continue

    if method == 'RK4':
        color = blue_gradient(step_idx, n_steps)
        plt.plot(spike_indices, diffs, color=color, linestyle='--', linewidth=1.8,
                 label=f'RK4 steps={steps} (mean={mean_diff:.3f})')

    elif method == 'BWD':
        color = orange_gradient(step_idx, n_steps)
        plt.plot(spike_indices, diffs, color=color, linestyle='-', linewidth=1.8,
                 label=f'BWD steps={steps} (mean={mean_diff:.3f})')

plt.axhline(0, color='black', linestyle='-', linewidth=1)  # 轻微加粗 0 线即可

plt.title('Spike Time Differences vs NEURON Reference (dt=0.001 ms)')
plt.xlabel('Spike index')
plt.ylabel('Δt (ms)')

# 横轴整数刻度
max_index = max(len(diffs) for _, (diffs, _, _) in results_diffs.items())
plt.xticks(np.arange(1, max_index + 1, 1))

# 图例放右边
plt.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.tight_layout(rect=[0, 0, 0.8, 1])
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()


In [None]:
t_neuron_compare, v_neuron_compare = neuron_simulation(dt=0.01)