/
builder.py
264 lines (215 loc) · 9.38 KB
/
builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from collections import defaultdict
import warnings
import numpy as np
from nengo.builder.signal import Signal, SignalDict
from nengo.builder.operator import TimeUpdate
from nengo.cache import NoDecoderCache
from nengo.exceptions import BuildError
from nengo.rc import rc
class Model:
"""Stores artifacts from the build process, which are used by `.Simulator`.
Parameters
----------
dt : float, optional
The length of a simulator timestep, in seconds.
label : str, optional
A name or description to differentiate models.
decoder_cache : DecoderCache, optional
Interface to a cache for expensive parts of the build process.
builder : `nengo.builder.Builder`, optional
A ``Builder`` instance to use for building. Defaults to a
new ``Builder()``.
Attributes
----------
config : Config or None
Build functions can set a config object here to affect sub-builders.
decoder_cache : DecoderCache
Interface to a cache for expensive parts of the build process.
dt : float
The length of each timestep, in seconds.
label : str or None
A name or description to differentiate models.
operators : list
List of all operators created in the build process.
All operators must be added to this list, as it is used by Simulator.
params : dict
Mapping from objects to namedtuples containing parameters generated
in the build process.
probes : list
List of all probes. Probes must be added to this list in the build
process, as this list is used by Simulator.
seeded : dict
All objects are assigned a seed, whether the user defined the seed
or it was automatically generated. 'seeded' keeps track of whether
the seed is user-defined. We consider the seed to be user-defined
if it was set directly on the object, or if a seed was set on the
network in which the object resides, or if a seed was set on any
ancestor network of the network in which the object resides.
seeds : dict
Mapping from objects to the integer seed assigned to that object.
sig : dict
A dictionary of dictionaries that organizes all of the signals
created in the build process, as build functions often need to
access signals created by other build functions.
step : Signal
The current step (i.e., how many timesteps have occurred thus far).
time : Signal
The current point in time.
toplevel : Network
The top-level network being built.
This is sometimes useful for accessing network elements after build,
or for the network builder to determine if it is the top-level network.
"""
def __init__(self, dt=0.001, label=None, decoder_cache=None, builder=None):
self.dt = dt
self.label = label
self.decoder_cache = (
NoDecoderCache() if decoder_cache is None else decoder_cache
)
# Will be filled in by the network builder
self.toplevel = None
self.config = None
# Resources used by the build process
self.operators = []
self.params = {}
self.probes = []
self.seeds = {}
self.seeded = {}
self.sig = defaultdict(dict)
self.sig["common"][0] = Signal(
np.array(0.0, dtype=rc.float_dtype), readonly=True, name="ZERO"
)
self.sig["common"][1] = Signal(
np.array(1.0, dtype=rc.float_dtype), readonly=True, name="ONE"
)
self.step = Signal(np.array(0, dtype=rc.int_dtype), name="step")
self.time = Signal(np.array(0, dtype=rc.float_dtype), name="time")
self.add_op(TimeUpdate(self.step, self.time))
self.builder = Builder() if builder is None else builder
self.build_callback = None
def __str__(self):
return "Model: %s" % self.label
def add_op(self, op):
"""Add an operator to the model.
In addition to adding the operator, this method performs additional
error checking by calling the operator's ``make_step`` function.
Calling ``make_step`` catches errors early, such as when signals are
not properly initialized, which aids debugging. For that reason,
we recommend calling this method over directly accessing
the ``operators`` attribute.
"""
self.operators.append(op)
if rc["nengo.Simulator"].getboolean("fail_fast"):
# Fail fast by trying make_step with a temporary sigdict
signals = SignalDict()
op.init_signals(signals)
op.make_step(signals, self.dt, np.random)
def build(self, obj, *args, **kwargs):
"""Build an object into this model.
See `.Builder.build` for more details.
Parameters
----------
obj : object
The object to build into this model.
"""
built = self.builder.build(self, obj, *args, **kwargs)
if self.build_callback is not None:
self.build_callback(obj)
return built
def has_built(self, obj):
"""Returns true if the object has already been built in this model.
.. note:: Some objects (e.g. synapses) can be built multiple times,
and therefore will always result in this method returning
``False`` even though they have been built.
This check is implemented by checking if the object is in the
``params`` dictionary. Build function should therefore add themselves
to ``model.params`` if they cannot be built multiple times.
Parameters
----------
obj : object
The object to query.
"""
return obj in self.params
class Builder:
"""Manages the build functions known to the Nengo build process.
Consists of two class methods to encapsulate the build function registry.
All build functions should use the `.Builder.register` method as a
decorator. For example:
.. testcode::
class MyRule(nengo.learning_rules.LearningRuleType):
modifies = "decoders"
...
@nengo.builder.Builder.register(MyRule)
def build_my_rule(model, my_rule, rule):
...
registers a build function for ``MyRule`` objects.
Build functions should not be called directly, but instead called through
the `.Model.build` method. `.Model.build` uses the `.Builder.build` method
to ensure that the correct build function is called based on the type of
the object passed to it.
For example, to build the learning rule type ``my_rule`` from above, do:
.. testcode::
with nengo.Network() as net:
ens_a = nengo.Ensemble(10, 1)
ens_b = nengo.Ensemble(10, 1)
my_rule = MyRule()
connection = nengo.Connection(ens_a, ens_b, learning_rule_type=my_rule)
model = nengo.builder.Model()
model.build(my_rule, connection.learning_rule)
This will call the ``build_my_rule`` function from above with the arguments
``model, my_rule, connection.learning_rule``.
Attributes
----------
builders : dict
Mapping from types to the build function associated with that type.
"""
builders = {}
@classmethod
def build(cls, model, obj, *args, **kwargs):
"""Build ``obj`` into ``model``.
This method looks up the appropriate build function for ``obj`` and
calls it with the model and other arguments provided.
Note that if a build function is not specified for a particular type
(e.g., `.EnsembleArray`), the type's method resolution order will be
examined to look for superclasses
with defined build functions (e.g., `.Network` in the case of
`.EnsembleArray`).
This indirection (calling `.Builder.build` instead of the build
function directly) enables users to augment the build process in their
own models, rather than having to modify Nengo itself.
In addition to the parameters listed below, further positional and
keyword arguments will be passed unchanged into the build function.
Parameters
----------
model : Model
The `.Model` instance in which to store build artifacts.
obj : object
The object to build into the model.
"""
if model.has_built(obj):
# TODO: Prevent this at pre-build validation time.
warnings.warn("Object %s has already been built." % obj)
return None
for obj_cls in type(obj).__mro__:
if obj_cls in cls.builders:
break
else:
raise BuildError("Cannot build object of type %r" % type(obj).__name__)
return cls.builders[obj_cls](model, obj, *args, **kwargs)
@classmethod
def register(cls, nengo_class):
"""A decorator for adding a class to the build function registry.
Raises a warning if a build function already exists for the class.
Parameters
----------
nengo_class : Class
The type associated with the build function being decorated.
"""
def register_builder(build_fn):
if nengo_class in cls.builders:
warnings.warn(
"Type '%s' already has a builder. Overwriting." % nengo_class
)
cls.builders[nengo_class] = build_fn
return build_fn
return register_builder