-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
custom_transpose.py
192 lines (150 loc) · 6.7 KB
/
custom_transpose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Optional, Tuple
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util
from jax._src import api_util
from jax._src import custom_api_util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
### bespoke linear_util and api_util deviations
class StoreEqual(lu.Store):
"""Stores an unchanging value. Checks empty reads and unequal overwrites."""
def store(self, val):
if self._val is not lu._EMPTY_STORE_VALUE and val != self._val:
raise lu.StoreException(
f"Store assignment mismatch, from {self._val} to {val}")
self._val = val
@util.curry
def transformation_with_aux(
gen, fun: lu.WrappedFun, *gen_static_args) -> Tuple[lu.WrappedFun, Any]:
out_store = StoreEqual()
out_thunk = lambda: out_store.val
return fun.wrap(gen, gen_static_args, out_store), out_thunk
flatten_fun_nokwargs = transformation_with_aux(
api_util.flatten_fun_nokwargs.args[0]) # type: ignore[has-type]
### api
@custom_api_util.register_custom_decorator_type
class custom_transpose:
fun: Callable
transpose: Optional[Callable] = None
def __init__(self, fun: Callable):
functools.update_wrapper(self, fun)
self.fun = fun # type: ignore[assignment]
__getattr__ = custom_api_util.forward_attr
def def_transpose(self, transpose: Callable):
self.transpose = transpose
return transpose
@traceback_util.api_boundary
def __call__(self, out_types, res_arg, lin_arg):
_, res_tree = tree_flatten(res_arg)
_, lin_tree = tree_flatten(lin_arg)
args_flat, in_tree = tree_flatten((res_arg, lin_arg))
# TODO(frostig,mattjj): check that out_trees match
# TODO(frostig,mattjj): could, and should, we avoid flattening
# self.fun at this point?
flat_fun, out_tree2 = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
out_types_flat, out_tree = tree_flatten(out_types)
out_flat = custom_transpose_p.bind(flat_fun, *args_flat,
transpose=self.transpose,
out_types=out_types_flat,
lin_tree=lin_tree,
res_tree=res_tree,
out_tree=out_tree)
return tree_unflatten(out_tree, out_flat)
### utils
def tree_fill(x, treedef):
return tree_unflatten(treedef, [x] * treedef.num_leaves)
def tree_fill_like(x, tree):
return tree_fill(x, tree_structure(tree))
def tree_broadcast(full_treedef, tree, is_leaf=None):
full_tree = tree_fill(0, full_treedef)
return tree_map(tree_fill_like, tree, full_tree, is_leaf=is_leaf)
def is_treedef_prefix(entire, prefix):
entire = tree_fill(0, entire)
prefix = tree_fill(0, prefix)
try:
tree_map(lambda x, y: x, prefix, entire)
except ValueError:
return False
return True
def rule_name(rule):
return getattr(rule, '__name__', '<unnamed transpose rule>')
def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
if not is_treedef_prefix(lin_tree, rule_out_tree):
if hasattr(rule, '_transpose_type_error'):
raise rule._transpose_type_error(lin_tree, rule_out_tree)
else:
raise TypeError(
'structure of custom transpose rule\'s output does not prefix-match '
'structure of primal function\'s linear inputs under '
f'custom transpose rule ({rule_name(rule)}).\n'
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')
### custom_transpose primitive and rules
class CustomTransposePrimitive(core.Primitive):
call_primitive = False
map_primitive = False
multiple_results = True
def bind(self, call, *args, **params):
# TODO(frostig,mattjj): This doesn't handle closures yet, which is
# a bit involved. Closures are complicated by us binding `call`
# twice in the JVP rule for custom transpose. The `env_trace_todo`
# output by `process_env_traces` due to one of those two bindings
# should be passable to the other, and need to be passed onward
# since the second bind is deferred by partial eval (since it
# typically receives unknowns)
top_trace = core.find_top_trace(args)
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_transpose(self, call, tracers, **params)
return outs
# TODO(frostig,mattjj): consider keeping `call` as a named parameter
# instead of following this "call primitive" convention.
def get_bind_params(self, params):
new_params = dict(params)
return [new_params.pop('call')], new_params
# TODO(frostig,mattjj): reinstate checks
def custom_transpose_typecheck(*avals, **params):
pass
def custom_transpose_transpose_rule(
cts, *args, call, transpose, out_types, res_tree, lin_tree, out_tree):
call_in_tree = treedef_tuple((res_tree, lin_tree))
# TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
# to which we are transposing (via `ad.is_undefined_primal`).
# Consider passing this information to the custom transpose rule?
res_arg, lin_arg = tree_unflatten(call_in_tree, args)
del lin_arg
assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))
cts = [ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
for ct in cts]
ct_out = tree_unflatten(out_tree, cts)
ct_lin = transpose(res_arg, ct_out)
check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
ct_lin_flat, _ = tree_flatten(
tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None),
is_leaf=lambda x: x is None)
return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule