In [1]:
import numpy as np
import matplotlib.pyplot as plt
import math
import statistics as st
from collections import Counter
import os
import os.path
import scipy
import csv
from scipy import signal
import pandas as pd

In [2]:
plot_source="neuronmonitor" # choose "neuronmonitor" or "jk_custom" for CARLsim data file source
# Create time points
Tmin, Tmax, dt = 0, 1000, 0.025  # Step size
T = np.arange(Tmin, Tmax + dt, dt)

# Declare the current to be used
I_ext = 150

## Nine parameters for IM model
# Axo_Axonic
k, a, b, d, C, vr, vt, c, vpeak = 3.961462878, 0.004638608, 8.683644937, 15, 165, -57.09978287, \
                                  -51.71875628, -73.96850421, 27.79863559

preNeuronType = "Axo-Axonic"

## Make a state vector that has a (v, u) pair for each timestep
s = np.zeros((len(T), 2))

# Create a vector to store spike times
if plot_source=="neuronmonitor":
    spike_times = np.array([15, 40, 67, 97, 130, 166, 206, 249, 295, 343, 392, 442, 493, 544, 595, 646, 697, 748, 799, 850, 901, 952])
if plot_source=="jk_custom":
    spike_times = np.array([])
isi_mode_list = np.array([])


## Initial values
s[0, 0] = vr
s[0, 1] = 0


# Note that s1[0] is v, s1[1] is u. This is Izhikevich equation in vector form
def s_dt(s1, I):
  v_dt = (k*(s1[0] - vr)*(s1[0] - vt) - s1[1] + I)*(1/C)
  u_dt = a*(b*(s1[0] - vr) - s1[1])
  return np.array([v_dt, u_dt])


## SIMULATE
for t in range(len(T)-1):
  # Calculate the four constants of Runge-Kutta method
  k_1 = s_dt(s[t], I_ext)
  k_2 = s_dt(s[t] + 0.5*dt*k_1, I_ext)
  k_3 = s_dt(s[t] + 0.5*dt*k_2, I_ext)
  k_4 = s_dt(s[t] + dt*k_3, I_ext)

  s[t+1] = s[t] + (1.0/6)*dt*(k_1 + 2*k_2 + 2*k_3 + k_4)

  # Reset the neuron if it has spiked
  if s[t+1, 0] >= vpeak:
    s[t, 0]   = vpeak # add Dirac pulse for visualisation
    s[t+1, 0] = c   # Reset to resting potential
    s[t+1, 1] += d  # Update recovery variable
    if plot_source=="jk_custom":
        spike_times = np.append(spike_times, math.ceil(t*dt))
    #print(math.floor(t * dt))

v = s[:, 0]
u = s[:, 1]


## Nine parameters for IM model
# Pyramidal
postNeuronType = "Pyramidal"
a, b, c, d, k, vr, vt, C, vpeak = 0.00838350334098279, -42.5524776883928, -38.8680990294091, 588.0, \
                                  0.792338703789581, -63.2044008171655, -33.6041733124267, 366.0, 35.8614648558726

# Synaptic Parameters and equations
synaptic_event_times = list(spike_times)
print(synaptic_event_times)

# TPM parameters
g0, tau_d, tau_f, tau_r, Utilization, e_rev = 3.644261648, 10.71107251, 17.20004939, 435.8103009, 0.259914361, -70

tm_model = 'Carlsim'#'Keivan'

if tm_model == 'Keivan':
    # in CARLsim there is no change to the g param from the value it is set as in connect()
    # Note: NS is unclear why this calculation exists and is related to eq. 13 in (Moradi, 2022) supp mat.
    g0 /= Utilization

def synaptic_event(delta_t_, g, g0, tau_d_, tau_r_, tau_f_, utilization, x0, y0, u0, e_syn):
    if tm_model != 'Carlsim':
        # TM Model that depends on tau_d
        tau1r = tau_d_ / ((tau_d_ - tau_r_) if tau_d_ != tau_r_ else 1e-13)
        y_ = y0 * math.exp(-delta_t_ / tau_d_)
        x_ = 1 + (x0 - 1 + tau1r * y0) * math.exp(-delta_t_ / tau_r_) - tau1r * y_
        u_ = u0 * math.exp(-delta_t_ / tau_f_)
        u0 = u_ + utilization * (1 - u_)
        y0 = y_ + u0 * x_
        x0 = x_ - u0 * x_
        g  = g0 * y0
        #print("%f  = %f * %f; u0:%f x0:%f x_:%f" % (g, g0, y0, u0, x0, x_))
    else:
        # Carlsim's TM Model
        A  = 1 / Utilization # this seems to serve the same function as g0 /= Utilization would
        u_ = u0 + utilization * (1 - u0)
        x_ = x0 - u_ * x0
        #print("%f + %f * (%f * %f * %f)" % (g, g0, A, u_, x_))        
        g  = g + g0 * (A * u_ * x0)
        u0 = u_
        x0 = x_
        
    return g, x0, y0, u0


def synaptic_current(I, g, delta_t_, tau_d_, e_syn_, tm_model):
    if tm_model == 'Keivan':
        #g = g * math.exp(-delta_t_ / tau_d_) * e_syn_
        I = g * math.exp(-delta_t_ / tau_d_) * e_syn_
    if tm_model == 'Carlsim':
        I = g * e_syn_
        #print("%f * (1 - (1 / %f))" % (g,tau_d))
    return I

def synaptic_decay(g, tau_d_, tm_model):
    g = g * (1 - (1 / tau_d_))
    return g

def stp_variable_update(u, x, tau_f, tau_r, tm_model):
    tau_f_inv = 1/tau_f
    tau_r_inv = 1/tau_r
    u = u * (1 - tau_f_inv)
    x = x + (1 - x) * tau_r_inv
    #u = u + -u/tau_f
    #x = x + (1 - x)/tau_r
    return u, x

# Initialize synaptic state variables
g, x0, y0 = 0.0, 1.0, 0.0
I = 0.0
if tm_model == 'Carlsim':
    u0 = 0 #Utilization # NS does not see evidence in CARLsim that this is initialized as other than 0
else:
    u0 = 0

# Make a state vector that has a (v, u) pair for each timestep
s = np.zeros((len(T), 2))

# Initial Izhikevich state variables
s[0, 0] = vr
s[0, 1] = 0


# Note that s1[0] is v, s1[1] is u. This is Izhikevich equation in vector form
def s_dt(s1, I):
    v_dt = (k * (s1[0] - vr) * (s1[0] - vt) - s1[1] + I) * (1 / C)
    u_dt = a * (b * (s1[0] - vr) - s1[1])
    return np.array([v_dt, u_dt])

# SIMULATE
next_synaptic_event_time, delta_t, I_syn = synaptic_event_times[0], 0.0, [0]
synaptic_event_time = next_synaptic_event_time
spike_times = np.array([])
for t in range(len(T) - 1):
    v = s[t, 0]
    e_syn = v - e_rev
    time = T[t]

    # in CARLsim this occurs before the synaptic spike current update. doSTPUpdateAndDecayCond() is before globalStateUpdate().
    if (tm_model == 'Carlsim' and t % (1/dt) == 0): # run every whole number millisecond
        g = synaptic_decay(g, tau_d, tm_model)
        u0, x0 = stp_variable_update(u0, x0, tau_f, tau_r, tm_model)

    if next_synaptic_event_time <= time:
        inter_event_time = next_synaptic_event_time - synaptic_event_time
        g, x0, y0, u0 = synaptic_event(inter_event_time, g, g0, tau_d, tau_r, tau_f, Utilization, x0, y0, u0, e_syn)
        # NS updated g for carlsim calcs
        #print(g, time, synaptic_event_time, next_synaptic_event_time)
        print("t:%d I:%.3f\tg:%.3f\tu:%.3f\tx:%.3f\tv:%.3f\tsynaptic spike" % (time,I,g,u0,x0,e_syn))
        synaptic_event_time = time
        if len(synaptic_event_times) > 1:
            del synaptic_event_times[0]
            next_synaptic_event_time = synaptic_event_times[0]
        else:
            next_synaptic_event_time = math.inf

    delta_t = time - synaptic_event_time
    if delta_t >= 0:
        if (t % (1/dt) == 0):
            I = synaptic_current(I, g, delta_t, tau_d, e_syn, tm_model)
            print("t:%d I:%.3f\tg:%.3f\tu:%.3f\tx:%.3f\tv:%.3f\tcurrent decay" % (t/(1/dt),I,g,u0,x0,e_syn))
    I_syn.append(I)

    # Calculate the four constants of Runge-Kutta method
    k_1 = s_dt(s[t], -I)
    k_2 = s_dt(s[t] + 0.5 * dt * k_1, -I)
    k_3 = s_dt(s[t] + 0.5 * dt * k_2, -I)
    k_4 = s_dt(s[t] + dt * k_3, -I)

    s[t + 1] = s[t] + (1.0 / 6) * dt * (k_1 + 2 * k_2 + 2 * k_3 + k_4)

    # Reset the neuron if it has spiked
    if s[t + 1, 0] >= vpeak:
        s[t, 0] = vpeak  # add Dirac pulse for visualisation
        s[t + 1, 0] = c  # Reset to resting potential
        s[t + 1, 1] += d  # Update recovery variable
        spike_times = np.append(spike_times, math.floor(t * dt))

v = s[:, 0]
u = s[:, 1]

# First define function to flip the sign of the current
def sign(lst): 
    return [ -i for i in lst ]

# downsample
if plot_source=="neuronmonitor":
    steps = 40 #1/dt
    dt = 1  # Step size
    T = np.arange(Tmin, Tmax + dt, dt)

# Compare to CARLsim simulation output
if plot_source=="neuronmonitor":
    FH = np.loadtxt("stp_compare_results.csv")
elif plot_source=="jk_custom":
    FH = np.loadtxt("HC_IM_05_26_aac_pyr_I_150pA_fast_1_slow_0.txt")
I = FH[1::2]
V = FH[0::2]
I = sign(I)
I = I[0:len(I_syn)-1]
V = V[0:len(I_syn)-1]
ax1 = plt.subplot(211)
ax1.plot(T,np.append(V, [V[-1]]), label = "CARLsim TM") # added the last value of V to ensure the same length as time vector
plt.ylabel('Membrane potential (mV)')
ax2 = plt.subplot(212)
ax2.plot(T,np.append(I, [I[-1]]), label = "CARLsim TM") # added the last value of I to ensure the same length as time vector
plt.ylabel('Synaptic Current (pA)')
plt.xlabel('Time (ms)')
plt.legend(loc = "center right")
plt.tight_layout()

# downsample
if plot_source=="neuronmonitor":
    v = v[0:v.size:steps]
    I_syn = I_syn[0:len(I_syn):steps]

## Plot the membrane potential
ax1 = plt.subplot(211)
ax1.plot(T, v, color = "orange", linestyle='dotted', label = "Keivan TM", alpha=0.85)
plt.ylabel('Membrane potential (mV)')
plt.title(f"{postNeuronType}")
ax2 = plt.subplot(212)
ax2.plot(T, I_syn, color = "orange", linestyle='dotted', label = "Keivan TM", alpha=0.85)
plt.ylabel('Synaptic Current (pA)')
plt.xlabel('Time (ms)')
plt.legend(loc = "center right")
plt.tight_layout()
fileOutputName = preNeuronType + '_' + postNeuronType + '_' + str(I_ext) + 'pA' + \
      '_CARLsim_vs_Keivan_superimposed.png'
plt.savefig(fileOutputName, dpi=800)
plt.clf()


# Look at the error between CARLsim and python computed synaptic signal
I = np.append(I, [I[-1]])
V = np.append(V, [V[-1]])

af = scipy.fft.fft(I_syn)
bf = scipy.fft.fft(I)
c = scipy.ifft(af * scipy.conj(bf))
time_shift = np.argmax(abs(c))
#         print(time_shift)

I_2 = I[:-1]
I_2 = I_2[::int(1/dt)]
I_syn_2 = I_syn[0+time_shift:]
I_syn_2 = I_syn_2[::int(1/dt)]
I_syn_2 = np.array(np.array(I_syn_2,dtype=np.float32))
V_2 = V[:-1]
V_2 = V_2[::int(1/dt)]
v_2 = v[0+time_shift:]
v_2 = v_2[::int(1/dt)]

if len(V_2) == len(v_2):
    ax1 = plt.subplot(211)
    ax1.plot(abs(V_2 - v_2))
    plt.ylabel('Membrane potential (mV)')
    plt.title(f"{postNeuronType}")
    ax2 = plt.subplot(212)
    ax2.plot(abs(I_2 - I_syn_2))
    plt.ylabel('I_syn diff (HCO - Carlsim) (pA)')
    plt.xlabel('Time (ms)')
    plt.tight_layout()
    fileOutputName = preNeuronType + '_' + postNeuronType + '_' + str(I_ext) + 'pA' + \
          '_CARLsim_vs_Keivan_error.png'
    plt.savefig(fileOutputName, dpi=800)
    plt.clf()

    # Append min, max, mean, and median of errors between V and I
    pctErrorV = abs((V_2 - v_2)/v_2)
    pctErrorI = abs((I_2 - I_syn_2)/I_syn_2)
    maxErrorV = max(pctErrorV[~np.isnan(pctErrorV)])
    maxErrorI = max(pctErrorI[~np.isnan(pctErrorI)])
    minErrorV = min(pctErrorV[~np.isnan(pctErrorV)])
    minErrorI = min(pctErrorI[~np.isnan(pctErrorI)])
    meanErrorV = np.mean(pctErrorV[~np.isnan(pctErrorV)])
    meanErrorI = np.mean(pctErrorI[~np.isnan(pctErrorI)])
    medianErrorV = np.median(pctErrorV[~np.isnan(pctErrorV)])
    medianErrorI = np.median(pctErrorI[~np.isnan(pctErrorI)])
    mismatchV = sum(abs(V_2-v_2)*dt)/sum(abs((V_2 + v_2)/2))
    mismatchI = sum(abs(I_2-I_syn_2)*dt)/sum(abs((I_2 + I_syn_2)/2))


[15, 40, 67, 97, 130, 166, 206, 249, 295, 343, 392, 442, 493, 544, 595, 646, 697, 748, 799, 850, 901, 952]
t:15 I:0.000	g:3.644	u:0.260	x:0.740	v:6.796	synaptic spike
t:15 I:24.765	g:3.644	u:0.260	x:0.740	v:6.796	current decay
t:16 I:22.236	g:3.304	u:0.245	x:0.741	v:6.730	current decay
t:17 I:19.996	g:2.996	u:0.231	x:0.741	v:6.675	current decay
t:18 I:18.005	g:2.716	u:0.217	x:0.742	v:6.630	current decay
t:19 I:16.232	g:2.462	u:0.205	x:0.742	v:6.592	current decay
t:20 I:14.648	g:2.232	u:0.193	x:0.743	v:6.561	current decay
t:21 I:13.230	g:2.024	u:0.181	x:0.744	v:6.536	current decay
t:22 I:11.958	g:1.835	u:0.171	x:0.744	v:6.517	current decay
t:23 I:10.816	g:1.664	u:0.161	x:0.745	v:6.501	current decay
t:24 I:9.789	g:1.508	u:0.152	x:0.745	v:6.490	current decay
t:25 I:8.863	g:1.368	u:0.143	x:0.746	v:6.481	current decay
t:26 I:8.029	g:1.240	u:0.134	x:0.747	v:6.475	current decay
t:27 I:7.275	g:1.124	u:0.127	x:0.747	v:6.472	current decay
t:28 I:6.594	g:1.019	u:0.119	x:0.748	v:6.470	current deca

t:454 I:1.287	g:0.353	u:0.132	x:0.241	v:3.644	current decay
t:455 I:1.163	g:0.320	u:0.124	x:0.242	v:3.631	current decay
t:456 I:1.051	g:0.290	u:0.117	x:0.244	v:3.618	current decay
t:457 I:0.949	g:0.263	u:0.110	x:0.246	v:3.605	current decay
t:458 I:0.858	g:0.239	u:0.104	x:0.248	v:3.593	current decay
t:459 I:0.775	g:0.216	u:0.098	x:0.249	v:3.581	current decay
t:460 I:0.700	g:0.196	u:0.092	x:0.251	v:3.568	current decay
t:461 I:0.633	g:0.178	u:0.086	x:0.253	v:3.556	current decay
t:462 I:0.572	g:0.161	u:0.081	x:0.254	v:3.544	current decay
t:463 I:0.516	g:0.146	u:0.077	x:0.256	v:3.531	current decay
t:464 I:0.467	g:0.133	u:0.072	x:0.258	v:3.519	current decay
t:465 I:0.422	g:0.120	u:0.068	x:0.260	v:3.507	current decay
t:466 I:0.381	g:0.109	u:0.064	x:0.261	v:3.494	current decay
t:467 I:0.344	g:0.099	u:0.060	x:0.263	v:3.482	current decay
t:468 I:0.311	g:0.090	u:0.057	x:0.265	v:3.469	current decay
t:469 I:0.281	g:0.081	u:0.054	x:0.266	v:3.457	current decay
t:470 I:0.254	g:0.074	u:0.050	x:0.268	v:

t:921 I:-0.929	g:0.169	u:0.081	x:0.265	v:-5.500	current decay
t:922 I:-0.846	g:0.153	u:0.077	x:0.267	v:-5.523	current decay
t:923 I:-0.770	g:0.139	u:0.072	x:0.268	v:-5.546	current decay
t:924 I:-0.701	g:0.126	u:0.068	x:0.270	v:-5.569	current decay
t:925 I:-0.638	g:0.114	u:0.064	x:0.272	v:-5.592	current decay
t:926 I:-0.581	g:0.103	u:0.060	x:0.273	v:-5.615	current decay
t:927 I:-0.529	g:0.094	u:0.057	x:0.275	v:-5.638	current decay
t:928 I:-0.481	g:0.085	u:0.053	x:0.277	v:-5.660	current decay
t:929 I:-0.438	g:0.077	u:0.050	x:0.278	v:-5.683	current decay
t:930 I:-0.399	g:0.070	u:0.047	x:0.280	v:-5.705	current decay
t:931 I:-0.363	g:0.063	u:0.045	x:0.282	v:-5.727	current decay
t:932 I:-0.330	g:0.057	u:0.042	x:0.283	v:-5.750	current decay
t:933 I:-0.301	g:0.052	u:0.040	x:0.285	v:-5.772	current decay
t:934 I:-0.274	g:0.047	u:0.037	x:0.287	v:-5.794	current decay
t:935 I:-0.249	g:0.043	u:0.035	x:0.288	v:-5.816	current decay
t:936 I:-0.227	g:0.039	u:0.033	x:0.290	v:-5.837	current decay
t:937 I:

  c = scipy.ifft(af * scipy.conj(bf))
  c = scipy.ifft(af * scipy.conj(bf))
  pctErrorI = abs((I_2 - I_syn_2)/I_syn_2)


<Figure size 432x288 with 0 Axes>

In [3]:
# Print stats related to the differences between the voltage computed via Keivan's model and CARLsim
print(maxErrorV)
print(minErrorV)
print(meanErrorV)
print(medianErrorV)
print(mismatchV)

0.0010360326336681467
1.2928933573394543e-08
0.0003057951979061734
0.0003578259458574965
0.0003130776190437925


In [4]:
# Print stats related to the differences between the current computed via Keivan's model and CARLsim
print(maxErrorI)
print(minErrorI)
print(meanErrorI)
print(medianErrorI)
print(mismatchI)

1.1583349103049558
5.181246986581737e-07
0.007962890744922471
0.0024832550344909762
0.0037698842660146663
