-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1433 from schmitts/Tetzlaff_2015
Reproducing figure 2F of The Use of Hebbian Cell Assemblies for Nonlinear Computation
- Loading branch information
Showing
2 changed files
with
206 additions
and
1 deletion.
There are no files selected for viewing
Submodule resources
updated
2 files
+ − | examples_images/frompapers.Tetzlaff_2015.1.png | |
+ − | examples_images/frompapers.Tetzlaff_2015.2.png |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Reproduces Figure 2F of | ||
The Use of Hebbian Cell Assemblies for Nonlinear Computation | ||
by Tetzlaff C., Dasgupta S., Kulvicius T. and Wörgötter F. | ||
Sci Rep 5, 12866 (2015). | ||
https://doi.org/10.1038/srep12866 | ||
Sebastian Schmitt, 2022 | ||
""" | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from brian2 import NeuronGroup, Synapses, StateMonitor, run, defaultclock, ms, second, TimedArray, seed | ||
|
||
# random seed that gives curves similar to the ones in the publication | ||
seed(9873487) | ||
|
||
# neuron parameters (sigmoidal activation) | ||
beta = 0.03 | ||
epsilon = 120 | ||
F_max = 100 | ||
F_T = 1 | ||
tau_u = 1*ms | ||
R = 0.012 | ||
|
||
# plasticity timescales | ||
tau_ratio = 60 | ||
# hebbian | ||
tau_H = 3e4*ms | ||
# synaptic scaling | ||
tau_SS = tau_ratio * tau_H | ||
|
||
# synaptic weights | ||
W_max = np.sqrt(tau_ratio*(F_max**2/(F_max - F_T))) | ||
W_ext = W_max | ||
W_input = W_max | ||
W_I = 0.3*W_max | ||
|
||
# stimulus | ||
N_units = 100 | ||
N_stim_units = 20 | ||
stim_A_units_until = N_stim_units | ||
stim_B_units_from = N_units-N_stim_units | ||
|
||
# connection probabilities | ||
p_E = 0.1 | ||
p_I = 0.2 | ||
|
||
# paper uses 0.3*ms | ||
DT = 0.5*ms | ||
defaultclock.dt = DT | ||
|
||
# duration of a learning trial | ||
lt = 5000*DT | ||
|
||
duration = 100*lt | ||
no_input_until = 5*lt | ||
balanced_until = duration/2 | ||
|
||
# gate balanced presentation of stimulus 1 and 2 | ||
balanced = TimedArray([lt_counter*lt < balanced_until for lt_counter in range(int(duration/lt))], dt=lt) | ||
|
||
# function used for stimulus (typo in paper, +1 is not part of the argument of sin) | ||
stim_func = TimedArray([100*(np.sin(0.1*(i+1))+1) for i in range(int(duration/DT))], dt=DT) | ||
|
||
# gate learning phase of either stimulus 1 or 2 | ||
learning_phase = TimedArray([i%10 > 3 for i in range(int(duration/(0.1*lt)))], dt=0.1*lt) | ||
|
||
# if not balanced present stimulus A three times more often than stimulus B | ||
stim_A_gate = TimedArray([lt_counter % 2 == 0 if balanced(lt_counter*lt) else lt_counter % 4 in [0,1,2] | ||
for lt_counter in range(int(duration/lt))], dt=lt) | ||
|
||
stim_B_gate = TimedArray([lt_counter % 2 == 1 if balanced(lt_counter*lt) else lt_counter % 4 == 3 | ||
for lt_counter in range(int(duration/lt))], dt=lt) | ||
|
||
# noise is applied also during stimulation | ||
neurons = NeuronGroup(N_units, | ||
""" | ||
F = F_max/(1+exp(beta*(epsilon-u))) : 1 | ||
du/dt = (-u + R*(I_E - I_I + W_input*(I_stim_A + I_stim_B)))/tau_u + R*W_ext*20*sqrt((DT/ms)/ms)*xi: 1 | ||
I_E : 1 | ||
I_I : 1 | ||
index : 1 (constant) | ||
stim_units_A = index < stim_A_units_until : boolean | ||
stim_units_B = index >= (stim_B_units_from) : boolean | ||
I_stim_A = learning_phase(t)*int(stim_units_A)*stim_A_gate(t)*stim_func(t) : 1 | ||
I_stim_B = learning_phase(t)*int(stim_units_B)*stim_B_gate(t)*stim_func(t) : 1 | ||
""", | ||
method = "euler") | ||
neurons.index = range(len(neurons)) | ||
|
||
# excitatory connections with Hebbian plasticity and synaptic scaling | ||
synapses_E = Synapses(neurons, neurons, | ||
""" | ||
dw/dt = 1/tau_H*F_pre*F_post + 1/tau_SS*(F_T - F_post)*w**2 : 1 (clock-driven) | ||
I_E_post = w*F_pre : 1 (summed) | ||
""", | ||
method="euler" | ||
) | ||
# do not connect between the two populations of stimulated neurons | ||
synapses_E.connect(p=p_E, condition="((j > stim_A_units_until and i >= stim_B_units_from) or (j < stim_B_units_from and i < stim_A_units_until))" | ||
"or ((i > stim_A_units_until and i < stim_B_units_from) and (j > stim_A_units_until and j < stim_B_units_from))") | ||
|
||
# fixed weight inhibitory connections | ||
synapses_I = Synapses(neurons, neurons, | ||
""" | ||
w : 1 | ||
I_I_post = w*F_pre : 1 (summed) | ||
""" | ||
) | ||
synapses_I.connect(p=p_I) | ||
synapses_I.w = W_I | ||
|
||
statemon_neurons = StateMonitor(neurons, ["F", "I_stim_A", "I_stim_B"], record=True, dt=100*defaultclock.dt) | ||
statemon_synapses_E = StateMonitor(synapses_E, "w", record=True, dt=100*defaultclock.dt) | ||
statemon_synapses_for_assembly_analysis = StateMonitor(synapses_E, "w", record=True, dt=lt) | ||
|
||
run(duration, report="text") | ||
|
||
# threshold saying that synaptic efficacies larger than theta are | ||
# 'strong' and others are 'weak' | ||
theta = 0.5*W_max | ||
|
||
in_assembly_A = [] | ||
in_assembly_B = [] | ||
|
||
# traverse through the graph following 'strong' synapses | ||
def go(W, source, units_in_assembly): | ||
units_in_assembly.add(source) | ||
# check all possible targets | ||
for target in range(N_units): | ||
w = W[source][target] | ||
if w > theta: | ||
W[source][target] = 0 | ||
go(W, target, units_in_assembly) | ||
|
||
# for each learning trial | ||
for ws in statemon_synapses_for_assembly_analysis.w.T: | ||
|
||
# construct a full weight matrix | ||
W = np.full((N_units, N_units), np.nan) | ||
W[synapses_E.i[:], synapses_E.j[:]] = ws | ||
|
||
for in_assembly, stim_units in zip([in_assembly_A, in_assembly_B], | ||
[range(stim_A_units_until), | ||
range(stim_B_units_from, N_units)]): | ||
|
||
units_in_assembly = set() | ||
|
||
# start with units that are stimulated | ||
for stim_unit in stim_units: | ||
go(W, stim_unit, units_in_assembly) | ||
|
||
in_assembly.append(len(units_in_assembly)) | ||
|
||
# competitive development of the two competing cell assemblies A and B as a function of the input protocol | ||
fig, ax = plt.subplots() | ||
|
||
ax.plot(in_assembly_A, linestyle="None", marker='o', color='orange', label="A") | ||
ax.plot(in_assembly_B, linestyle="None", marker='o', color='olivedrab', label="B") | ||
ax.set_ylim(19, 51) | ||
ax.set_xlim(0, 100) | ||
|
||
ax.set_ylabel("Neurons in Cell Assembly [%]") | ||
ax.set_xlabel("Learning Trial") | ||
|
||
ax.axvline(balanced_until/lt, linestyle='dashed', color='k') | ||
|
||
ax.text(15, 52, " A A", color='orange', fontfamily="monospace", fontsize="xx-large") | ||
ax.text(15, 52, " B B", color='olivedrab', fontfamily="monospace", fontsize="xx-large") | ||
|
||
ax.text(65, 52, " 3A 3A", color='orange', fontfamily="monospace", fontsize="xx-large") | ||
ax.text(65, 52, " B B", color='olivedrab', fontfamily="monospace", fontsize="xx-large") | ||
|
||
plt.show() | ||
|
||
# stimulus, neuronal activity and excitatory weights as function of time | ||
fig, axes = plt.subplots(3, sharex=True) | ||
|
||
axes[0].plot(statemon_neurons.I_stim_A[0], label="A", color='orange') | ||
axes[0].plot(statemon_neurons.I_stim_B[-1], label="B", color='olivedrab') | ||
axes[0].legend(loc="upper right") | ||
axes[0].set_title("Stimulus") | ||
|
||
axes[1].imshow(statemon_neurons.F, aspect='auto') | ||
axes[1].set_title("Neuron Activity") | ||
axes[1].axhline(stim_A_units_until, linestyle='dashed', color='white') | ||
axes[1].axhline(stim_B_units_from, linestyle='dashed', color='white') | ||
|
||
axes[2].imshow(statemon_synapses_E.w, aspect='auto') | ||
axes[2].set_title("Excitatory Weights") | ||
|
||
axes[2].set_xticks(range(0, 5000, 250)) | ||
axes[2].set_xticklabels(f"{i}" for i in range(0, 100, 5)) | ||
|
||
axes[2].set_xlabel("Learning Trial") | ||
axes[2].set_xlim(0, 5000) | ||
|
||
fig.tight_layout() | ||
|
||
plt.show() |