This repository has been archived by the owner on Dec 2, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 435
/
grads.py
390 lines (275 loc) · 9.04 KB
/
grads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Templates for gradient expressions.
The first argument to the adjoint must be the return value of the primal.
Use `d[x]` to denote the gradient of a variable `x`.
If the primal returns a tuple, the first argument to the adjoint is a tuple,
and the adjoint is supposed to define `d[y]` as a tuple.
Templates do not support use of `**kwargs`.
If a keyword argument isn't present in the adjoint, it means that Tangent
doesn't support it, and an error will be raised if it appears in user code.
Keyword arguments that are supported should always have the default value
`DEFAULT`, which means that they will be passed the default value of the
original function.
Adjoints have access to the inputs of the primal, output of the primal, and
gradients with respect to the output. They are expected to contain expressions
for the gradient with respect to the input. They don't have access to any
intermediate variables from the primal.
"""
from __future__ import absolute_import
import math
import types
import gast
import numpy
import tangent
from tangent import tracing
# Means that a keyword argument will be filled in by the default of the actual
# function
DEFAULT = object()
# Non-differentiable functions; not in the mathematical sense, but in the sense
# of them providing zero gradient because they provide meta-information (shape)
# do integer arithmetic, or are tensor constructors
NON_DIFFERENTIABLE = set([
len,
numpy.shape, numpy.zeros, numpy.ones, numpy.zeros_like, numpy.ones_like,
tangent.init_grad, tangent.array_size, tangent.Stack,
])
# TODO: Avoid requiring non-differentiables to define @tangent_s.
# All non-differentiable function need to create shadow zero-filled variables
# in forward mode. Currently we achieve that by defining identity @tangent_
# versions of those functions, but a beter approach would be to do that
# automatically.
# Create decorators that add templates to dictionaries
adjoints = {}
primals = {}
def get_module_functions(modules):
"""Finds functions that do not have implemented derivatives.
Args:
modules: A list of Python modules. Functions contained in these modules
will be checked for membership in 'implemented', and if not found,
will be added to an 'unimplemented' set
implemented: A Python object containing implemented derivatives. A function
should be checkable for membership using the `fn in implemented` syntax.
Returns:
module_fns: A set of functions, builtins or ufuncs in `modules`.
"""
module_fns = set()
for module in modules:
for key in dir(module):
attr = getattr(module, key)
if isinstance(
attr, (types.BuiltinFunctionType, types.FunctionType, numpy.ufunc)):
module_fns.add(attr)
return module_fns
def register_non_differentiable_functions(*funcs):
global NON_DIFFERENTIABLE
NON_DIFFERENTIABLE |= set(funcs)
def create_register(dict_):
def register(key):
def _(f):
dict_[key] = f
return f
return _
return register
adjoint = create_register(adjoints)
primal = create_register(primals)
# Functions: f => f, df
@adjoint(gast.FunctionDef)
def dfunction_def(adjoint_body, return_dx):
def df():
adjoint_body
return_dx
# Control flow
@primal(gast.For)
def for_(body, i, iter_, target, push, push_target, _target, _stack, op_id_iter,
op_id_target):
i = 0
for target in iter_:
_target = target
i += 1
body
push_target(_stack, _target, op_id_target)
push(_stack, i, op_id_iter)
@adjoint(gast.For)
def dfor_(adjoint_body, i, pop, pop_target, target, _stack, op_id_iter,
op_id_target):
i = pop(_stack, op_id_iter)
for _ in range(i):
target = pop_target(_stack, op_id_target)
adjoint_body
@primal(gast.While)
def while_(body, i, test, push, _stack, op_id):
i = 0
while test:
i += 1
body
push(_stack, i, op_id)
@adjoint(gast.While)
def dwhile_(adjoint_body, i, pop, _stack, op_id):
i = pop(_stack, op_id)
for _ in range(i):
adjoint_body
@primal(gast.If)
def if_(cond, test, body, orelse, push, _stack, op_id):
cond = test
if cond:
body
else:
orelse
push(_stack, cond, op_id)
@adjoint(gast.If)
def dif_(cond, adjoint_body, adjoint_orelse, pop, _stack, op_id):
cond = pop(_stack, op_id)
if cond:
adjoint_body
else:
adjoint_orelse
# Binary ops: z = op(x, y)
@adjoint(gast.Mult)
def mult(z, x, y):
d[x] = tangent.unbroadcast(d[z] * y, x)
d[y] = tangent.unbroadcast(d[z] * x, y)
@adjoint(gast.Add)
def add(z, x, y):
d[x] = tangent.unbroadcast(d[z], x)
d[y] = tangent.unbroadcast(d[z], y)
@adjoint(gast.Pow)
def pow(z, x, y):
d[x] = y * x ** (y - 1) * d[z]
d[y] = numpy.log(x) * x ** y * d[z]
@adjoint(gast.Sub)
def sub(z, x, y):
d[x] = tangent.unbroadcast(d[z], x)
d[y] = -tangent.unbroadcast(d[z], y)
@adjoint(gast.Div)
def div(z, x, y):
d[x] = d[z] / y
d[y] = -d[z] * x / (y * y)
# Unary ops: y = op(x)
@adjoint(gast.USub)
def usub(y, x):
d[x] = -d[y]
@adjoint(gast.UAdd)
def uadd(y, x):
d[x] = d[y]
#
# NumPy adjoints
#
@adjoint(numpy.tanh)
def tanh(y, x):
d[x] = d[y] * (1.0 - (y * y))
@adjoint(numpy.log)
def log(y, x):
d[x] = d[y] / x
@adjoint(numpy.sin)
def sin(y, x):
d[x] = d[y] * numpy.cos(x)
@adjoint(numpy.cos)
def cos(y, x):
d[x] = -d[y] * numpy.sin(x)
@adjoint(numpy.cosh)
def cosh(y, x):
d[x] = d[y] * numpy.sinh(x)
@adjoint(numpy.sinh)
def sinh(y, x):
d[x] = d[y] * numpy.cosh(x)
@adjoint(numpy.exp)
def exp(y, x):
d[x] = y * d[y]
@adjoint(numpy.sqrt)
def sqrt(y, x):
d[x] = d[y] / (2.0 * y)
@adjoint(numpy.multiply)
def multiply(z, x, y):
d[x] = y * d[z]
d[y] = x * d[z]
@adjoint(numpy.dot)
def dot(y, x1, x2):
d[x1] = tangent.grad_dot(d[y], x1, x2)
d[x2] = numpy.transpose(tangent.grad_dot(numpy.transpose(d[y]),
numpy.transpose(x2),
numpy.transpose(x1)))
@adjoint(numpy.reshape)
def reshape(y, x, y_shape):
d[x] = numpy.reshape(d[y], numpy.shape(x))
@adjoint(numpy.transpose)
def transpose(y, x):
d[x] = numpy.transpose(d[y])
@adjoint(numpy.broadcast_arrays)
def broadcast_arrays(ys, *args):
d[args] = tuple(tangent.unbroadcast_to(dy, numpy.shape(arg))
for arg, dy in zip(args, d[ys]))
@adjoint(numpy.sum)
def sum(y, x, axis=DEFAULT, dtype=DEFAULT, keepdims=DEFAULT):
d[x] = tangent.astype(tangent.unreduce(d[y], numpy.shape(x),
axis, keepdims), x)
@adjoint(numpy.mean)
def mean(y, x, axis=DEFAULT, dtype=DEFAULT, keepdims=DEFAULT):
n = tangent.astype(tangent.array_size(x, axis), x)
d[x] = tangent.astype(tangent.unreduce(d[y], numpy.shape(x),
axis, keepdims), x) / n
@adjoint(numpy.maximum)
def maximum(ans, x, y):
d[x] = d[ans] * tangent.balanced_eq(x, ans, y)
d[y] = d[ans] * tangent.balanced_eq(y, ans, x)
#
# Tangent adjoints
#
@adjoint(tangent.unreduce)
def unreduce(y, x, shape, axis, keepdims):
d[x] = tangent.unbroadcast(d[y], x)
@adjoint(tangent.unbroadcast)
def unbroadcast(y, array, shape):
d[array] = tangent.unreduce(d[y], numpy.shape(array), None, False)
@adjoint(tangent.add_grad)
def add_grad(z, left, right):
d[left] = tangent.unbroadcast(d[z], left)
d[right] = tangent.unbroadcast(d[z], right)
@adjoint(tangent.astype)
def astype(z, array, y):
d[array] = tangent.astype(d[z], array)
@adjoint(tangent.push)
def apush(stack, val, op_id):
d[val] = tangent.pop(stack, d[op_id])
@adjoint(tangent.pop)
def apop(z, stack, op_id):
tangent.push(stack, d[z], d[op_id])
@adjoint(tangent.push_stack)
def apush_stack(stack, val, op_id):
d[val] = tangent.pop_stack(stack, d[op_id])
@adjoint(tangent.pop_stack)
def apop_stack(z, stack, op_id):
tangent.push_stack(stack, d[z], d[op_id])
@adjoint(tangent.copy)
def acopy(z, x):
d[x] = tangent.copy(d[z])
#
# Tracing primitives
#
@primal(tracing.Traceable)
def traceable_primal(result, fn, vjp, tmp, args):
result, vjp = tangent.trace_grad(fn, args)
@adjoint(tracing.Traceable)
def traceable_adjoint(result, vjp, dargs):
dargs = vjp(d[result])
#
# Blacklist unimplemented NumPy grads
#
# We can enumerate all of the functions that we'd like grads for.
# Until we've written the adjoints of all functions we want to support,
# we will throw an explicit "no grad found" error for those we have not
# finished. UNIMPLEMENTED will contain the list of all of these unimplemented
# grad functions
UNIMPLEMENTED_ADJOINTS = get_module_functions(
(numpy, numpy.fft, numpy.linalg, numpy.random, math)) - set(adjoints)