In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from neuron import h
h.load_file('stdrun.hoc')
import sys
if '..' not in sys.path:
    sys.path = ['..'] + sys.path
from neuroutils.trees import ImpedanceTree
from neuroutils.nodes import ImpedanceNode

In [None]:
ra    =   200.     # [Ω.cm]
rm    = 15000.     # [Ω.cm2]
cm    =     1.     # [μF/cm2]
g_pas = 1/rm       # [S/cm2]
e_pas = -65.       # [mV]
taum  = rm*cm*1e-6 # [s]
F     = 0.         # [Hz]
ω     = 2*np.pi*F  # [rad/s]
if F > 0:
    T = 1000/F     # [ms]

In [None]:
n_sections = 4
sections = [h.Section(name=f'sec_{i}') for i in range(n_sections)]
for sec in sections:
    sec.cm = cm
    sec.Ra = ra
    sec.insert('pas')
    sec.g_pas = 1/rm
    sec.e_pas = e_pas
sections[0].L,sections[0].diam,sections[0].nseg = 20,20,1
sections[1].L,sections[1].diam,sections[1].nseg = 500,10,25
sections[2].L,sections[2].diam,sections[2].nseg = 500,4,25
sections[1].connect(sections[0](1), 0)
sections[2].connect(sections[1](1), 0)
if n_sections == 4:
    sections[3].L,sections[3].diam,sections[3].nseg = 300,3,15
    sections[3].connect(sections[1](1), 0)
h.topology()

In [None]:
stim_seg = sections[2](0.5)
if F == 0:
    stim = h.IClamp(stim_seg)
    stim.dur = 10*taum*1e3
else:
    stim = h.Izap(stim_seg)
    stim.f0 = F
    stim.f1 = F
    stim.dur = 10*taum*1e3 + 5*T
stim.delay = 5*taum*1e3
stim.amp = 0.01

In [None]:
tree = ImpedanceTree(root_seg=stim_seg)
tree.compute_impedances(F)
tree.compute_attenuations()

In [None]:
segments = [node.seg for node in tree]
n_segments = len(segments)

In [None]:
t_rec = h.Vector()
t_rec.record(h._ref_t)
v_rec = []
for seg in segments:
    rec = h.Vector()
    rec.record(seg._ref_v)
    v_rec.append(rec)

In [None]:
h.tstop = stim.dur + 2*stim.delay
h.v_init = e_pas
if F == 0:
    h.cvode_active(1)
else:
    h.cvode_active(0)
    h.dt = min(h.dt, T/1000)
h.run()

In [None]:
t = np.array(t_rec)
V = np.array([np.array(rec) for rec in v_rec])
if F > 0:
    idx = (t>stim.delay+stim.dur-2*T) & (t<stim.delay+stim.dur)
    ΔV = V[:,idx].max(axis=1) - V[:,idx].min(axis=1)
else:
    idx0 = np.where(t<=stim.delay)[0][-1]
    idx1 = np.where(t<=stim.delay+stim.dur)[0][-1]
    ΔV = V[:,idx1] - V[:,idx0]

In [None]:
print('{:>12s} {:>8s} {:>13s} {:>13s} {:>13s} {:>7s}'.\
      format('Segment', 'Za (MΩ)', 'Zm (MΩ)', 'Zp (MΩ)', 'Zload (MΩ)', 'A'))
print('=' * 73)
for node in tree:
    att = np.abs(node.A[0]) if len(node.children) > 0 else 1
    print('{}({:5.3f}) {:8.4f} {:13.1f} {:13.1f} {:13.1f} {:7.4f}'.\
          format(node.seg.sec.name(),
                 node.seg.x,
                 node.Ra*1e-6,
                 node.Zm*1e-6,
                 node.Zp*1e-6,
                 node.Zload*1e-6,
                 att))

In [None]:
i,j = 0,32 #n_segments-1
seg_i,seg_j = segments[i], segments[j]
node_i = tree.find_node(ImpedanceNode(seg_i))
node_j = tree.find_node(ImpedanceNode(seg_j))
A_computed = tree.compute_attenuation(seg_j)
A_measured = ΔV[i]/ΔV[j]
print('Attenuation between segments {} and {}: {:.6f} (measured), {:.6f} (computed).'.\
     format(node_i, node_j, A_measured, A_computed))

In [None]:
A_0_k_measured = ΔV[0] / ΔV[1:]
A_0_k_computed = [tree.compute_attenuation(seg_k) for seg_k in segments[1:]]

In [None]:
if F == 0:
    vi = e_pas + ΔV[i]
    vj = e_pas + ΔV[i]/A_computed
else:
    vi = e_pas + ΔV[i]/2
    vj = e_pas + ΔV[i]/A_computed/2
t_stim = [stim.delay,stim.delay+stim.dur]
fig,ax = plt.subplots(1, 2, figsize=(7.5,2.5), width_ratios=(1,2))

lim = [A_0_k_measured.min()*0.999, A_0_k_measured.max()*1.001]
ax[0].plot(lim, lim, lw=2, color=[.4,.4,.4])
ax[0].plot(A_0_k_measured, A_0_k_computed, 'ko', markerfacecolor='w', markersize=4)
ax[0].grid(which='major', axis='both', ls=':', lw=0.5, color=[.6,.6,.6])
ax[0].set_xlabel('Measured attenuation')
ax[0].set_ylabel('Computed attenuation')

ax[1].plot(t, V[i], 'k', lw=0.5)
ax[1].plot(t, V[j], 'r', lw=0.5)
ax[1].plot(t_stim, vi+np.zeros(2), 'k--', lw=2)
ax[1].plot(t_stim, vj+np.zeros(2), 'r--', lw=2)
ax[1].set_xlabel('Time (ms)')
ax[1].set_ylabel('Vm (mV)')
ax[1].grid(which='major', axis='y', ls=':', lw=0.5, color=[.6,.6,.6])

sns.despine()
fig.tight_layout()