In [None]:
import os
import matplotlib
%matplotlib nbagg
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.signal import gaussian
from scipy.io import savemat
from scipy.io import loadmat
import brian2 as b2
import multiprocessing as mp

## Set up HPC inputs

In [None]:
def fsigm( x, amp, center, slope ):
    shift = b2.log(amp/(amp-(b2.sqrt(2.0)/2.0))-1.0)*slope
    return amp-amp/(1.0+b2.exp((x+shift-center)/slope))

def fsigm_down( x, amp, center, slope ):
    shift = -slope * b2.log(amp/(b2.sqrt(2.0)/2.0)-1.0)
    return amp / (1.0+b2.exp((x-shift-center)/slope))

In [None]:
tstop=3000.0*b2.ms
poisson_resolution = 10.0*b2.ms
# ext_ramp = gaussian(int(tstop/poisson_resolution), std=int(tstop/poisson_resolution)/10.0)*1000.0
ext_ramp = fsigm(b2.arange(int(tstop/poisson_resolution)), 1.0, tstop/3.0/poisson_resolution, 1.0)
ext_ramp[int(tstop/poisson_resolution/2.0):] = fsigm_down(
    b2.arange(int(tstop/poisson_resolution/2.0)), 1.0, tstop/6.0/poisson_resolution, 1.0)
ext_ramp *= 1000.0
b2.figure()
b2.plot(np.arange(ext_ramp.shape[0])*poisson_resolution/b2.ms, ext_ramp)


## Set up network

In [None]:
PLOT = False

# Passive & synaptic properties
area = 10000*b2.umetre**2

Cm = 1*b2.ufarad*b2.cm**-2 * area

pyr_gL = 4e-4*b2.siemens/b2.cm**2 * area
som_gL = 2e-4*b2.siemens/b2.cm**2 * area
parv_gL = 2e-4*b2.siemens/b2.cm**2 * area
print(pyr_gL, som_gL, parv_gL)
print(1.0/pyr_gL, 1.0/som_gL, 1.0/parv_gL)

EL = -70*b2.mV
E_ampa = 0*b2.mV
E_gaba = -75*b2.mV
tau_ampa = 2*b2.ms
tau_gaba = 1*b2.ms


pyr_eqs='''
dv/dt = 1/Cm * Im : volt
Im = pyr_gL*(EL-v) + Is : amp
Is = g_hpc_pyr*(E_ampa-v) + g_som_pyr*(E_gaba-v) + g_parv_pyr*(E_gaba-v): amp
dg_hpc_pyr/dt = -g_hpc_pyr/tau_ampa : siemens
dg_som_pyr/dt = -g_som_pyr/tau_gaba : siemens
dg_parv_pyr/dt = -g_parv_pyr/tau_gaba : siemens
'''
som_eqs='''
dv/dt = 1/Cm * Im : volt
Im = som_gL*(EL-v) + Is : amp
Is = g_hpc_som*(E_ampa-v) + g_som_som*(E_gaba-v) + g_parv_som*(E_gaba-v): amp
dg_hpc_som/dt = -g_hpc_som/tau_ampa : siemens
dg_som_som/dt = -g_som_som/tau_gaba : siemens
dg_parv_som/dt = -g_parv_som/tau_gaba : siemens
'''
parv_eqs='''
dv/dt = 1/Cm * Im : volt
Im = parv_gL*(EL-v) + Is : amp
Is = g_hpc_parv*(E_ampa-v) + g_som_parv*(E_gaba-v) + g_parv_parv*(E_gaba-v): amp
dg_hpc_parv/dt = -g_hpc_parv/tau_ampa : siemens
dg_som_parv/dt = -g_som_parv/tau_gaba : siemens
dg_parv_parv/dt = -g_parv_parv/tau_gaba : siemens
'''
V_reset = -70*b2.mV
whss = (2, 24)

def run_simulation(parpair, with_monitors=False):
    # pid = os.getpid()
    # directory = f"standalone{pid}"
    # b2.set_device('cpp_standalone', directory=directory)

    seed, whs = parpair
    pyr = b2.NeuronGroup(
        N=1, model=pyr_eqs,
        threshold='v > -45*mV', refractory='v > -55*mV', reset='v = V_reset', method='exponential_euler')
    pyr.v = EL

    som = b2.NeuronGroup(
        N=1, model=som_eqs,
        threshold='v > -45*mV', refractory='v > -55*mV', reset='v = V_reset', method='exponential_euler')
    som.v = EL

    parv = b2.NeuronGroup(
        N=1, model=parv_eqs,
        threshold='v > -45*mV', refractory='v > -55*mV', reset='v = V_reset', method='exponential_euler')
    parv.v = EL

    # Regular inputs
    varying_stimulus = b2.TimedArray((100.0+ext_ramp)*b2.Hz, dt=poisson_resolution)
    varying_poisson = b2.PoissonGroup(2, rates='varying_stimulus(t)')
    pyr_varying_stimulus = b2.TimedArray((10.0+0.1*ext_ramp)*b2.Hz, dt=poisson_resolution)
    pyr_varying_poisson = b2.PoissonGroup(1, rates='varying_stimulus(t)')

    # Synapses
    w_hpc_pyr = 1*b2.nS # 2.0
    S_hpc_pyr = b2.Synapses(pyr_varying_poisson, pyr, on_pre='g_hpc_pyr += w_hpc_pyr')
    S_hpc_pyr.connect(i=0, j=pyr)

    w_hpc_som = 2*b2.nS
    S_hpc_som = b2.Synapses(varying_poisson, som, on_pre='g_hpc_som += w_hpc_som')
    S_hpc_som.connect(i=0, j=som)

    w_hpc_parv = 25*b2.nS # 25
    S_hpc_parv = b2.Synapses(varying_poisson, parv, on_pre='g_hpc_parv += w_hpc_parv')
    S_hpc_parv.connect(i=1, j=parv)

    w_parv_pyr = 27*b2.nS # 50
    S_parv_pyr = b2.Synapses(parv, pyr, on_pre='g_parv_pyr += w_parv_pyr')
    S_parv_pyr.connect(i=parv, j=pyr)

    w_som_pyr = 5*b2.nS
    S_som_pyr = b2.Synapses(som, pyr, on_pre='g_som_pyr += w_som_pyr')
    S_som_pyr.connect(i=som, j=pyr)

    w_som_parv = 100*b2.nS
    S_som_parv = b2.Synapses(som, parv, on_pre='g_som_parv += w_som_parv')
    S_som_parv.connect(i=som, j=parv)

    # Monitors
    monitors = [
        b2.StateMonitor(pyr, 'v', record=[0]),
    ]
    if with_monitors:
        monitors += [
            b2.StateMonitor(som, 'v', record=[0]),
            b2.StateMonitor(parv, 'v', record=[0]),
            b2.SpikeMonitor(varying_poisson),
            b2.SpikeMonitor(pyr_varying_poisson),
            b2.SpikeMonitor(som),
            b2.SpikeMonitor(parv)
        ]
    net = b2.Network(b2.collect())
    net.add(monitors)
    b2.seed(seed)
    w_hpc_som = whs*b2.nS
    # net.store()
    net.run(tstop, report='text')
    # net.restore()
    # device.reinit()
    
    retdict = {
        'seed': seed,
        'whs': whs,
        'tstore': monitors[0].t/b2.ms,
        'pyr_vs': monitors[0][0].v/b2.mV,
    }
    if with_monitors:
        retdict.update({
            'som_vs': monitors[1][0].v/b2.mV,
            'parv_vs': monitors[2][0].v/b2.mV,
            'hpc_IN': monitors[3].t/b2.ms,
            'hpc_pyr': monitors[4].t/b2.ms,
            'som_spikes': monitors[5].t/b2.ms,
            'parv_spikes': monitors[6].t/b2.ms
        })
    return retdict

## Run multiple simulations with different seeds

In [None]:
parlist = [(seed, whs) for seed in range(4040, 4140) for whs in whss]
with mp.Pool(16) as pool:
    results = pool.map(run_simulation, parlist)


In [None]:
b2.figure()
b2.subplot(111)
pyr_vs = {whs: [] for whs in whss}
for entry in results:
    pyr_vs[entry['whs']].append(entry['pyr_vs'])
    tstore = entry['tstore']
b2.plot(tstore, b2.mean(pyr_vs[whss[0]], axis=0), alpha=0.5)
b2.plot(tstore, b2.mean(pyr_vs[whss[1]], axis=0), alpha=0.5)


In [None]:
if not os.path.exists('dat'):
    os.makedirs('dat')
savemat(
    os.path.join('dat', 'results_brian2'),
    {'tstore': tstore,
     'pyr_vs_{0:d}'.format(whss[0]): pyr_vs[whss[0]],
     'pyr_vs_{0:d}'.format(whss[1]): pyr_vs[whss[1]],})

## Plot manuscript figure

In [None]:
pyr_vs_figure = {}
som_vs_figure = {}
parv_vs_figure = {}
hpc_IN_figure = {}
hpc_pyr_figure = {}
som_spikes_figure = {}
parv_spikes_figure = {}

for whs in whss:
    results = run_simulation((4040, whs), with_monitors = True)
    pyr_vs_figure[whs] = results['pyr_vs']
    som_vs_figure[whs] = results['som_vs']
    parv_vs_figure[whs] = results['parv_vs']
    hpc_IN_figure[whs] = results['hpc_IN']
    hpc_pyr_figure[whs] = results['hpc_pyr']
    som_spikes_figure[whs] = results['som_spikes']
    parv_spikes_figure[whs] = results['parv_spikes']
    tstore = results['tstore']

In [None]:
results = loadmat(
    os.path.join('dat', 'results_brian2'))

In [None]:
b2.figure()
b2.subplot(111)
b2.plot(results['tstore'][0], b2.mean(results['pyr_vs_{0:d}'.format(whss[0])], axis=0), alpha=0.5)
b2.plot(results['tstore'][0], b2.mean(results['pyr_vs_{0:d}'.format(whss[1])], axis=0), alpha=0.5)

In [None]:
def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])

In [None]:
plt.figure(figsize=(6,15))
axes = {}
for nw, whs in enumerate(whss):
    col = nw%2
    t_ext_ramp = np.arange(0, ext_ramp.shape[0])*poisson_resolution/b2.ms
    npanel = 1+col
    axes[npanel] = plt.subplot(6, 2, npanel)
    axes[npanel].plot(t_ext_ramp, ext_ramp, clip_on=False)
    adjust_spines(axes[npanel], ['left'])
    a=axes[npanel].get_yticks().tolist()
    for n in range(len(a)):
        a[n] = ''
    a[1] = 'Offset'
    a[-2] = 'Max'
    axes[npanel].set_yticklabels(a)
    if col == 0:
        axes[npanel].set_ylabel('HPC spike rate (AU)')

    npanel = 3+col
    axes[npanel] = plt.subplot(6, 2, npanel)
    axes[npanel].plot(
        hpc_pyr_figure[whs],
        np.ones(hpc_pyr_figure[whs].shape)+np.random.uniform(-0.4, 0.4, hpc_pyr_figure[whs].shape),
        '|', clip_on=False)
    axes[npanel].plot(
        hpc_IN_figure[whs],
        np.zeros(hpc_IN_figure[whs].shape)+np.random.uniform(-0.4, 0.4, hpc_IN_figure[whs].shape),
        '|', clip_on=False)
    adjust_spines(axes[npanel], [])
    if col == 0:
        axes[npanel].set_ylabel('HPC spikes')

    npanel = 5+col
    axes[npanel] = plt.subplot(6, 2, npanel)
    axes[npanel].plot(tstore, som_vs_figure[whs], 'k', clip_on=False)
    for som_spike in som_spikes_figure[whs]:
        axes[npanel].plot([som_spike, som_spike], [-45, 40], 'k')
    axes[npanel].plot(
        som_spikes_figure[whs],
        np.ones(som_spikes_figure[whs].shape) * 80.0 + np.random.uniform(-20.0, 20.0, som_spikes_figure[whs].shape),
            '|', clip_on=False)
    adjust_spines(axes[npanel], ['left'])
    if col == 0:
        axes[npanel].set_ylabel(r'SOM $V_m$ (mV)')

    npanel = 7+col
    axes[npanel] = plt.subplot(6, 2, npanel)
    axes[npanel].plot(tstore, parv_vs_figure[whs], 'k', clip_on=False)
    for parv_spike in parv_spikes_figure[whs]:
        axes[npanel].plot([parv_spike, parv_spike], [-45, 40], 'k')
    axes[npanel].plot(
        parv_spikes_figure[whs],
        np.ones(parv_spikes_figure[whs].shape) * 80.0 + np.random.uniform(-20.0, 20.0, parv_spikes_figure[whs].shape),
            '|', clip_on=False)
    adjust_spines(axes[npanel], ['left'])
    if col == 0:
        axes[npanel].set_ylabel(r'PV $V_m$ (mV)')

    npanel = 9+col
    axes[npanel] = plt.subplot(6, 2, npanel)
    axes[npanel].plot(tstore, pyr_vs_figure[whs], clip_on=False)
    adjust_spines(axes[npanel], ['left'])
    if col == 0:
        axes[npanel].set_ylabel(r'PYR $V_m$ (mV)')

    npanel = 11+col
    axes[npanel] = plt.subplot(6, 2, npanel)
    axes[npanel].plot(results['tstore'][0], b2.mean(results['pyr_vs_{0:d}'.format(whs)], axis=0), clip_on=False)
    adjust_spines(axes[npanel], ['left', 'bottom'])
    if col == 0:
        axes[npanel].set_ylabel(r'Mean PYR $V_m$ (mV)')

for npanel in range(2, 13, 2):
    axes[1].get_shared_x_axes().join(axes[1], axes[npanel-1])
    axes[2].get_shared_x_axes().join(axes[2], axes[npanel])
    axes[npanel-1].get_shared_y_axes().join(axes[npanel-1], axes[npanel])
    if npanel < 11:
        adjust_spines(axes[npanel-1], ['left'])
        adjust_spines(axes[npanel], [])
    else:
        adjust_spines(axes[npanel-1], ['left', 'bottom'])
        adjust_spines(axes[npanel], ['bottom'])
sns.despine()
plt.savefig('scheme_brian2.pdf')
plt.savefig('scheme_brian2.png')
