-
Notifications
You must be signed in to change notification settings - Fork 2
/
sim.py
272 lines (235 loc) · 8.83 KB
/
sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
__all__ = [
"LetterNetSim",
]
import numpy as np
import pandas as pd
from numba import njit
from .core import *
from .lnet import *
from . import bkh
class LetterNetSim:
"""
Simulator with a LetterNet
"""
# the ratio of minicolumn span, on the x-axis of plot, against a unit time step
COL_PLOT_WIDTH = 0.8
def __init__(
self,
# LetterNet, unit synaptic efficacy assumed to be valued 1.0
lnet: LetterNet,
# global scaling factor, to accommodate a unit synaptic efficacy value of 1.0
# roughly this specifies that:
# how many presynaptic spikes is enough to trigger a postsynaptic spike,
# when each incoming firing synapse has a unit efficacy value of 1.0
SYNAP_FACTOR=5,
# reset voltage, negative to enable refractory period
VOLT_RESET=-0.1,
# membrane time constant
τ_m=10,
# fire plot params
plot_width=800,
plot_height=600,
plot_n_steps=100,
fire_dots_glyph="square",
fire_dots_alpha=0.01,
fire_dots_size=3,
fire_dots_color="#0000FF",
):
self.lnet = lnet
self.VOLT_RESET = VOLT_RESET
self.τ_m = τ_m
self.SYNAP_FACTOR = SYNAP_FACTOR
self.cell_volts = np.full(lnet.CELLS_SHAPE, 0, "f4")
self.done_n_steps = 0
self.ds_spikes = bkh.ColumnDataSource(
{
"x": [],
"y": [],
}
)
p = bkh.figure(
title="Letter SDR Spikes",
x_axis_label="Time Step",
y_axis_label="Column (with Letter Spans)",
width=plot_width,
height=plot_height,
tools=[
"pan",
"box_zoom",
"xwheel_zoom",
# "ywheel_zoom",
"undo",
"redo",
"reset",
"crosshair",
],
y_range=(0, lnet.CELLS_SHAPE[0]),
x_range=(0, plot_n_steps),
)
label_ys = (
np.arange(1, lnet.ALPHABET_SIZE + 1) * lnet.N_SPARSE_COLS_PER_LETTER - 1
)
p.yaxis.ticker = bkh.FixedTicker(ticks=label_ys)
p.yaxis.formatter = bkh.CustomJSTickFormatter(
code=f"""
return {list(lnet.ALPHABET.alphabet())!r}[(tick+1)/{lnet.N_SPARSE_COLS_PER_LETTER}-1];
"""
)
p.scatter(
source=self.ds_spikes,
marker=fire_dots_glyph,
alpha=fire_dots_alpha,
size=fire_dots_size,
color=fire_dots_color,
)
self.fig = p
def simulate(
self,
n_steps,
prompt_words, # a single word or list of words to prompt
prompt_blur=0.8, # reduce voltage of other cells than the prompted letter
):
lnet = self.lnet
_w_bound, w_lcode = lnet.ALPHABET.encode_words(
[prompt_words] if isinstance(prompt_words, str) else prompt_words
)
spikes, step_n_spikes = _simulate_lnet(
n_steps,
self.cell_volts,
*lnet._excitatory_synapses(),
*lnet._inhibitory_synapses(),
lnet.sdr_indices,
w_lcode,
prompt_blur,
self.VOLT_RESET,
self.τ_m,
self.SYNAP_FACTOR,
)
if spikes.size <= 0:
return # no spike at all
x_base = self.done_n_steps
ci, ici = np.divmod(spikes, lnet.N_CELLS_PER_COL)
si = 0
xs, ys = [], []
for n_spikes in step_n_spikes:
if n_spikes > 0:
next_si = si + n_spikes
xs.append(
x_base
+ self.COL_PLOT_WIDTH * ici[si:next_si] / lnet.N_CELLS_PER_COL
)
ys.append(ci[si:next_si])
si = next_si
x_base += 1
self.done_n_steps = x_base
self.ds_spikes.stream({"x": np.concatenate(xs), "y": np.concatenate(ys)})
# threshold voltage for a spike, simplified to be 1.0 globally, as constant
SPIKE_THRES = 1.0
@njit
def _simulate_lnet(
n_steps, # total number of time steps to simulate
# neuron voltages, both as input for initial states, and as output for final states
cell_volts,
# excitatory synapse links/efficacies
excit_links,
excit_effis,
# inhibitory synapse links/efficacies
inhib_links,
inhib_effis,
sdr_indices, # letter SDR indices
prompt_lcodes, # letter code sequence
prompt_blur=0.8, # reduce voltage of other cells than the prompted letter
VOLT_RESET=-0.1, # reset voltage, negative to enable refractory period
τ_m=10, # membrane time constant
# global scaling factor, to facilitate a unit synaptic efficacy value of 1.0
# roughly this specifies that:
# how many presynaptic spikes is enough to trigger a postsynaptic spike,
# when each synapse has a unit efficacy value of 1.0
# todo: justify, this is sorta global inhibition?
SYNAP_FACTOR=5,
):
assert 0 < n_steps <= 20000
assert excit_links.ndim == excit_effis.ndim == 1
assert excit_links.shape == excit_effis.shape
assert inhib_links.ndim == inhib_effis.ndim == 1
assert inhib_links.shape == inhib_effis.shape
assert prompt_lcodes.ndim == 1
assert 0 <= prompt_lcodes.size <= n_steps
assert 0 <= prompt_blur <= 1.0
# ALPHABET_SIZE, N_COLS_PER_LETTER = sdr_indices.shape
N_COLS, N_CELLS_PER_COL = cell_volts.shape
excit_presynap_ci, excit_presynap_ici = np.divmod(
excit_links["i0"], N_CELLS_PER_COL
)
excit_postsynap_ci, excit_postsynap_ici = np.divmod(
excit_links["i1"], N_CELLS_PER_COL
)
inhib_presynap_ci, inhib_presynap_ici = np.divmod(
inhib_links["i0"], N_CELLS_PER_COL
)
inhib_postsynap_ci, inhib_postsynap_ici = np.divmod(
inhib_links["i1"], N_CELLS_PER_COL
)
def prompt_letter(lcode):
letter_volts = cell_volts[sdr_indices[lcode], :]
# suppress all cells first
# https://github.com/numba/numba/issues/8616
cell_volts.ravel()[cell_volts.ravel() > 0] *= prompt_blur
if np.any(letter_volts >= SPIKE_THRES):
# some cell(s) of prompted letter would fire
# restore letter cell voltages
cell_volts[sdr_indices[lcode], :] = letter_volts
else: # no cell of prompted letter would fire
# force fire all of the letter's cells
cell_volts[sdr_indices[lcode], :] = SPIKE_THRES
# we serialize the indices of spiked cells as the output record of simulation
# pre-allocate sufficient capacity to store maximumally possible spike info
spikes = np.empty(n_steps * cell_volts.size, "int32")
n_spikes = 0 # total number of individual spikes as recorded
# record number of spikes per each time step, it may vary across steps
step_n_spikes = np.zeros(n_steps, "int32")
# intermediate state data for cell voltages
cell_volts_tobe = np.empty_like(cell_volts)
prompt_i = 0
for i_step in range(n_steps):
if prompt_i < prompt_lcodes.size: # apply prompt
prompt_letter(prompt_lcodes[prompt_i])
prompt_i += 1
# accumulate input current, according to presynaptic spikes
cell_volts_tobe[:] = 0
for i in range(excit_links.size):
v = cell_volts[excit_presynap_ci[i], excit_presynap_ici[i]]
if v >= SPIKE_THRES:
cell_volts_tobe[
excit_postsynap_ci[i], excit_postsynap_ici[i]
] += excit_effis[i]
for i in range(inhib_links.size):
v = cell_volts[inhib_presynap_ci[i], inhib_presynap_ici[i]]
if v >= SPIKE_THRES:
cell_volts_tobe[
inhib_postsynap_ci[i], inhib_postsynap_ici[i]
] -= inhib_effis[i]
# apply the global scaling factor
cell_volts_tobe[:] /= SYNAP_FACTOR
# reset voltage if fired, or update the voltage
for ci in range(N_COLS):
for ici in range(N_CELLS_PER_COL):
v = cell_volts[ci, ici]
if v >= SPIKE_THRES:
# fired, reset
cell_volts_tobe[ci, ici] = VOLT_RESET
# record the spike
spikes[n_spikes] = ci * N_CELLS_PER_COL + ici
n_spikes += 1
step_n_spikes[i_step] += 1
else: # add back previous-voltage, plus leakage
# note it's just input-current before this update
cell_volts_tobe[ci, ici] += v + (0 - v) / τ_m
# update the final state at end of this time step
cell_volts[:] = cell_volts_tobe
assert n_spikes == np.sum(step_n_spikes), "bug?!"
return (
# return a copy of valid slice, to release extraneous memory allocated
spikes[:n_spikes].copy(),
step_n_spikes,
)