/
action_selection.py
202 lines (169 loc) · 6.64 KB
/
action_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Implementation of action syntax."""
from collections import OrderedDict
from collections.abc import Mapping
import nengo
from nengo_spa.ast import dynamic
from nengo_spa.connectors import ModuleInput, RoutedConnection, input_vocab_registry
from nengo_spa.exceptions import SpaActionSelectionError, SpaTypeError
from nengo_spa.types import TScalar
class ActionSelection(Mapping):
"""
Implements an action selection system with basal ganglia and thalamus.
The *ActionSelection* instance has to be used as context manager and each
potential action is defined by an `.ifmax` call providing an expression
for the utility value and any number of effects (routing of information)
to activate when this utility value is highest of all.
Attributes
----------
active : ActionSelection
Class attribute providing the currently active ActionSelection
instance (if any).
built : bool
Indicates whether the action selection system has been built
successfully.
bg : nengo.Network
Basal ganglia network. Available after the action selection system has
been built.
thalamus : nengo.Network
Thalamus network. Available after the action selection system has
been built.
See Also
--------
nengo_spa.modules.BasalGanglia : Default basal ganglia network
nengo_spa.modules.Thalamus : Default thalamus network
Examples
--------
.. code-block:: python
with ActionSelection():
ifmax(dot(state, sym.A), sym.B >> state)
ifmax(dot(state, sym.B), sym.C >> state)
ifmax(dot(state, sym.C), sym.A >> state)
This will route the *B* Semantic Pointer to *state* when *state* is more
similar to *A* than any of the other Semantic Pointers. Similarly, *C*
will be routed to *state* when *state* is *B*. Once, *state* is *C*, it
will be reset to *A* and the cycle begins anew.
Further action selection examples:
* :ref:`/examples/question-control.ipynb`
* :ref:`/examples/spa-parser.ipynb`
* :ref:`/examples/spa-sequence.ipynb`
* :ref:`/examples/spa-sequence-routed.ipynb`
"""
active = None
def __init__(self):
super(ActionSelection, self).__init__()
self.built = False
self.bias = None
self.bg = None
self.thalamus = None
self._utilities = []
self._actions = []
# Maps labels of actions to the index of that action
self._name2idx = OrderedDict()
def __enter__(self):
assert not self.built
if ActionSelection.active is None:
ActionSelection.active = self
ModuleInput.routed_mode = True
else:
raise SpaActionSelectionError("Must not nest action selection contexts.")
return self
def __exit__(self, exc_type, exc_value, traceback):
ActionSelection.active = None
ModuleInput.routed_mode = False
if exc_type is not None:
RoutedConnection.free_floating.clear()
return
self._build()
def _build(self):
try:
if len(RoutedConnection.free_floating) > 0:
raise SpaActionSelectionError(
"All actions in an action selection context must be part "
"of an ifmax call."
)
finally:
RoutedConnection.free_floating.clear()
if len(self._utilities) <= 0:
return
self.bias = nengo.Node(1.0, label="bias")
self.bg = dynamic.BasalGangliaRealization(len(self._utilities))
self.thalamus = dynamic.ThalamusRealization(len(self._utilities))
self.thalamus.connect_bg(self.bg)
for index, utility in enumerate(self._utilities):
self.bg.connect_input(utility, index=index)
for index, action in enumerate(self._actions):
for effect in action:
if effect.fixed:
self.thalamus.connect_fixed(
index, effect.sink.input, transform=effect.transform()
)
else:
self.thalamus.construct_gate(index, self.bias)
channel = self.thalamus.construct_channel(
effect.sink.input, effect.type
)
effect.connect_to(channel.input)
self.thalamus.connect_gate(index, channel)
self.built = True
def __getitem__(self, key):
if isinstance(key, str):
key = self._name2idx[key]
return self._utilities[key]
def __iter__(self):
# Given not all actions have names, there will actions whose keys
# will be numbers and not names.
i = -1
for i, (name, v) in enumerate(self._name2idx.items()):
while i < v:
yield i
i += 1
yield name
for i in range(i + 1, len(self)):
yield i
def __len__(self):
return len(self._actions)
def add_action(self, name, *actions):
assert ActionSelection.active is self
if name is not None:
self._name2idx[name] = len(self._actions)
else:
name = str(len(self._actions))
utility = input_vocab_registry.declare_connector(
nengo.Node(size_in=1, label="Utility for action " + name), None
)
self._utilities.append(utility)
self._actions.append(actions)
RoutedConnection.free_floating.difference_update(actions)
return utility
def ifmax(name, condition, *actions):
"""
Defines a potential action within an `ActionSelection` context.
Parameters
----------
name : str
Name for the action
condition : nengo_spa.ast.base.Node
The utility value for the given actions.
actions : sequence of `RoutedConnection`
The actions to activate if the given utility is the highest.
Returns
-------
NengoObject
Nengo object that can be connected to, to provide additional input to
the utility value.
"""
if ActionSelection.active is None:
raise SpaActionSelectionError(
"ifmax must be used within the context of an ActionSelection instance."
)
if condition.type != TScalar:
raise SpaTypeError(
f"ifmax condition must evaluate to a scalar, but got {condition.type}."
)
if any(not isinstance(a, RoutedConnection) for a in actions):
raise SpaActionSelectionError(
"ifmax actions must be routing expressions like 'a >> b'."
)
utility = ActionSelection.active.add_action(name, *actions)
condition.connect_to(utility)
return utility