-
Notifications
You must be signed in to change notification settings - Fork 136
/
models.py
352 lines (302 loc) · 14.6 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
# Copyright (C) 2021-22 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/
import numpy as np
from lava.magma.core.model.py.connection import (
LearningConnectionModelFloat,
LearningConnectionModelBitApproximate,
)
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.proc.dense.process import Dense, LearningDense, DelayDense
from lava.utils.weightutils import SignMode, determine_sign_mode,\
truncate_weights, clip_weights
class AbstractPyDenseModelFloat(PyLoihiProcessModel):
"""Implementation of Conn Process with Dense synaptic connections in
floating point precision. This short and simple ProcessModel can be used
for quick algorithmic prototyping, without engaging with the nuances of a
fixed point implementation.
"""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
a_buff: np.ndarray = LavaPyType(np.ndarray, float)
# weights is a 2D matrix of form (num_flat_output_neurons,
# num_flat_input_neurons)in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, float)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)
def run_spk(self):
# The a_out sent on a each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff)
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
self.a_buff = self.weights.dot(s_in)
else:
s_in = self.s_in.recv().astype(bool)
self.a_buff = self.weights[:, s_in].sum(axis=1)
@implements(proc=Dense, protocol=LoihiProtocol)
@requires(CPU)
@tag("floating_pt")
class PyDenseModelFloat(AbstractPyDenseModelFloat):
pass
class AbstractPyDenseModelBitAcc(PyLoihiProcessModel):
"""Implementation of Conn Process with Dense synaptic connections that is
bit-accurate with Loihi's hardware implementation of Dense, which means,
it mimics Loihi behavior bit-by-bit.
"""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=16)
a_buff: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=16)
# weights is a 2D matrix of form (num_flat_output_neurons,
# num_flat_input_neurons) in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=8)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)
def __init__(self, proc_params):
super().__init__(proc_params)
# Flag to determine whether weights have already been scaled.
self.weights_set = False
self.weight_exp: int = self.proc_params.get("weight_exp", 0)
def run_spk(self):
self.weight_exp = self.proc_params.get("weight_exp", 0)
# Since this Process has no learning, weights are assumed to be static
# and only require scaling on the first timestep of run_spk().
if not self.weights_set:
num_weight_bits: int = self.proc_params.get("num_weight_bits", 8)
sign_mode: SignMode = self.proc_params.get("sign_mode") \
or determine_sign_mode(self.weights)
self.weights = clip_weights(self.weights, sign_mode, num_bits=8)
self.weights = truncate_weights(self.weights,
sign_mode,
num_weight_bits)
self.weights_set = True
# The a_out sent at each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff)
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
a_accum = self.weights.dot(s_in)
else:
s_in = self.s_in.recv().astype(bool)
a_accum = self.weights[:, s_in].sum(axis=1)
self.a_buff = (
np.left_shift(a_accum, self.weight_exp)
if self.weight_exp > 0
else np.right_shift(a_accum, -self.weight_exp)
)
@implements(proc=Dense, protocol=LoihiProtocol)
@requires(CPU)
@tag("bit_accurate_loihi", "fixed_pt")
class PyDenseModelBitAcc(AbstractPyDenseModelBitAcc):
pass
@implements(proc=LearningDense, protocol=LoihiProtocol)
@requires(CPU)
@tag("floating_pt")
class PyLearningDenseModelFloat(
LearningConnectionModelFloat, AbstractPyDenseModelFloat):
"""Implementation of Conn Process with Dense synaptic connections in
floating point precision. This short and simple ProcessModel can be used
for quick algorithmic prototyping, without engaging with the nuances of a
fixed point implementation.
"""
def __init__(self, proc_params):
super().__init__(proc_params)
def run_spk(self):
# The a_out sent at each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff)
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
self.a_buff = self.weights.dot(s_in)
else:
s_in = self.s_in.recv().astype(bool)
self.a_buff = self.weights[:, s_in].sum(axis=1)
self.recv_traces(s_in)
@implements(proc=LearningDense, protocol=LoihiProtocol)
@requires(CPU)
@tag("bit_approximate_loihi", "fixed_pt")
class PyLearningDenseModelBitApproximate(
LearningConnectionModelBitApproximate, AbstractPyDenseModelBitAcc):
"""Implementation of Conn Process with Dense synaptic connections that
uses similar constraints as Loihi's hardware implementation of dense
connectivity but does not reproduce Loihi bit-by-bit.
"""
def __init__(self, proc_params):
super().__init__(proc_params)
# Flag to determine whether weights have already been scaled.
self.num_weight_bits: int = self.proc_params.get("num_weight_bits", 8)
def run_spk(self):
self.weight_exp: int = self.proc_params.get("weight_exp", 0)
# Since this Process has no learning, weights are assumed to be static
# and only require scaling on the first timestep of run_spk().
if not self.weights_set:
self.weights = truncate_weights(
self.weights,
sign_mode=self.sign_mode,
num_weight_bits=self.num_weight_bits
)
self.weights_set = True
# The a_out sent at each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff)
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
a_accum = self.weights.dot(s_in)
else:
s_in = self.s_in.recv().astype(bool)
a_accum = self.weights[:, s_in].sum(axis=1)
self.a_buff = (
np.left_shift(a_accum, self.weight_exp)
if self.weight_exp > 0
else np.right_shift(a_accum, -self.weight_exp)
)
self.recv_traces(s_in)
class AbstractPyDelayDenseModel(PyLoihiProcessModel):
"""Abstract Conn Process with Dense synaptic connections which incorporates
delays into the Conn Process.
"""
weights: np.ndarray = None
delays: np.ndarray = None
a_buff: np.ndarray = None
def calc_act(self, s_in) -> np.ndarray:
"""
Calculate the activation matrix based on s_in by performing
delay_wgts * s_in.
"""
# First calculating the activations through delay_wgts * s_in
# This matrix is then summed across each row to get the
# activations to the output neurons for different delays.
# This activation vector is reshaped to a matrix of the form
# (n_flat_output_neurons * (max_delay + 1), n_flat_output_neurons)
# which is then transposed to get the activation matrix.
return np.reshape(
np.sum(self.get_delay_wgts_mat(self.weights,
self.delays,
self.a_buff.shape[-1] - 1) * s_in,
axis=1),
(self.a_buff.shape[-1], self.weights.shape[0])).T
@staticmethod
def get_delay_wgts_mat(weights, delays, max_delay) -> np.ndarray:
"""
Create a matrix where the synaptic weights are separated
by their corresponding delays. The first matrix contains all the
weights, where the delay is equal to zero. The second matrix
contains all the weights, where the delay is equal to one and so on.
These matrices are then stacked together vertically.
Returns 2D matrix of form
(num_flat_output_neurons * max_delay + 1, num_flat_input_neurons) where
delay_wgts[
k * num_flat_output_neurons : (k + 1) * num_flat_output_neurons, :
]
contains the weights for all connections with a delay equal to k.
This allows for the updating of the activation buffer and updating
weights.
"""
return np.vstack([
np.where(delays == k, weights, 0)
for k in range(max_delay + 1)
])
def update_act(self, s_in):
"""
Updates the activations for the connection.
Clears first column of a_buff and rolls them to the last column.
Finally, calculates the activations for the current time step and adds
them to a_buff.
This order of operations ensures that delays of 0 correspond to
the next time step.
"""
self.a_buff[:, 0] = 0
self.a_buff = np.roll(self.a_buff, -1)
self.a_buff += self.calc_act(s_in)
@implements(proc=DelayDense, protocol=LoihiProtocol)
@requires(CPU)
@tag("floating_pt")
class PyDelayDenseModelFloat(AbstractPyDelayDenseModel):
"""Implementation of Conn Process with Dense synaptic connections in
floating point precision. This short and simple ProcessModel can be used
for quick algorithmic prototyping, without engaging with the nuances of a
fixed point implementation. DelayDense incorporates delays into the Conn
Process.
"""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
a_buff: np.ndarray = LavaPyType(np.ndarray, float)
# weights is a 2D matrix of form (num_flat_output_neurons,
# num_flat_input_neurons) in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, float)
# delays is a 2D matrix of form (num_flat_output_neurons,
# num_flat_input_neurons) in C-order (row major).
delays: np.ndarray = LavaPyType(np.ndarray, int)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)
def run_spk(self):
# The a_out sent on each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff[:, 0])
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
else:
s_in = self.s_in.recv().astype(bool)
self.update_act(s_in)
@implements(proc=DelayDense, protocol=LoihiProtocol)
@requires(CPU)
@tag("bit_accurate_loihi", "fixed_pt")
class PyDelayDenseModelBitAcc(AbstractPyDelayDenseModel):
"""Implementation of Conn Process with Dense synaptic connections that is
bit-accurate with Loihi's hardware implementation of Dense, which means,
it mimics Loihi behaviour bit-by-bit. DelayDense incorporates delays into
the Conn Process. Loihi 2 has a maximum of 6 bits for delays, meaning a
spike can be delayed by 0 to 63 time steps."""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=16)
a_buff: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=16)
# weights is a 2D matrix of form (num_flat_output_neurons,
# num_flat_input_neurons) in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=8)
delays: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=6)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)
def __init__(self, proc_params):
super().__init__(proc_params)
# Flag to determine whether weights have already been scaled.
self.weights_set = False
def run_spk(self):
self.weight_exp: int = self.proc_params.get("weight_exp", 0)
# Since this Process has no learning, weights are assumed to be static
# and only require scaling on the first timestep of run_spk().
if not self.weights_set:
num_weight_bits: int = self.proc_params.get("num_weight_bits", 8)
sign_mode: SignMode = self.proc_params.get("sign_mode") \
or determine_sign_mode(self.weights)
self.weights = clip_weights(self.weights, sign_mode, num_bits=8)
self.weights = truncate_weights(self.weights,
sign_mode,
num_weight_bits)
self.weights_set = True
# Check if delays are within Loihi 2 constraints
if np.max(self.delays) > 63:
raise ValueError("DelayDense Process 'delays' expects values "
f"between 0 and 63 for Loihi, got "
f"{self.delays}.")
# The a_out sent at each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff[:, 0])
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
else:
s_in = self.s_in.recv().astype(bool)
a_accum = self.calc_act(s_in)
self.a_buff[:, 0] = 0
self.a_buff = np.roll(self.a_buff, -1)
self.a_buff += (
np.left_shift(a_accum, self.weight_exp)
if self.weight_exp > 0
else np.right_shift(a_accum, -self.weight_exp)
)