In [None]:
import json
import pickle
import numpy as np
import matplotlib.pyplot as plt
from brian2 import *

%matplotlib inline

In [None]:
cell_types = ['RS', 'IB', 'FS']
config_files = [f'../DL/configs/{cell_type}.json' for cell_type in cell_types]
parameters = {cell_type: json.load(open(config_file)) for cell_type,config_file in
              zip(cell_types,config_files)}

In [None]:
neuron_groups = {}
state_monitors = {}
spike_monitors = {}

net = Network()

nrn_eqs = """
dV/dt = (-gL_{0} * (V - EL_{0}) + gL_{0} * DeltaT_{0} * 
         exp((V - VT_{0}) / DeltaT_{0}) + I_{0} - w) / Cm_{0} : volt (unless refractory)
dw/dt = (a_{0} * (V - EL_{0}) - w) / tauw_{0} : amp
"""

for cell_type in cell_types:
    params = parameters[cell_type]
    P = {
        'Cm': (params['C_m'], pF),
        'gL': (params['g_L'], nS),
        'EL': (params['E_L'], mV),
        'VT': (params['V_th'], mV),
        'Vpeak': (params['V_peak'], mV),
        'Vreset': (params['V_reset'], mV),
        'DeltaT': (params['Delta_T'], mV),
        'tauw': (params['tau_w'], ms),
        'a': (params['a'], nS),
        'b': (params['b'], pA),
        'tarp': (params['tau_arp'] if 'tau_arp' in params else 0, ms),
        'I': (params['I_e'], pA),
    }
    for par_name,(value,unit) in P.items():
        exec(f'{par_name}_{cell_type} = {value} * {unit}')
    eqs = nrn_eqs.format(cell_type)
    group = NeuronGroup(2, model=eqs,
                        threshold=f'V>Vpeak_{cell_type}',
                        reset=f'V=Vreset_{cell_type}; w+=b_{cell_type}',
                        refractory=f'tarp_{cell_type}',
                        method='exponential_euler')
    group.V = params['E_L'] * mV
    group.w = 0 * pA
    neuron_groups[cell_type] = group
    
    state_monitors[cell_type] = StateMonitor(group, ['V', 'w'], record=True)
    spike_monitors[cell_type] = SpikeMonitor(group)

    net.add(neuron_groups[cell_type])
    net.add(state_monitors[cell_type])
    net.add(spike_monitors[cell_type])

In [None]:
stim_start, stim_dur = 100 * ms, 200 * ms
stim_stop = stim_start + stim_dur

for cell_type in cell_types:
    exec(f'I_{cell_type} = 0 * pA')
net.run(stim_start)
for cell_type in cell_types:
    exec(f'I_{cell_type} = parameters["{cell_type}"]["I_e"] * pA')
net.run(stim_dur)
for cell_type in cell_types:
    exec(f'I_{cell_type} = 0 * pA')
net.run(stim_start)

In [None]:
time = state_monitors[cell_types[0]].t / ms
V = {cell_type: state_monitors[cell_type].V / mV for cell_type in cell_types}
w = {cell_type: state_monitors[cell_type].w / pA for cell_type in cell_types}

In [None]:
fig,ax = plt.subplots(2, 1, sharex=True)
cmap = plt.get_cmap('jet', len(cell_types))
for i,cell_type in enumerate(cell_types):
    ax[0].plot(time, V[cell_type][0], color=cmap(i), lw=1)
    ax[1].plot(time, w[cell_type][0], color=cmap(i), lw=1, label=cell_type)
for axx in ax:
    axx.grid(which='major', axis='both', color=[.6,.6,.6], ls=':', lw=0.5)
    for side in 'right','top':
        axx.spines[side].set_visible(False)
ax[0].set_xlim([stim_start/ms - 20, stim_stop/ms + 100])
ax[-1].set_xlabel('Time (ms)')
ax[0].set_ylabel('V (mV)')
ax[1].set_ylabel('w (pA)')
ax[1].legend(loc='upper right')
fig.tight_layout()