-
Notifications
You must be signed in to change notification settings - Fork 45
/
unflaxify.py
465 lines (410 loc) · 18 KB
/
unflaxify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recursively transforms `flax.linen` modules into Penzai layers.
This utility is designed to make it easier to:
* Patch and inspect `flax.linen` Modules using Penzai tools.
* Simplify the process of migrating from `flax.linen` to Penzai by swapping out
components one by one.
It relies on inspecting Flax internals and may not be completely robust.
Transformed layers in particular are not currently supported.
"""
from __future__ import annotations
import collections
import dataclasses
from typing import Any, Callable, Sequence
import flax
import flax.typing
import jax
from penzai import pz
@pz.pytree_dataclass
class ArgsAndKwargs(pz.Struct):
"""Representation of the arguments and kwargs of a function call."""
args: tuple[Any, ...]
kwargs: dict[str, Any]
@classmethod
def capture(cls, *args, **kwargs) -> ArgsAndKwargs:
"""Captures positional and keyword arguments from this call."""
return cls(args=args, kwargs=kwargs)
@pz.pytree_dataclass
class InterceptedFlaxScopeData(pz.Struct):
"""A frozen representation of data in a particular Flax scope.
Flax implements its modules using a "functional core" which is a stateful
manager of variables, parameters, and random keys for a module and all its
submodules. This class represents a "Penzai view" of the data held in the
scope for a particular module, not including its submodules.
Attributes:
parameters: The collection of named parameters used directly by this module
(not a submodule), represented as Penzai parameters. If this method was
called multiple times, the parameters may be shared parameter references.
variables: The collection of other variables used directly by this module
(not a submodule), represented as Penzai state effects.
immutable_variables: The collection of immutable variables used directly by
this module (not a submodule).
rngs: The collection of RNGs used by this module method. Note that the
random numbers generated by Penzai will NOT exactly match the random
numbers generated by Flax, because Flax has custom logic for splitting and
seeding RNGs that is not easy to directly reproduce in Penzai.
"""
parameters: dict[str, pz.nn.ParameterLike[Any]]
variables: dict[str, dict[str, pz.de.LocalStateEffect]]
immutable_variables: dict[str, dict[str, Any]]
rngs: dict[str, pz.de.RandomEffect]
@pz.pytree_dataclass
class InterceptedFlaxModuleMethod(pz.Layer):
"""A representation of an intercepted Flax module method call.
An InterceptedFlaxModule captures the logic that runs when you call a single
Flax module method, and re-ifies its children and varibles so that they are
accessible in the PyTree structure of the model.
Attributes:
module: The unbound Flax module.
method_name: The name of the method being called.
scope_data: Data associated with this Flax module's scope, including
parameters, variables, and random keys used directly by this module (not
including its submodules). Can be None if this module does not have any
parameters or variables of its own and instead merely defers to its
submodules.
submodule_calls: The collection of all submodule calls made by this module
method, in call order. Each call is re-ified as a Penzai layer and can be
patched to run arbitrary logic instead of the original Flax module method.
"""
module: flax.linen.Module = dataclasses.field(metadata={"pytree_node": False})
method_name: str = dataclasses.field(metadata={"pytree_node": False})
scope_data: InterceptedFlaxScopeData | None
submodule_calls: dict[tuple[int, str], pz.LayerLike]
def __call__(self, args_and_kwargs: ArgsAndKwargs) -> Any:
"""Calls the intercepted method with the given arguments and kwargs.
Arguments:
args_and_kwargs: The positional and keyword arguments passed to the Flax
method.
Returns:
Whatever the output of the Flax method is.
"""
scope_data = self.scope_data
if scope_data is None:
scope_data = InterceptedFlaxScopeData({}, {}, {}, {})
if "params" in scope_data.variables:
raise ValueError(
"The 'params' variable collection should be part of `parameters` not"
" `variables`"
)
# Set up the concrete values for the parameters that this module needs, by
# retrieving them from their Penzai wrappers.
flax_variables = {
"params": {k: v.value for k, v in scope_data.parameters.items()},
**{
col: {k: v.get() for k, v in col_vars.items()}
for col, col_vars in scope_data.variables.items()
},
**scope_data.immutable_variables,
}
flax_rngs = {k: v.next_key() for k, v in scope_data.rngs.items()}
# Set up a Flax interceptor. This interceptor will detect all calls to
# child modules, and redirect them to instead run the logic in
# `submodule_calls`.
the_bound_module = None
intercept_counter = 0
def redirecting_interceptor(next_fun, args, kwargs, context):
nonlocal the_bound_module, intercept_counter
if the_bound_module is None:
# We are intercepting our own `apply`. Set things up and let it run.
if context.method_name == "setup":
# Allow setup method if called during our own apply.
return next_fun(*args, **kwargs)
assert context.module.name == self.module.name
assert type(context.module) is type(self.module) # pylint: disable=unidiomatic-typecheck
assert context.method_name == self.method_name
the_bound_module = context.module
return next_fun(*args, **kwargs)
elif context.module.parent is the_bound_module:
# We are intercepting a submodule call within our `apply`. Redirect it.
if context.method_name == "setup":
# Simply ignore setup method called on a submodule. We're going to
# intercept the regular call anyway, so there's nothing to set up.
# (If we tried to run setup, we'd run into an error, since our
# intercepted module wrappers don't store the parameters that belong
# to the submodules.)
return None
submodule_call_name = (
intercept_counter,
f"{context.module.name}.{context.method_name}",
)
intercept_counter += 1
subcall = self.submodule_calls[submodule_call_name]
return subcall(ArgsAndKwargs(args=args, kwargs=kwargs))
else:
# Something else. Perhaps this is an interceptor that was set up by
# some outer InterceptedFlaxModuleMethod layer. Just let it run.
return next_fun(*args, **kwargs)
# Run the Flax method, but immediately intercept all submodule calls and
# redirect them to our own Penzai sublayers.
with flax.linen.intercept_methods(redirecting_interceptor):
output, new_variables = self.module.apply(
flax_variables,
*args_and_kwargs.args,
rngs=flax_rngs,
method=self.method_name,
mutable=list(scope_data.variables.keys()),
**args_and_kwargs.kwargs,
)
# Update our states.
for col, new_col_vars in new_variables.items():
for k, v in new_col_vars.items():
scope_data.variables[col][k].set(v)
return output
def treescope_color(self) -> str:
return pz.color_from_string(type(self.module).__name__)
@dataclasses.dataclass
class _FlaxModelInterceptState:
module: flax.linen.Module
unclaimed_collections: dict[str, dict[str, Any]]
submodule_call_path: tuple[str, ...]
intercept_counter: int
submodule_calls: dict[str, pz.LayerLike]
def _common_prefix(parts: Sequence[tuple[Any, ...]]) -> tuple[Any, ...]:
result = []
for step in zip(*parts):
if all(t == step[0] for t in step[1:]):
result.append(step[0])
else:
break
return tuple(result)
def unflaxify_apply(
module: flax.linen.Module,
variables: flax.typing.VariableDict,
*dummy_args,
rngs: flax.typing.PRNGKey | flax.typing.RNGSequences | None = None,
method: Callable[..., Any] | str | None = None,
mutable: flax.core.scope.CollectionFilter = False,
**dummy_kwargs,
) -> InterceptedFlaxModuleMethod:
"""Creates an `InterceptedFlaxModuleMethod` from applying a Flax module.
Note that this function is intended for interactive exploration and to help
migrate Flax code to Penzai. It is not intended to be used in production
code. Not all Flax features are supported yet; in particular, transformed
layers are not supported and have not been tested.
Args:
module: The flax module to apply.
variables: A dictionary containing variables keyed by variable collections,
with same interpretation as for `flax.linen.Module.apply`.
*dummy_args: Positional arguments passed to the specified apply method.
These can be arbitrary values; their purpose is to enable tracing through
the Flax logic.
rngs: A dict of PRNGKeys to initialize the PRNG sequences, with same
interpretation as for `flax.linen.Module.apply`.
method: A function to call apply on. This is generally a function in the
module. If provided, applies this method. If not provided, applies the
``__call__`` method of the module. A string can also be provided to
specify a method by name.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable. ``str``: The
name of a single mutable collection. ``list``: A list of names of mutable
collections.
**dummy_kwargs: Keyword arguments passed to the specified apply method.
These can be arbitrary values; their purpose is to enable tracing through
the Flax logic.
Returns:
An intercepted version of the Flax module call, which can be manipulated
using Penzai tools.
"""
if rngs is None:
rngs = {}
original_rngs = rngs
intercept_states_by_path: dict[Any, _FlaxModelInterceptState] = {}
root_intercept = None
known_states = set()
parameter_usages = collections.defaultdict(list)
parameter_values = {}
def reifying_interceptor(next_fun, args, kwargs, context):
if context.method_name == "setup":
# Ignore setup method.
return next_fun(*args, **kwargs)
nonlocal root_intercept
# We make a depth-2 copy of the
# variable collections because Flax represents variables as nested dicts,
# with the first key corresponding to collection name and the second to the
# variable or submodule name. Making a copy allows us to mutate them to
# remove variables that we've already accounted for.
current_vars = {
col: dict(col_vars)
for col, col_vars in context.module.scope.variables().items()
}
if not intercept_states_by_path:
# This is the root module call.
submodule_call_path = ()
submodule_call_name = None
else:
# We are intercepting a submodule call.
assert context.module.scope.path not in intercept_states_by_path
parent_state = intercept_states_by_path[context.module.parent.scope.path]
assert parent_state.module is context.module.parent
# Figure out a name for this call.
submodule_call_name = (
parent_state.intercept_counter,
f"{context.module.name}.{context.method_name}",
)
submodule_call_path = parent_state.submodule_call_path + (
submodule_call_name,
)
parent_state.intercept_counter += 1
# Claim all of the variables/parameters owned by this module.
for col in current_vars.keys():
assert col in parent_state.unclaimed_collections
if context.module.name in parent_state.unclaimed_collections[col]:
del parent_state.unclaimed_collections[col][context.module.name]
# Set up the state for any submodule calls.
current_state = _FlaxModelInterceptState(
module=context.module,
unclaimed_collections=current_vars,
submodule_call_path=submodule_call_path,
intercept_counter=0,
submodule_calls={},
)
# Run the submodule logic, allowing our interceptor to recursively populate
# the state we just built.
intercept_states_by_path[context.module.scope.path] = current_state
the_fn_output = next_fun(*args, **kwargs)
del intercept_states_by_path[context.module.scope.path]
# Assemble the re-ified layer by replacing Flax variables and parameters
# with their Penzai equivalents. Note that any variables that are owned
# by submodules will be moved into the submodule.
# We detect which of the variables were actually used by this method
# by inspecting the "reservations", which Flax uses to ensure that
# variables or submodules aren't accidentally duplicated.
parameter_vars, variables, immutable_variables = (
flax.core.scope.group_collections(
current_state.unclaimed_collections, ["params", mutable, True]
)
)
name_prefix = "".join(f"{s}." for s in context.module.scope.path)
converted_parameters = {}
for k, v in parameter_vars["params"].items():
if context.module.scope.name_reserved(k, col="params"):
param_name = name_prefix + k
converted_parameters[k] = pz.nn.SharedParameterLookup(
pz.de.SideInputRequest(pz.nn.SharedParamTag(param_name)),
value_structure=pz.chk.as_array_structure(v),
)
parameter_usages[param_name].append(submodule_call_path)
if param_name in parameter_values:
assert parameter_values[param_name] is v
else:
parameter_values[param_name] = v
converted_variables = {}
for col, col_vars in variables.items():
converted_variables[col] = {}
for k, v in col_vars.items():
if context.module.scope.name_reserved(k, col=col):
state_name = col + ":" + name_prefix + k
if state_name in known_states:
converted_variables[col][k] = pz.de.SharedLocalStateRequest(
name=state_name, category=col
)
else:
known_states.add(state_name)
converted_variables[col][k] = pz.de.FrozenLocalStateRequest(
state=v, name=state_name, category=col
)
converted_immutable_variables = {}
for col, col_vars in immutable_variables.items():
converted_immutable_variables[col] = {}
for k, v in col_vars.items():
if context.module.scope.name_reserved(k, col=col):
converted_immutable_variables[col][k] = v
scope_data = InterceptedFlaxScopeData(
parameters=converted_parameters,
variables=converted_variables,
immutable_variables=converted_immutable_variables,
rngs={
rng_name: pz.de.TaggedRandomRequest(tag=rng_name)
for rng_name in original_rngs
if context.module.scope.rng_counters[rng_name] > 0
},
)
if (
not scope_data.parameters
and not scope_data.variables
and not scope_data.immutable_variables
and not scope_data.rngs
):
scope_data = None
converted_layer = InterceptedFlaxModuleMethod(
module=context.module.clone(),
method_name=context.method_name,
scope_data=scope_data,
submodule_calls=current_state.submodule_calls,
)
if not intercept_states_by_path:
root_intercept = converted_layer
else:
# Add this call to the parent's submodule calls.
intercept_states_by_path[
context.module.parent.scope.path
].submodule_calls[submodule_call_name] = converted_layer
return the_fn_output
def go():
with flax.linen.intercept_methods(reifying_interceptor):
_ = module.apply(
variables,
*dummy_args,
rngs=rngs,
method=method,
mutable=mutable,
**dummy_kwargs,
)
jax.eval_shape(go)
assert root_intercept is not None
# Handle parameter sharing: we walk the tree and identify the deepest
# ancestor that dominates all usages of the same parameter.
params_by_owner_path = collections.defaultdict(list)
param_singletons = set()
for param_name, paths in parameter_usages.items():
if len(paths) == 1:
param_singletons.add(param_name)
else:
params_by_owner_path[_common_prefix(paths)].append(param_name)
def bind_params(
intercepted: InterceptedFlaxModuleMethod, path: tuple[str, ...]
) -> pz.Layer:
# Insert parameters that should live here.
scope_data = intercepted.scope_data
if scope_data is not None:
for local_name, stub in list(scope_data.parameters.items()):
param_name = stub.ref.tag.name # pylint: disable=attribute-error
if param_name in param_singletons:
scope_data.parameters[local_name] = pz.nn.Parameter(
name=param_name, value=parameter_values[param_name]
)
# Update submodule calls recursively.
for call_name, call in list(intercepted.submodule_calls.items()):
intercepted.submodule_calls[call_name] = bind_params(
call, path + (call_name,)
)
if path in params_by_owner_path:
# Take ownership of parameters that were used across multiple submodule
# calls at this level.
return pz.de.WithConstantSideInputs.handling(
body=intercepted,
side_inputs={
pz.nn.SharedParamTag(name): pz.nn.Parameter(
name=name, value=parameter_values[name]
)
for name in params_by_owner_path[path]
},
handler_id=pz.de.infer_or_check_handler_id(
"flax_shared_params", intercepted
),
)
else:
return intercepted
return bind_params(root_intercept, ())