forked from google/jax
-
Notifications
You must be signed in to change notification settings - Fork 2
/
linear_util.py
212 lines (167 loc) · 6.27 KB
/
linear_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for defining linear functions composed with transformations.
"Linear" here is meant in the sense of linear types; that is, a linear function
may be called at most once.
For example:
from jax import linear_util as lu
# A transformation that scales its argument down and its result up.
@lu.transformation
def scale_transformer(scale, x):
ans = yield (x / scale,)
yield x * scale
def f(x):
return x + 1
g = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
g = scale_transformer(g, 2.0) # Scale inputs/outputs by 2.0
g = scale_transformer(g, 0.7) # Scale inputs/outputs further by 0.7.
print(g.call_wrapped(3.)) # Call the transformed function.
A `WrappedFun` object represents a function `f`, together with a
sequence of nested transformations that are to be applied to the positional
arguments at call time and function return values at return time.
`WrappedFun` objects explicitly represent the set of transformations so that
they can be used as dictionary keys for memoization. `WrappedFun` objects
compare as equal only if they compute the same function.
Transformations are implemented as generators to save call stack frames.
A transformation's generator takes arguments `gen args + args`, and yields
a tuple of transformed arguments that should be passed to the wrapped
function. The result of the wrapped function is passed back to the generator
using `gen.send()`, and the generator yields the transformed results to pass
back to the caller.
Transformations can also return auxiliary data using the `transform_with_aux`
decorator. For example:
@lu.transformation_with_aux
def scale_transformer_aux(scale, x):
ans = yield (x / scale,)
yield (x * scale, "Auxiliary data: {}".format(x))
g = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
g, aux_thunk = scale_transformer_aux(g, 2.0) # Scale inputs/outputs by 2.0
print(g.call_wrapped(3.)) # Call the transformed function.
print(aux_thunk()) # Retrieves the auxiliary data computed during evaluation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .util import curry, partial, OrderedDict
def thunk(f):
store = Store()
def f_memoized():
if not store:
# TODO(dougalm): save/restore relevant environment state too
store.store(f())
return store.val
return f_memoized
class StoreException(Exception): pass
class Store(object):
def store(self, val):
assert not self, "Store occupied"
self._val = val
@property
def val(self):
if not self:
raise StoreException("Store empty")
return self._val
def __nonzero__(self):
return hasattr(self, '_val')
__bool__ = __nonzero__
@curry
def staged(f, *init_args):
store = Store()
def f_partial(*rest):
ans, aux = f(*(init_args + rest))
store.store(aux)
return ans
f_partial.__name__ = f.__name__ + "_staged"
return f_partial, thunk(lambda: store.val)
class WrappedFun(object):
"""Represents a function `f` to which `transforms` are to be applied.
Arguments:
f: the function to be transformed.
transforms: a list of `(gen, gen_args, out_store)` tuples representing
transformations to apply to `f.`
params: extra parameters to pass as keyword arguments to `f`.
"""
def __init__(self, f, transforms, params):
self.f = f
self.transforms = transforms
self.params = params
def wrap(self, *transformation):
return WrappedFun(self.f, [transformation] + self.transforms, self.params)
def populate_stores(self, other):
for (_, _, self_store), (_, _, other_store) in zip(self.transforms,
other.transforms):
if self_store is not None:
self_store.store(other_store.val)
def call_wrapped(self, *args, **kwargs):
stack = []
for gen, gen_args, out_store in self.transforms:
gen = gen(*(gen_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
del gen
ans = self.f(*args, **dict(self.params, **kwargs))
del args
while stack:
gen, out_store = stack.pop()
ans = gen.send(ans)
if out_store is not None:
ans, side = ans
out_store.store(side)
return ans
def __repr__(self):
def transform_to_str(x):
i, (gen, args, _) = x
return "{} : {} {}".format(i, fun_name(gen), fun_name(args))
transformation_stack = map(transform_to_str, enumerate(self.transforms))
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
def hashable_payload(self):
return (self.f,
tuple((gen, tuple(gen_args)) for gen, gen_args, _ in self.transforms),
tuple(sorted(self.params.items())))
def __hash__(self):
return hash(self.hashable_payload())
def __eq__(self, other):
return self.hashable_payload() == other.hashable_payload()
@curry
def transformation(gen, fun, *transformation_args):
return fun.wrap(gen, transformation_args, None)
@curry
def transformation_with_aux(gen, fun, *transformation_args):
out_store = Store()
out_thunk = lambda: out_store.val
return fun.wrap(gen, transformation_args, out_store), out_thunk
def fun_name(f):
try:
return f.__name__
except:
return str(f)
def wrap_init(f, params={}):
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
return WrappedFun(f, [], params)
def memoize(call, max_size=4096):
cache = OrderedDict()
def memoized_fun(f, *args):
key = (f, args)
if key in cache:
ans, f_prev = cache[key]
cache.move_to_end(key)
f.populate_stores(f_prev)
else:
if len(cache) > max_size:
cache.popitem(last=False)
ans = call(f, *args)
cache[key] = (ans, f)
return ans
return memoized_fun