/
node.py
209 lines (181 loc) · 7.58 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import inspect
import warnings
import numpy as np
import nengo.utils.numpy as npext
from nengo.base import NengoObject, ObjView
from nengo.exceptions import ValidationError
from nengo.params import Default, IntParam, Parameter
from nengo.processes import Process
from nengo.rc import rc
from nengo.utils.numpy import is_array_like
from nengo.utils.stdlib import checked_call
class OutputParam(Parameter):
def __init__(self, name, default, optional=True, readonly=False):
assert optional # None has meaning (passthrough node)
super().__init__(name, default, optional, readonly)
def _fn_args_validation_error(self, output, attr, node):
n_args = 2 if node.size_in > 0 else 1
msg = "output function '%s' is expected to accept exactly %d argument" % (
output,
n_args,
)
msg += (
" (time, as a float)"
if n_args == 1
else "s (time, as a float and data, as a NumPy array)"
)
return ValidationError(msg, attr=attr, obj=node)
def check_ndarray(self, node, output):
if len(output.shape) > 1:
raise ValidationError(
"Node output must be a vector (got shape %s)" % (output.shape,),
attr=self.name,
obj=node,
)
if node.size_in != 0:
raise ValidationError(
"output must be callable if size_in != 0", attr=self.name, obj=node
)
if node.size_out is not None and node.size_out != output.size:
raise ValidationError(
"Size of Node output (%d) does not match "
"size_out (%d)" % (output.size, node.size_out),
attr=self.name,
obj=node,
)
def coerce(self, node, output):
output = super().coerce(node, output)
size_in_set = node.size_in is not None
node.size_in = node.size_in if size_in_set else 0
# --- Validate and set the new size_out
if output is None:
if node.size_out is not None:
warnings.warn(
"'Node.size_out' is being overwritten with "
"'Node.size_in' since 'Node.output=None'"
)
node.size_out = node.size_in
elif isinstance(output, Process):
if not size_in_set:
node.size_in = output.default_size_in
if node.size_out is None:
node.size_out = output.default_size_out
elif callable(output):
self.check_callable_args_list(node, output)
# We trust user's size_out if set, because calling output
# may have unintended consequences (e.g., network communication)
if node.size_out is None:
node.size_out = self.check_callable_output(node, output)
elif is_array_like(output):
# Make into correctly shaped numpy array before validation
output = npext.array(output, min_dims=1, copy=False, dtype=rc.float_dtype)
self.check_ndarray(node, output)
if not np.all(np.isfinite(output)):
raise ValidationError(
"Output value must be finite.", attr=self.name, obj=node
)
node.size_out = output.size
else:
raise ValidationError(
"Invalid node output type %r" % type(output).__name__,
attr=self.name,
obj=node,
)
return output
def check_callable_output(self, node, output):
t, x = 0.0, np.zeros(node.size_in)
args = (t, x) if node.size_in > 0 else (t,)
result, invoked = checked_call(output, *args)
if not invoked:
raise self._fn_args_validation_error(output, self.name, node)
if result is not None:
result = np.asarray(result)
if len(result.shape) > 1:
raise ValidationError(
"Node output must be a vector (got shape %s)" % (result.shape,),
attr=self.name,
obj=node,
)
# return callable output size
return 0 if result is None else result.size
def check_callable_args_list(self, node, output):
# not all callables provide an argspec, such as numpy
try:
func_argspec = inspect.getfullargspec(output)
except (TypeError, ValueError):
pass
else:
args_len = len(func_argspec.args)
if inspect.ismethod(output) or not inspect.isroutine(output):
# don't count self as an argument
args_len -= 1
defaults_len = 0
if func_argspec.defaults is not None:
defaults_len = len(func_argspec.defaults)
required_len = args_len - defaults_len
expected_len = 2 if node.size_in > 0 else 1
if func_argspec.varargs:
args_len = max(expected_len, args_len)
if not required_len <= expected_len <= args_len:
raise self._fn_args_validation_error(output, self.name, node)
class Node(NengoObject):
"""Provide non-neural inputs to Nengo objects and process outputs.
Nodes can accept input, and perform arbitrary computations
for the purpose of controlling a Nengo simulation.
Nodes are typically not part of a brain model per se,
but serve to summarize the assumptions being made
about sensory data or other environment variables
that cannot be generated by a brain model alone.
Nodes can also be used to test models by providing specific input signals
to parts of the model, and can simplify the input/output interface of a
`~nengo.Network` when used as a relay to/from its internal
ensembles (see `~nengo.networks.EnsembleArray` for an example).
Parameters
----------
output : callable, array_like, or None
Function that transforms the Node inputs into outputs,
a constant output value, or None to transmit signals unchanged.
size_in : int, optional
The number of dimensions of the input data parameter.
size_out : int, optional
The size of the output signal. If None, it will be determined
based on the values of ``output`` and ``size_in``.
label : str, optional
A name for the node. Used for debugging and visualization.
seed : int, optional
The seed used for random number generation.
Note: no aspects of the node are random, so currently setting
this seed has no effect.
Attributes
----------
label : str
The name of the node.
output : callable, array_like, or None
The given output.
size_in : int
The number of dimensions for incoming connection.
size_out : int
The number of output dimensions.
"""
probeable = ("output",)
output = OutputParam("output", default=None)
size_in = IntParam("size_in", default=None, low=0, optional=True)
size_out = IntParam("size_out", default=None, low=0, optional=True)
def __init__(
self,
output=Default,
size_in=Default,
size_out=Default,
label=Default,
seed=Default,
):
if not (seed is Default or seed is None):
raise NotImplementedError("Changing the seed of a node has no effect")
super().__init__(label=label, seed=seed)
self.size_in = size_in
self.size_out = size_out
self.output = output # Must be set after size_out; may modify size_out
def __getitem__(self, key):
return ObjView(self, key)
def __len__(self):
return self.size_out