-
Notifications
You must be signed in to change notification settings - Fork 203
/
leaky.py
272 lines (215 loc) · 9.29 KB
/
leaky.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
from .neurons import LIF
import torch
from torch import nn
class Leaky(LIF):
"""
First-order leaky integrate-and-fire neuron model.
Input is assumed to be a current injection.
Membrane potential decays exponentially with rate beta.
For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`.
If `reset_mechanism = "subtract"`, then :math:`U[t+1]` will have
`threshold` subtracted from it whenever the neuron emits a spike:
.. math::
U[t+1] = βU[t] + I_{\\rm in}[t+1] - RU_{\\rm thr}
If `reset_mechanism = "zero"`, then :math:`U[t+1]` will be set to `0`
whenever the neuron emits a spike:
.. math::
U[t+1] = βU[t] + I_{\\rm syn}[t+1] - R(βU[t] + I_{\\rm in}[t+1])
* :math:`I_{\\rm in}` - Input current
* :math:`U` - Membrane potential
* :math:`U_{\\rm thr}` - Membrane threshold
* :math:`R` - Reset mechanism: if active, :math:`R = 1`, otherwise \
:math:`R = 0`
* :math:`β` - Membrane potential decay rate
Example::
import torch
import torch.nn as nn
import snntorch as snn
beta = 0.5
# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()
# initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Leaky(beta=beta)
self.fc2 = nn.Linear(num_hidden, num_outputs)
self.lif2 = snn.Leaky(beta=beta)
def forward(self, x, mem1, spk1, mem2):
cur1 = self.fc1(x)
spk1, mem1 = self.lif1(cur1, mem1)
cur2 = self.fc2(spk1)
spk2, mem2 = self.lif2(cur2, mem2)
return mem1, spk1, mem2, spk2
:param beta: membrane potential decay rate. Clipped between 0 and 1
during the forward-pass. May be a single-valued tensor (i.e., equal
decay rate for all neurons in a layer), or multi-valued (one weight per
neuron).
:type beta: float or torch.tensor
:param threshold: Threshold for :math:`mem` to reach in order to
generate a spike `S=1`. Defaults to 1
:type threshold: float, optional
:param spike_grad: Surrogate gradient for the term dS/dU. Defaults to
None (corresponds to ATan surrogate gradient. See
`snntorch.surrogate` for more options)
:type spike_grad: surrogate gradient function from snntorch.surrogate,
optional
:param surrogate_disable: Disables surrogate gradients regardless of
`spike_grad` argument. Useful for ONNX compatibility. Defaults
to False
:type surrogate_disable: bool, Optional
:param init_hidden: Instantiates state variables as instance variables.
Defaults to False
:type init_hidden: bool, optional
:param inhibition: If `True`, suppresses all spiking other than the
neuron with the highest state. Defaults to False
:type inhibition: bool, optional
:param learn_beta: Option to enable learnable beta. Defaults to False
:type learn_beta: bool, optional
:param learn_threshold: Option to enable learnable threshold. Defaults
to False
:type learn_threshold: bool, optional
:param reset_mechanism: Defines the reset mechanism applied to \
:math:`mem` each time the threshold is met. Reset-by-subtraction: \
"subtract", reset-to-zero: "zero", none: "none". Defaults to "subtract"
:type reset_mechanism: str, optional
:param state_quant: If specified, hidden state :math:`mem` is quantized
to a valid state for the forward pass. Defaults to False
:type state_quant: quantization function from snntorch.quant, optional
:param output: If `True` as well as `init_hidden=True`, states are
returned when neuron is called. Defaults to False
:type output: bool, optional
:param graded_spikes_factor: output spikes are scaled this value, if specified. Defaults to 1.0
:type graded_spikes_factor: float or torch.tensor
:param learn_graded_spikes_factor: Option to enable learnable graded spikes. Defaults to False
:type learn_graded_spikes_factor: bool, optional
:param reset_delay: If `True`, a spike is returned with a one-step delay after the threshold is reached.
Defaults to True
:type reset_delay: bool, optional
Inputs: \\input_, mem_0
- **input_** of shape `(batch, input_size)`: tensor containing input
features
- **mem_0** of shape `(batch, input_size)`: tensor containing the
initial membrane potential for each element in the batch.
Outputs: spk, mem_1
- **spk** of shape `(batch, input_size)`: tensor containing the
output spikes.
- **mem_1** of shape `(batch, input_size)`: tensor containing the
next membrane potential for each element in the batch
Learnable Parameters:
- **Leaky.beta** (torch.Tensor) - optional learnable weights must be
manually passed in, of shape `1` or (input_size).
- **Leaky.threshold** (torch.Tensor) - optional learnable thresholds
must be manually passed in, of shape `1` or`` (input_size).
"""
def __init__(
self,
beta,
threshold=1.0,
spike_grad=None,
surrogate_disable=False,
init_hidden=False,
inhibition=False,
learn_beta=False,
learn_threshold=False,
reset_mechanism="subtract",
state_quant=False,
output=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
reset_delay=True,
):
super().__init__(
beta,
threshold,
spike_grad,
surrogate_disable,
init_hidden,
inhibition,
learn_beta,
learn_threshold,
reset_mechanism,
state_quant,
output,
graded_spikes_factor,
learn_graded_spikes_factor,
)
self._init_mem()
if self.reset_mechanism_val == 0: # reset by subtraction
self.state_function = self._base_sub
elif self.reset_mechanism_val == 1: # reset to zero
self.state_function = self._base_zero
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int
self.reset_delay = reset_delay
def _init_mem(self):
mem = torch.zeros(1)
self.register_buffer("mem", mem)
def reset_mem(self):
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
def init_leaky(self):
"""Deprecated, use :class:`Leaky.reset_mem` instead"""
self.reset_mem()
return self.mem
def forward(self, input_, mem=None):
if not mem == None:
self.mem = mem
if self.init_hidden and not mem == None:
raise TypeError(
"`mem` should not be passed as an argument while `init_hidden=True`"
)
if not self.mem.shape == input_.shape:
self.mem = torch.zeros_like(input_, device=self.mem.device)
self.reset = self.mem_reset(self.mem)
self.mem = self.state_function(input_)
if self.state_quant:
self.mem = self.state_quant(self.mem)
if self.inhibition:
spk = self.fire_inhibition(
self.mem.size(0), self.mem
) # batch_size
else:
spk = self.fire(self.mem)
if not self.reset_delay:
do_reset = (
spk / self.graded_spikes_factor - self.reset
) # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
self.mem = self.mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
self.mem = self.mem - do_reset * self.mem
if self.output:
return spk, self.mem
elif self.init_hidden:
return spk
else:
return spk, self.mem
def _base_state_function(self, input_):
base_fn = self.beta.clamp(0, 1) * self.mem + input_
return base_fn
def _base_sub(self, input_):
return self._base_state_function(input_) - self.reset * self.threshold
def _base_zero(self, input_):
self.mem = (1 - self.reset) * self.mem
return self._base_state_function(input_)
def _base_int(self, input_):
return self._base_state_function(input_)
@classmethod
def detach_hidden(cls):
"""Returns the hidden states, detached from the current graph.
Intended for use in truncated backpropagation through time where
hidden state variables are instance variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], Leaky):
cls.instances[layer].mem.detach_()
@classmethod
def reset_hidden(cls):
"""Used to clear hidden state variables to zero.
Intended for use where hidden state variables are instance variables.
Assumes hidden states have a batch dimension already."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], Leaky):
cls.instances[layer].mem = torch.zeros_like(
cls.instances[layer].mem,
device=cls.instances[layer].mem.device,
)