In [1]:
# to add the spiking model codebase to the path
import sys
sys.path.append('..')

from brian2 import *
from params import paramsJercog as p

import numpy as np

from network import JercogEphysNetwork
from results import ResultsEphys

%matplotlib
import matplotlib.pyplot as plt

INFO       Cache size for target "cython": 11811 MB.
You can call "clear_cache('cython')" to delete all files from the cache or manually delete files in the "C:\Users\mikejseay\.cython\brian_extensions" directory. [brian2]


Using matplotlib backend: Qt5Agg


In [4]:
p['setUpFRExc'] = 5 * Hz
p['setUpFRInh'] = 14 * Hz

p['threshExc'] = 130 * pA  # this is a bad approximation and replaced later by a precise test
p['threshInh'] = 140 * pA

p['gainExc'] = 0.49 * Hz / pA
p['gainInh'] = 1.46 * Hz / pA

In [5]:
p['recordStateVariables'] = ['v', 'sE', 'sI']
p['duration'] = 1 * second
p['nUnits'] = 1000
p['propConnect'] = 0.5

nExc = 400
nInh = 100
simType = 'poisson'  # poisson or fixed

- simulate 400 poisson processes that represent presynaptic E Units at E_set
- simulate 100 poisson processes that represent presynaptic I units at I_set
- connect them all to a single real IAF E unit and I unit
- set the synaptic weights based on the result of a cross-homeo training that converged to the set points
- measure the excitatory and inhibitory currents

In [6]:
defaultclock.dt = p['dt']
p['nInh'] = int(p['propInh'] * p['nUnits'])
p['nExc'] = int(p['nUnits'] - p['nInh'])
p['nExcSpikemon'] = int(p['nExc'] * p['propSpikemon'])
p['nInhSpikemon'] = int(p['nInh'] * p['propSpikemon'])
p['nIncInh'] = int(p['propConnect'] * p['propInh'] * p['nUnits'])
p['nIncExc'] = int(p['propConnect'] * (1 - p['propInh']) * p['nUnits'])

N = Network()

# NOTICE NO NOISE

unitModel = '''
                dv/dt = (gl * (eLeak - v) - iAdapt +
                         sE - sI) / Cm : volt (unless refractory)
                diAdapt/dt = -iAdapt / tauAdapt : amp

                dsE/dt = (-sE + uE) / tauFallE : amp
                duE/dt = -uE / tauRiseE : amp
                dsI/dt = (-sI + uI) / tauFallI : amp
                duI/dt = -uI / tauRiseI : amp
                
                eLeak : volt
                vReset : volt
                vThresh : volt
                betaAdapt : amp * second
                gl : siemens
                Cm : farad
                '''

resetCode = '''
        v = vReset
        iAdapt += betaAdapt / tauAdapt 
        '''

threshCode = 'v >= vThresh'

unitsExc = NeuronGroup(N=1, model=unitModel, method=p['updateMethod'], threshold=threshCode, reset=resetCode,
                       refractory=p['refractoryPeriodExc'], clock=defaultclock)
unitsInh = NeuronGroup(N=1, model=unitModel, method=p['updateMethod'], threshold=threshCode, reset=resetCode,
                       refractory=p['refractoryPeriodInh'], clock=defaultclock)

unitsExc.v = p['eLeakExc']
unitsExc.vReset = p['vResetExc']
unitsExc.vThresh = p['vThreshExc']
unitsExc.betaAdapt = p['betaAdaptExc']
unitsExc.eLeak = p['eLeakExc']
unitsExc.Cm = p['membraneCapacitanceExc']
unitsExc.gl = p['gLeakExc']

unitsInh.v = p['eLeakInh']
unitsInh.vReset = p['vResetInh']
unitsInh.vThresh = p['vThreshInh']
unitsInh.betaAdapt = p['betaAdaptInh']
unitsInh.eLeak = p['eLeakInh']
unitsInh.Cm = p['membraneCapacitanceInh']
unitsInh.gl = p['gLeakInh']

N.add(unitsExc, unitsInh)

In [7]:
if simType == 'poisson':  # make the Poisson spikers
    
    poissonExc = PoissonGroup(N=nExc, rates=p['setUpFRExc'], clock=defaultclock)
    poissonInh = PoissonGroup(N=nInh, rates=p['setUpFRInh'], clock=defaultclock)
    
elif simType == 'fixed':  # instead, make fixed rate that tiles the time period
    
    numSpikesExc = int(np.round(p['setUpFRExc'] * p['duration']))
    numSpikesInh = int(np.round(p['setUpFRInh'] * p['duration']))

    fixedIndicesExc = np.zeros((numSpikesExc, ), )
    fixedIndicesInh = np.zeros((numSpikesInh, ), )

    spikeTimesExc = np.linspace(0, p['duration'] / second, numSpikesExc + 1) * second
    spikeTimesExc = spikeTimesExc[:-1]

    spikeTimesInh = np.linspace(0, p['duration'] / second, numSpikesInh + 1) * second
    spikeTimesInh = spikeTimesInh[:-1]

    poissonExc = SpikeGeneratorGroup(N=1, indices=fixedIndicesExc, times=spikeTimesExc, clock=defaultclock)
    poissonInh = SpikeGeneratorGroup(N=1, indices=fixedIndicesInh, times=spikeTimesInh, clock=defaultclock)

N.add(poissonExc, poissonInh)

In [8]:
# make the synapses

onPreStrings = ('uE_post += jEE / tauRiseEOverMS',
                'uE_post += jIE / tauRiseEOverMS',
                'uI_post += jEI / tauRiseIOverMS',
                'uI_post += jII / tauRiseIOverMS',)

tauRiseEOverMS = p['tauRiseExc'] / ms
tauRiseIOverMS = p['tauRiseInh'] / ms
vTauExcOverMS = p['membraneCapacitanceExc'] / p['gLeakExc'] / ms
vTauInhOverMS = p['membraneCapacitanceInh'] / p['gLeakInh'] / ms

synapsesEE = Synapses(model='jEE: amp',
                     source=poissonExc,
                     target=unitsExc,
                     on_pre=onPreStrings[0])
synapsesIE = Synapses(model='jIE: amp',
                     source=poissonExc,
                     target=unitsInh,
                     on_pre=onPreStrings[1])
synapsesEI = Synapses(model='jEI: amp',
                     source=poissonInh,
                     target=unitsExc,
                     on_pre=onPreStrings[2])
synapsesII = Synapses(model='jII: amp',
                     source=poissonInh,
                     target=unitsInh,
                     on_pre=onPreStrings[3])

synapsesEE.connect()
synapsesIE.connect()
synapsesEI.connect()
synapsesII.connect()

In [9]:
useBetterWeights = True

if useBetterWeights:
    jEEmult = 1.0249649370098968
    jIEmult = 0.9913780811668145
    jEImult = 0.9750494640455015
    jIImult = 1.0108023944074438
else:
    jEEmult = 1
    jIEmult = 1
    jEImult = 1
    jIImult = 1

usejEE = p['jEE'] / p['nIncExc'] * vTauExcOverMS * jEEmult
usejIE = p['jIE'] / p['nIncExc'] * vTauInhOverMS * jIEmult
usejEI = p['jEI'] / p['nIncInh'] * vTauExcOverMS * jEImult
usejII = p['jII'] / p['nIncInh'] * vTauInhOverMS * jIImult

synapsesEE.jEE = usejEE
synapsesIE.jIE = usejIE
synapsesEI.jEI = usejEI
synapsesII.jII = usejII

In [10]:
print(usejEE, usejIE, usejEI, usejII, )

143.49509118 pA 123.92226015 pA 136.50692497 pA 101.08023944 pA


In [11]:
N.add(synapsesEE, synapsesIE, synapsesEI, synapsesII)

# create monitors
spikeMonExc = SpikeMonitor(unitsExc[:p['nExcSpikemon']])
spikeMonInh = SpikeMonitor(unitsInh[:p['nInhSpikemon']])
stateMonExc = StateMonitor(unitsExc,
                           p['recordStateVariables'],
                           record=p['indsRecordStateExc'],
                           clock=defaultclock)
stateMonInh = StateMonitor(unitsInh,
                           p['recordStateVariables'],
                           record=p['indsRecordStateInh'],
                           clock=defaultclock)
N.add(spikeMonExc, spikeMonInh, stateMonExc, stateMonInh)

In [12]:
tauRiseE = p['tauRiseExc']
tauFallE = p['tauFallExc']
tauRiseI = p['tauRiseInh']
tauFallI = p['tauFallInh']
tauAdapt = p['adaptTau']
tauRiseEOverMS = p['tauRiseExc'] / ms
tauRiseIOverMS = p['tauRiseInh'] / ms
vTauExcOverMS = p['membraneCapacitanceExc'] / p['gLeakExc'] / ms
vTauInhOverMS = p['membraneCapacitanceInh'] / p['gLeakInh'] / ms

N.run(p['duration'], report=p['reportType'], report_period=p['reportPeriod'], profile=p['doProfile'])

Starting simulation at t=0. s for a duration of 1. s
1. s (100%) simulated in 1s


In [13]:
useDType = np.single

timeArray = np.arange(0, float(p['duration']), float(p['dt']), dtype=useDType)

spikeMonExcT = np.array(spikeMonExc.t, dtype=useDType)
spikeMonExcI = np.array(spikeMonExc.i, dtype=useDType)
spikeMonInhT = np.array(spikeMonInh.t, dtype=useDType)
spikeMonInhI = np.array(spikeMonInh.i, dtype=useDType)
stateMonExcV = np.array(stateMonExc.v / mV, dtype=useDType)
stateMonInhV = np.array(stateMonInh.v / mV, dtype=useDType)

stateMonExcSE = np.array(stateMonExc.sE / pA, dtype=useDType)
stateMonInhSE = np.array(stateMonInh.sE / pA, dtype=useDType)
stateMonExcSI = np.array(stateMonExc.sI / pA, dtype=useDType)
stateMonInhSI = np.array(stateMonInh.sI / pA, dtype=useDType)

In [14]:
# plots

f, ax = plt.subplots()
ax.plot(timeArray, stateMonExcV[0, :])
ax.plot(timeArray, stateMonInhV[0, :])

f, ax = plt.subplots()
ax.plot(timeArray, stateMonExcSE[0, :])
ax.plot(timeArray, stateMonExcSI[0, :])

f, ax = plt.subplots()
ax.plot(timeArray, stateMonExcSE[0, :] - stateMonExcSI[0, :])

[<matplotlib.lines.Line2D at 0x265b10c8dc8>]

In [15]:
# this one works better...
sESumExc = (stateMonExc.sE).sum() * p['dt'] * 1e3
sESumInh = (stateMonInh.sE).sum() * p['dt'] * 1e3
sISumExc = (stateMonExc.sI).sum() * p['dt'] * 1e3
sISumInh = (stateMonInh.sI).sum() * p['dt'] * 1e3

In [16]:
print(sESumExc, sESumInh, sISumExc, sISumInh, )

276.66296644 nC 238.92594386 nC 188.63632144 nC 139.680859 nC


In [17]:
print(usejEE, usejIE, usejEI, usejII, )

143.49509118 pA 123.92226015 pA 136.50692497 pA 101.08023944 pA


In [18]:
print(spikeMonExcT.size / p['duration'], spikeMonInhT.size / p['duration'])

8. Hz 18. Hz


In [19]:
doEmpiricalTest = True

In [20]:
if doEmpiricalTest:
    pForEphys = p.copy()
    if 'iExtRange' not in pForEphys:
        pForEphys['recordStateVariables'] = ['v', ]
        pForEphys['propInh'] = 0.5
        pForEphys['duration'] = 1 * second
        pForEphys['iExtRange'] = np.linspace(0, .3, 3001) * nA
    JEN = JercogEphysNetwork(pForEphys)
    JEN.build_classic()
    JEN.run()
    RE = ResultsEphys()
    RE.init_from_network_object(JEN)
    RE.calculate_thresh_and_gain()
    
    p['threshExc'] = RE.threshExc
    p['threshInh'] = RE.threshInh
    p['gainExc'] = RE.gainExc
    p['gainInh'] = RE.gainInh

Starting simulation at t=0. s for a duration of 1. s
1. s (100%) simulated in 1s


In [22]:
f, ax = plt.subplots(2, 2)
RE.calculate_and_plot(f, ax)

IndexError: too many indices for array

In [41]:
f, ax = plt.subplots(2, 2)

I_ext_range = RE.p['iExtRange']
ExcData = RE.spikeMonExcC / RE.p['duration']
InhData = RE.spikeMonInhC / RE.p['duration']

I_index_for_ISI = int(len(I_ext_range) * .9) - 1

# reconstruct time
stateMonT = np.arange(0, float(RE.p['duration']), float(RE.p['dt']))

# might be useful...
# ax.axhline(useThresh, color=useColor, linestyle=':')  # Threshold
# ax.axhline(RE.p['eLeak' + unitType] / mV, color=useColor, linestyle='--')  # Resting

useThresh = RE.p['vThreshExc'] / mV
ax[0, 0].plot(stateMonT, RE.stateMonExcV[I_index_for_ISI, :], color='g')
ax[0, 0].vlines(RE.spikeTrainsExc[()][I_index_for_ISI], useThresh, useThresh + 40, color='g', lw=.3)
ax[0, 0].set(xlim=(0., RE.p['duration'] / second), ylabel='mV', xlabel='Time (s)')

useThresh = RE.p['vThreshInh'] / mV
ax[0, 1].plot(stateMonT, RE.stateMonInhV[I_index_for_ISI, :], color='g')
ax[0, 1].vlines(RE.spikeTrainsInh[()][I_index_for_ISI], useThresh, useThresh + 40, color='g', lw=.3)
ax[0, 1].set(xlim=(0., RE.p['duration'] / second), ylabel='mV', xlabel='Time (s)')

ax[1, 0].plot(I_ext_range * 1e9, ExcData, label='Exc')
ax[1, 0].plot(I_ext_range * 1e9, InhData, label='Inh')
ax[1, 0].axvline(float(I_ext_range[I_index_for_ISI]) * 1e9,
                 label='displayed value', color='grey', ls='--')
ax[1, 0].set_xlabel('Current (nA)')
ax[1, 0].set_ylabel('Firing Rate (Hz)')
ax[1, 0].legend()

ISIExc = diff(RE.spikeTrainsExc[()][I_index_for_ISI])
ISIInh = diff(RE.spikeTrainsInh[()][I_index_for_ISI])
ax[1, 1].plot(arange(1, len(ISIExc) + 1), ISIExc * 1000, label='Exc')
ax[1, 1].plot(arange(1, len(ISIInh) + 1), ISIInh * 1000, label='Inh')
ax[1, 1].set_xlabel('ISI number')
ax[1, 1].set_ylabel('ISI (ms)')
ax[1, 1].legend()

f.tight_layout()
f.subplots_adjust(top=.9)

In [40]:
RE.spikeTrainsExc[()]

{0: array([], dtype=float64) * second,
 1: array([], dtype=float64) * second,
 2: array([], dtype=float64) * second,
 3: array([], dtype=float64) * second,
 4: array([], dtype=float64) * second,
 5: array([], dtype=float64) * second,
 6: array([], dtype=float64) * second,
 7: array([], dtype=float64) * second,
 8: array([], dtype=float64) * second,
 9: array([], dtype=float64) * second,
 10: array([], dtype=float64) * second,
 11: array([], dtype=float64) * second,
 12: array([], dtype=float64) * second,
 13: array([], dtype=float64) * second,
 14: array([], dtype=float64) * second,
 15: array([], dtype=float64) * second,
 16: array([], dtype=float64) * second,
 17: array([], dtype=float64) * second,
 18: array([], dtype=float64) * second,
 19: array([], dtype=float64) * second,
 20: array([], dtype=float64) * second,
 21: array([], dtype=float64) * second,
 22: array([], dtype=float64) * second,
 23: array([], dtype=float64) * second,
 24: array([], dtype=float64) * second,
 25: array

In [23]:
# average net positive current to sustain set FR

In [24]:
p['threshExc'] = RE.threshExc
p['threshInh'] = RE.threshInh
p['gainExc'] = RE.gainExc
p['gainInh'] = RE.gainInh

In [25]:
print(RE.threshExc, RE.threshInh, RE.gainExc, RE.gainInh, )

124.1 pA 135.1 pA 0.34659091 Hz/pA 1.50909091 Hz/pA


In [22]:
# rough estimates
sumExcInputToExc = nExc * usejEE * float(p['setUpFRExc']) * p['duration']
sumExcInputToInh = nExc * usejIE * float(p['setUpFRExc']) * p['duration']

In [23]:
# rough estimates
sumInhInputToExc = nInh * usejEI * float(p['setUpFRInh']) * p['duration']
sumInhInputToInh = nInh * usejII * float(p['setUpFRInh']) * p['duration']

In [24]:
print(sumExcInputToExc, sumExcInputToInh, sumInhInputToExc, sumInhInputToInh)

286.99018236 nC 247.84452029 nC 191.10969495 nC 141.51233522 nC


In [25]:
print(sESumExc, sESumInh, sISumExc, sISumInh)

289.86654356 nC 250.32854381 nC 193.64915285 nC 143.3927454 nC


In [24]:
# my calculation of the total excitatory charge is accurate!
# now we compare the excitatory to inhibitory currents to get an estimate of the balancing inhibitory weights

In [25]:
jEESum = nExc * usejEE * float(p['setUpFRExc'])
jIESum = nExc * usejIE * float(p['setUpFRExc'])

# NOTE: these are observed empirically!!!!
jEISumgood = sISumExc / p['duration']
jIISumgood = sISumInh / p['duration']

In [94]:
usejEI

136.50692497 * pamp

In [26]:
# additive model: I = E + b
# solve for b as b = I - E

sustainingWeightImbalanceExc = jEISumgood - jEESum
sustainingWeightImbalanceInh = jIISumgood - jIESum

# now calculate I = E + b

proposedjEISum = (jEESum + sustainingWeightImbalanceExc)
proposedjIISum = (jIESum + sustainingWeightImbalanceInh)
proposedjEI = proposedjEISum / nInh / float(p['setUpFRInh'])
proposedjII = proposedjIISum / nInh / float(p['setUpFRInh'])

In [27]:
print(sustainingWeightImbalanceExc, sustainingWeightImbalanceInh)

-76.47829669 nA -84.89833122 nA


In [28]:
print(proposedjEI, proposedjII)

136.70879393 pA 101.22971876 pA


In [29]:
print(usejEI, usejII)

136.50692497 pA 101.08023944 pA


In [30]:
# multiplicative model: I = m * E

# solve for m as m = I / E
sustainingWeightRatioExc = jEISumgood / jEESum
sustainingWeightRatioInh = jIISumgood / jIESum

# claculate I = m * E
proposedjEISum = (jEESum * sustainingWeightRatioExc)
proposedjIISum = (jIESum * sustainingWeightRatioInh)
proposedjEI = proposedjEISum / nInh / float(p['setUpFRInh'])
proposedjII = proposedjIISum / nInh / float(p['setUpFRInh'])

In [31]:
print(sustainingWeightRatioExc, sustainingWeightRatioInh)

0.6668949784973939 0.5718165812278555


In [32]:
print(proposedjEI, proposedjII)

136.70879393 pA 101.22971876 pA


In [33]:
print(usejEI, usejII)

136.50692497 pA 101.08023944 pA


In [83]:
# linear model with bias: I = m * E + b

# m is known (or is it?)...
# given enough realizations of E and I with balanced E I to generate Ups with the set points,
# could do linear fit to find out if this is a good estimate of the actual slope
slopeExc = p['setUpFRExc'] / p['setUpFRInh']
slopeInh = p['setUpFRExc'] / p['setUpFRInh']

# assuming m is known, solve for b as b = I - m * E
offsetExc = jEISumgood - jEESum * p['setUpFRExc'] / p['setUpFRInh']
offsetInh = jIISumgood - jIESum * p['setUpFRExc'] / p['setUpFRInh']

# calculate I as I = m * E + b
proposedjEISum = jEESum * slopeExc + offsetExc
proposedjIISum = jIESum * slopeInh + offsetInh
proposedjEI = proposedjEISum / nInh / float(p['setUpFRInh'])
proposedjII = proposedjIISum / nInh / float(p['setUpFRInh'])

In [84]:
print(slopeExc, slopeInh, offsetExc, offsetInh)

0.35714285714285715 0.35714285714285715 71.11665424 nA 42.56456493 nA


In [85]:
print(proposedjEI, proposedjII)

136.70879393 pA 101.22971876 pA


In [86]:
print(usejEI, usejII)

136.50692497 pA 101.08023944 pA


In [38]:
usejEE

143.49509118 * pamp

In [44]:
# how do the purely divisive and slope models compare?

In [47]:
xVals = np.linspace(0, 200, 1000)  # pretend in pA
yValsDiv = sustainingWeightRatioExc * xVals
yValsSlope = offsetExc / nA + slopeExc * xVals

In [53]:
f, ax = plt.subplots()
ax.plot(xVals, yValsDiv, label='I = m * E')
ax.plot(xVals, yValsSlope, label='I = m * E + b (m assumed)')
ax.scatter(usejEE / pA, proposedjEI / pA)
ax.legend()

<matplotlib.legend.Legend at 0x2a4ab847dc8>

In [66]:
# what happens if i don't multiply and divide by the set-points twice?

In [71]:
jEESum = nExc * usejEE
jIESum = nExc * usejIE

slopeExc = p['setUpFRExc'] / p['setUpFRInh']
slopeInh = p['setUpFRExc'] / p['setUpFRInh']

# assuming m is known, solve for b as b = I - m * E
offsetExc = jEISumgood - jEESum * slopeExc
offsetInh = jIISumgood - jIESum * slopeInh

# calculate I as I = m * E + b
proposedjEISum = jEESum * slopeExc + offsetExc
proposedjIISum = jIESum * slopeInh + offsetInh
proposedjEI = proposedjEISum / nInh
proposedjII = proposedjIISum / nInh

# this seems quite wrong!!!

In [87]:
print(proposedjEI, proposedjII)   # objectively good weights

136.70879393 pA 101.22971876 pA


In [61]:
p['setUpFRExc'] / p['setUpFRInh'] * sumExcInputToExc

81.99719496 * namp

In [89]:
jEESum = nExc * usejEE * float(p['setUpFRExc'])

sumExcInputToExc = jEESum.copy()
sumInhInputToExc = p['setUpFRExc'] / p['setUpFRInh'] * sumExcInputToExc - \
                   p['setUpFRExc'] / p['setUpFRInh'] / p['gainExc'] / second - \
                   p['threshExc'] / p['setUpFRInh'] / second  # amp

proposedjEI = sumInhInputToExc / nInh / float(p['setUpFRInh'])

In [91]:
p['setUpFRExc'] / p['setUpFRInh'] * sumExcInputToExc

81.99719496 * namp

In [76]:
sumInhInputToExc

81.98730023 * namp

In [77]:
# i believe this is essentially what my old version was driving towards... 8x too big

sumInhInputToExc / nInh

1.02484125 * namp

In [78]:
sumInhInputToExc / nInh / float(p['setUpFRInh'])  # now that's too low?

73.20294663 * pamp

In [79]:
jIESum = nExc * usejIE * float(p['setUpFRExc'])

sumExcInputToInh = jIESum.copy()
sumInhInputToInh = p['setUpFRExc'] / p['setUpFRInh'] * sumExcInputToInh - \
                   1 / p['gainInh'] / second - \
                   p['threshInh'] / p['setUpFRInh'] / second

In [80]:
sumInhInputToInh

70.80240743 * namp

In [81]:
sumInhInputToInh / nInh / float(p['setUpFRInh'])

63.21643521 * pamp

In [82]:
plt.close('all')