-
Notifications
You must be signed in to change notification settings - Fork 0
/
_staged.py
364 lines (289 loc) · 12.4 KB
/
_staged.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""Base classes for stateful models with stages.
:copyright: Copyright 2023-2024 by Matt Laporte.
:license: Apache 2.0. See LICENSE for details.
"""
from abc import abstractmethod, abstractproperty
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
from functools import cached_property
import logging
import os
from typing import (
TYPE_CHECKING,
Generic,
Optional,
Protocol,
Self,
TypeVar,
Union,
)
import equinox as eqx
from equinox import AbstractVar, Module, field
import jax
import jax.random as jr
from jaxtyping import Array, PRNGKeyArray, PyTree
import numpy as np
from feedbax._model import AbstractModel, ModelInput
from feedbax.intervene import AbstractIntervenor
from feedbax.misc import indent_str
from feedbax.state import AbstractState, StateT
if TYPE_CHECKING:
from feedbax.task import AbstractTaskInputs
logger = logging.getLogger(__name__)
ModelT = TypeVar("ModelT", bound=Module)
StateT = TypeVar("StateT", bound=Module)
class ModelStageCallable(Protocol):
# This is part of the `ModelInput` hack.
def __call__(self, input: ModelInput, state: PyTree[Array], *, key: PRNGKeyArray) -> PyTree[Array]:
...
class OtherStageCallable(Protocol):
def __call__(self, input: PyTree[Array], state: PyTree[Array], *, key: PRNGKeyArray) -> PyTree[Array]:
...
class ModelStage(Module, Generic[ModelT, StateT]):
"""Specification for a stage in a subclass of `AbstractStagedModel`.
Each stage of a model is a callable that performs a modification to part
of the model state.
!!! Note
To ensure that references to parts of the model instance remain fresh,
`callable_` takes the instance of `AbstractStagedModel` (i.e. `self`)
and returns the callable associated with the stage.
It is possible for references to become stale. For example, if we
assign `callable_=self.net` for the neural network update in
[`SimpleFeedback`][feedbax.bodies.SimpleFeedback], then it will
continue to refer to the neural network assigned to `self.net`
upon the model's construction, even after the network weights
have been updated during training—so, the model will not train.
Attributes:
callable_: The module, method, or function that transforms part of the
model state.
where_input: Selects the parts of the input and state to be passed
as input to `callable_`.
where_state: Selects the substate that passed and return as state to
`callable_`.
intervenors: Optionally, a sequence of state interventions to be
applied at the beginning of this model stage.
"""
callable: Callable[
[ModelT],
Union[ModelStageCallable, OtherStageCallable],
]
where_input: Callable[["AbstractTaskInputs", StateT], PyTree]
where_state: Callable[[StateT], PyTree]
intervenors: Sequence[AbstractIntervenor] = field(default_factory=tuple)
class AbstractStagedModel(AbstractModel[StateT]):
"""Base class for state-dependent models whose stages can be intervened upon.
!!! Info
To define a new staged model, the following complementary components
must be implemented:
1. A [final](https://docs.kidger.site/equinox/pattern/) subclass of
`AbstractState` that defines the PyTree structure of the model
state. The type of the fields of this PyTree are typically JAX
arrays, or else other `AbstractState` types associated with the
model's components.
2. A final subclass of
[`AbstractStagedModel`][feedbax.AbstractStagedModel]. Note that the
abstract class is a `Generic`, and for proper type checking, the
type argument of the subclass should be the type of `AbstractState`
defined in (1).
This subclass must implement the following:
1. A `model_spec` property giving a mapping from stage labels
to [`ModelStage`][feedbax.ModelStage] instances, each
specifying an operation performed on the model state.
2. An `init` method that takes a random key and returns a default
model state.
For an example, consider 1) [`SimpleFeedbackState`][feedbax.bodies.SimpleFeedbackState]
and 2) [`SimpleFeedback`][feedbax.bodies.SimpleFeedback].
"""
intervenors: AbstractVar[Mapping[str, Sequence[AbstractIntervenor]]]
def __call__(
self,
input: ModelInput,
state: StateT,
key: PRNGKeyArray,
) -> StateT:
"""Return an updated model state, given input and a prior state.
Arguments:
input: The input to the model.
state: The prior state of the model.
key: A random key which will be split to provide separate keys for
each model stage and intervenor.
"""
with jax.named_scope(type(self).__name__):
keys = jr.split(key, len(self._stages))
for (label, stage), key in zip(self._stages.items(), keys):
key_intervene, key_stage = jr.split(key)
keys_intervene = jr.split(key_intervene, len(stage.intervenors))
for intervenor, k in zip(stage.intervenors, keys_intervene):
if intervenor.label in input.intervene:
params = input.intervene[intervenor.label]
else:
params = None
state = intervenor(params, state, key=k)
callable_ = stage.callable(self)
subinput = stage.where_input(input.input, state)
# TODO: What's a less hacky way of doing this?
# I was trying to avoid introducing additional parameters to `AbstractStagedModel.__call__`
if isinstance(callable_, AbstractModel):
callable_input = ModelInput(subinput, input.intervene)
else:
callable_input = subinput
state = eqx.tree_at(
stage.where_state,
state,
callable_(
callable_input,
stage.where_state(state),
key=key_stage,
),
)
if os.environ.get("FEEDBAX_DEBUG", False) == "True":
debug_strs = [
indent_str(eqx.tree_pformat(x), indent=4)
for x in (callable_, subinput, stage.where_state(state))
]
log_str = "\n".join(
[
f"Model type: {type(self).__name__}",
f'Stage: "{label}"',
f"Callable:\n{debug_strs[0]}",
f"Input:\n{debug_strs[1]}",
f"Substate:\n{debug_strs[2]}",
]
)
logger.debug(f"\n{indent_str(log_str, indent=2)}\n")
return state
@abstractmethod
def init(
self,
*,
key: PRNGKeyArray,
) -> StateT:
"""Return a default state for the model."""
...
@abstractproperty
def model_spec(self) -> OrderedDict[str, ModelStage]:
"""Specify the model's computation in terms of state operations.
!!! Warning
It's necessary to return `OrderedDict` because `jax.tree_util`
still sorts `dict` keys, which usually puts the stages out of order.
"""
...
@cached_property
def _stages(self) -> OrderedDict[str, ModelStage]:
"""Zips up the user-defined intervenors with `model_spec`.
This should not be referred to in `__init__` before assigning `self.intervenors`!
"""
return jax.tree_map(
lambda x, y: eqx.tree_at(lambda x: x.intervenors, x, y),
self.model_spec,
jax.tree_map(
tuple, self.intervenors, is_leaf=lambda x: isinstance(x, list)
),
is_leaf=lambda x: isinstance(x, ModelStage),
)
def _get_intervenors_dict(
self,
intervenors: Optional[
Union[
Sequence[AbstractIntervenor], Mapping[str, Sequence[AbstractIntervenor]]
]
],
):
intervenors_dict = jax.tree_map(
lambda _: [],
self.model_spec,
is_leaf=lambda x: isinstance(x, ModelStage),
)
if intervenors is not None:
if isinstance(intervenors, Sequence):
# By default, place interventions in the first stage.
intervenors_dict.update({"get_feedback": list(intervenors)})
elif isinstance(intervenors, dict):
intervenors_dict.update(
jax.tree_map(
list, intervenors, is_leaf=lambda x: isinstance(x, Sequence)
)
)
else:
raise ValueError("intervenors not a sequence or dict of sequences")
return intervenors_dict
@property
def step(self) -> Module:
"""The model step.
For an `AbstractStagedModel`, this is trivially the model itself.
"""
return self
# TODO: Avoid referencing `AbstractIntervenor` here, to avoid a circular import
# with `feedbax.intervene`.
@property
def _all_intervenor_labels(self):
model_leaves = jax.tree_util.tree_leaves(
self, is_leaf=lambda x: isinstance(x, AbstractIntervenor)
)
labels = [
leaf.label for leaf in model_leaves if isinstance(leaf, AbstractIntervenor)
]
return tuple(labels)
def pformat_model_spec(
model: AbstractStagedModel,
indent: int = 2,
newlines: bool = False,
) -> str:
"""Returns a string representation of the model specification tree.
Shows what is called by `model`, and by any `AbstractStagedModel`s it calls.
!!! Warning
This assumes that the model spec is a tree/DAG. If there are cycles in
the model spec, this will recurse until an exception is raised.
Arguments:
model: The staged model to format.
indent: Number of spaces to indent each nested level of the tree.
newlines: Whether to add an extra blank line between each line.
"""
def get_spec_strs(model: AbstractStagedModel):
spec_strs = []
for label, stage_spec in model._stages.items():
intervenor_str = "".join(
[
f"intervenor: {type(intervenor).__name__}\n"
for intervenor in stage_spec.intervenors
]
)
callable = stage_spec.callable(model)
spec_str = f"{label}: "
if getattr(callable, "__wrapped__", None) is not None:
spec_str += "wrapped: "
# callable = callable.__wrapped__
# BoundMethods
if (func := getattr(callable, "__func__", None)) is not None:
owner = type(getattr(callable, "__self__")).__name__
spec_str += f"{owner}.{func.__name__}"
# Functions
elif (name := getattr(callable, "__name__", None)) is not None:
spec_str += f"{name}"
# Modules and other callable instances
else:
spec_str += f"{type(callable).__name__}"
spec_strs += [intervenor_str + spec_str]
if isinstance(callable, AbstractStagedModel):
spec_strs += [
" " * indent + spec_str for spec_str in get_spec_strs(callable)
]
return spec_strs
nl = "\n\n" if newlines else "\n"
return nl.join(get_spec_strs(model))
def pprint_model_spec(
model: AbstractStagedModel,
indent: int = 2,
newlines: bool = False,
):
"""Prints a string representation of the model specification tree.
Shows what is called by `model`, and by any `AbstractStagedModel`s it calls.
!!! Warning
This assumes that the model spec is a tree. If there are cycles in
the model spec, this will recurse until an exception is raised.
Arguments:
model: The staged model to format.
indent: Number of spaces to indent each nested level of the tree.
newlines: Whether to add an extra blank line between each line.
"""
print(pformat_model_spec(model, indent=indent, newlines=newlines))