Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSoC Example #39

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions examples/IF_brette_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import numpy as np
from brian2 import *
from brian2modelfitting import *

'''
Adaptive exponential IF model introduced by Brette R. and Gerstner W. (2005). The model has 3 distinct firing regiems
depending on the parameters used in the simulation. This example shows how multiple sets of parameters can
be fitted to the same model simulatenously with the Brian2modelfitting toolbox.
'''

# Parameters
C = 281 * pF
gL = 30 * nS
taum = C / gL
EL = -70.6 * mV
VT = -50.4 * mV
DeltaT = 2 * mV
Vcut = VT + 5 * DeltaT

dt = 0.1 * ms
defaultclock.dt = dt

#Define input current

input_current1 = np.hstack([np.zeros(int(round(20*ms/dt))),
np.ones(int(round(100*ms/dt))),
np.zeros(int(round(20*ms/dt)))]) * 1
input_current2 = np.hstack([np.zeros(int(round(20*ms/dt))),
np.ones(int(round(100*ms/dt))),
np.zeros(int(round(20*ms/dt)))]) * 2
input_current = np.stack((input_current1, input_current2))*nA
I = TimedArray(input_current.T, dt=dt)

#First simulate the model with known parameters (to obtain data to fit parameters too!).
#Define model, setup monitoring & define parameters.

eqs = """
dvm/dt = (gL*(EL - vm) + gL*DeltaT*exp((vm - VT)/DeltaT) + I(t, i%2==1) - w)/C : volt
dw/dt = (a*(vm - EL) - w)/tauw : amp
tauw : second
a : siemens
b : amp
Vr : volt
"""

neuron = NeuronGroup(6, model=eqs, threshold='vm>Vcut',
reset="vm=Vr; w+=b", method='euler')
neuron.vm = EL
trace = StateMonitor(neuron, 'vm', record=True)
spikes = SpikeMonitor(neuron)

neuron.tauw = [144, 144, 20, 20, 144, 144] * ms
neuron.a = [4*nS, 4*nS, 4*nS, 4*nS, (2*C)/(144*ms), (2*C)/(144*ms)]
neuron.b = [0.0805, 0.0805, 0.5, 0.5, 0, 0] * nA
neuron.Vr = [-70.6*mV, -70.6*mV, VT+5*mV, VT+5*mV, -70.6*mV, -70.6*mV]

run(140 * ms)

spike_train = spikes.spike_trains()

#Put spike spikes and voltage traces into lists , probably lots of smarter ways to do this.
out_spikes = [[spike_train[0] / second, spike_train[1] /second], [spike_train[2] / second, spike_train[3] / second], [spike_train[4] / second, spike_train[5] / second]]
v_traces = [[trace.vm[0]/mV, trace.vm[1]/mV], [trace.vm[2]/mV, trace.vm[3]/mV], [trace.vm[4]/mV, trace.vm[5]/mV]]
Comment on lines +61 to +63
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok as it is. There are "smarter" ways to do this, but they are also less clear. I'd probably change the formatting a bit, though. E.e. have spike_train[0], spike_train[1] on one line, and then spike_train[2], spike_train[3] on the next, etc.
Also, no need to divide by seconds here, and I'd rather keep the v_traces with the units as well (for more consistency in the plotting code).


start_scope()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary here.


#Model fitting

eqs_fit = '''
dvm/dt = (gL*(EL - vm) + gL*DeltaT*exp((vm - VT)/DeltaT) + I - w)/C : volt
dw/dt = (a*(vm - EL) - w)/tauw : amp
tauw : second (constant)
a : siemens (constant)
b : amp (constant)
Vr : volt (constant)
'''

n_opt = NevergradOptimizer()
metric = GammaFactor(delta=1*ms, time=140*ms)

fitters = []
for i in range(len(out_spikes)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be something like

for spikes in out_spikes:

instead.

fitters.append(SpikeFitter(model=eqs_fit, input_var='I', dt=dt,
input=input_current, output=out_spikes[i],
n_samples=800,
param_init={'vm': EL},
threshold='vm > Vcut',
reset='vm=Vr; w+=b',))

result_dict_error = []
predict_spikes = []
fits = []
for fitter in fitters:
result_dict_error.append(fitter.fit(n_rounds=15,
optimizer=n_opt,
metric=metric,
callback='progressbar',
tauw=[1,200]*ms,
a=[3, 4]*nS,
b=[0, 1]*nA,
Vr=[-80,-40]*mV))

predict_spikes.append(fitter.generate_spikes(params=None))
#print('spike times:', spikes)

fits.append(fitter.generate(params=None,
output_var='vm'))

#Number of samples (n_samples) and number of epochs (n_rounds) is probably overkill here, can probably be reduced.

#Print parameter results and plot fitted traces
print(f'Printing fitting results...\n')

for i in range(3):
print(f'Results for parameter set {i+1}\n')
print(f'tauw (true/predict) : {neuron.tauw[i*2]}, {result_dict_error[i][0]["tauw"]}')
print(f'Vr (true/predict) : {neuron.Vr[i*2]}, {result_dict_error[i][0]["Vr"]}')
print(f'a (true/predict) : {neuron.a[i*2]}, {result_dict_error[i][0]["a"]}')
print(f'b (true/predict) : {neuron.b[i*2]}, {result_dict_error[i][0]["b"]}\n')

fig, ax = plt.subplots(2, 3, figsize=(12, 8))

for i in range(3):
ax[0, i].plot(trace.t/ms, v_traces[i][0], label='Simulated');
ax[0, i].plot(trace.t/ms, fits[i][0]/mV, label='Fitted');
ax[1, i].plot(trace.t/ms, v_traces[i][1]);
ax[1, i].plot(trace.t/ms, fits[i][1]/mV);
Comment on lines +124 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for semicola here.


ax[0, 0].legend()
ax[0, 0].set_xlabel('Time (ms)')
ax[0, 0].set_ylabel('v (mV)')

plt.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting really nit-picky: please add a line break in the end :)