-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
attrs.py
132 lines (110 loc) · 4.77 KB
/
attrs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from jax._src import core
from jax._src import api_util
from jax._src import linear_util as lu
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import unzip2
JaxVal = Any
register = api_util.register_class_with_attrs
class GetAttrPrimitive(core.Primitive):
def bind_with_trace(self, trace, args, params):
() = args
return trace.process_getattr(**params)
getattr_p = GetAttrPrimitive('getattr')
class SetAttrPrimitive(core.Primitive):
def bind_with_trace(self, trace, args, params):
val, = args
return trace.process_setattr(trace.full_raise(val), **params)
setattr_p = SetAttrPrimitive('setattr')
def jax_getattr(obj: Any, attr: str):
return getattr_p.bind(obj=obj, attr=attr)
def jax_setattr(obj: Any, attr: str, val: JaxVal):
setattr_p.bind(val, obj=obj, attr=attr)
def _getattr_impl(_, *, obj, attr):
return getattr(obj, attr)
core.EvalTrace.process_getattr = _getattr_impl
def _setattr_impl(_, val, *, obj, attr):
setattr(obj, attr, val)
core.EvalTrace.process_setattr = _setattr_impl
def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str):
frame = trace.main.jaxpr_stack[-1] # type: ignore
if (obj, attr) not in frame.attrs_tracked:
init_val = getattr(obj, attr)
aval = core.raise_to_shaped(core.get_aval(init_val))
tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current())
var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval)
setattr(obj, attr, tracer)
frame.attrs_tracked.append((obj, attr))
frame.attrs_inits.append(init_val)
frame.attrs_vars.append(var)
pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked
def _getattr_staging(trace, *, obj, attr):
trace._ensure_tracked(obj, attr)
return getattr(obj, attr)
pe.DynamicJaxprTrace.process_getattr = _getattr_staging
def _setattr_staging(trace, tracer, *, obj, attr):
trace._ensure_tracked(obj, attr)
setattr(obj, attr, tracer)
pe.DynamicJaxprTrace.process_setattr = _setattr_staging
def jvp(f, primals, tangents, tangent_attrs_in):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree_ = tree_flatten(tangents)
if in_tree != in_tree_: raise Exception
f_, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped(
primals_flat, tangents_flat, tangent_attrs_in)
out_primals = tree_unflatten(out_tree(), out_primals_flat)
out_tangents = tree_unflatten(out_tree(), out_tangents_flat)
return out_primals, out_tangents, tangent_attrs_out
def _jvp(fun: lu.WrappedFun):
return jvpfun2(jvp_subtrace2(fun))
@lu.transformation
def jvpfun2(primals, tangents, tangent_attrs_in):
with core.new_main(ad.JVPTrace) as main:
out_primals, out_tangents, tangent_attrs_out = \
yield (main, primals, tangents, tangent_attrs_in), {}
del main
yield out_primals, out_tangents, tangent_attrs_out
@lu.transformation
def jvp_subtrace2(main, primals, tangents, tangent_attrs_in):
main.attrs_tracked = [] # attrs written to
trace = main.with_cur_sublevel()
for obj, name, tangent in tangent_attrs_in:
primal = jax_getattr(obj, name)
tracer = ad.JVPTracer(trace, primal, tangent)
jax_setattr(obj, name, tracer)
in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x
for x, t in zip(primals, tangents)]
ans = yield in_tracers, {}
out_tracers = map(trace.full_raise, ans)
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
tangent_attrs_out = []
for (obj, name) in main.attrs_tracked:
tracer = trace.full_raise(jax_getattr(obj, name))
jax_setattr(obj, name, tracer.primal)
if type(tracer.tangent) is not ad.Zero:
tangent_attrs_out.append((obj, name, tracer.tangent))
del main.attrs_tracked
yield out_primals, out_tangents, tangent_attrs_out
def _setattr_jvp(trace, tracer, *, obj, attr):
if (obj, attr) not in trace.main.attrs_tracked:
trace.main.attrs_tracked.append((obj, attr))
setattr(obj, attr, tracer)
ad.JVPTrace.process_setattr = _setattr_jvp