-
Notifications
You must be signed in to change notification settings - Fork 100
/
numpy_vjps.py
98 lines (81 loc) · 3.7 KB
/
numpy_vjps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Vector-Jacobian products for NumPy functions.
This library consists of implementations of vector-Jacobian products (vjps, gradients)
for functions implemented in numpy. Each function-argument index pair is
provided a gradient function registered with defvjp().
Notice that vjps are implemented with (autograd-wrapped) functions as well. This
is the magic that allows one to compute gradients-of-gradients!
"""
from __future__ import absolute_import
import numpy as onp
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox
from autograd.tracer import primitive
from autograd.core import defvjp
# ----- Binary ufuncs -----
defvjp(anp.add, lambda g, ans, x, y : unbroadcast(x, g),
lambda g, ans, x, y : unbroadcast(y, g))
defvjp(anp.multiply, lambda g, ans, x, y : unbroadcast(x, y * g),
lambda g, ans, x, y : unbroadcast(y, x * g))
defvjp(anp.subtract, lambda g, ans, x, y : unbroadcast(x, g),
lambda g, ans, x, y : unbroadcast(y, -g))
defvjp(anp.divide, lambda g, ans, x, y : unbroadcast(x, g / y),
lambda g, ans, x, y : unbroadcast(y, - g * x / y**2))
defvjp(anp.true_divide, lambda g, ans, x, y : unbroadcast(x, g / y),
lambda g, ans, x, y : unbroadcast(y, - g * x / y**2))
defvjp(anp.power,
lambda g, ans, x, y: unbroadcast(x, g * y * x ** anp.where(y, y - 1, 1.)),
lambda g, ans, x, y: unbroadcast(y, g * anp.log(replace_zero(x, 1.)) * x ** y))
def replace_zero(x, val):
"""Replace all zeros in 'x' with 'val'."""
return anp.where(x, x, val)
def unbroadcast(target, g, broadcast_idx=0):
"""Remove broadcasted dimensions by summing along them.
When computing gradients of a broadcasted value, this is the right thing to
do when computing the total derivative and accounting for cloning.
"""
while anp.ndim(g) > anp.ndim(target):
g = anp.sum(g, axis=broadcast_idx)
for axis, size in enumerate(anp.shape(target)):
if size == 1:
g = anp.sum(g, axis=axis, keepdims=True)
if anp.iscomplexobj(g) and not anp.iscomplex(target):
g = anp.real(g)
return g
# ----- Simple grads -----
defvjp(anp.negative, lambda g, ans, x: -g)
defvjp(anp.exp, lambda g, ans, x: ans * g)
defvjp(anp.log, lambda g, ans, x: g / x)
defvjp(anp.tanh, lambda g, ans, x: g / anp.cosh(x) **2)
defvjp(anp.sinh, lambda g, ans, x: g * anp.cosh(x))
defvjp(anp.cosh, lambda g, ans, x: g * anp.sinh(x))
defvjp(anp.where, None,
lambda g, ans, c, x=None, y=None: anp.where(c, g, anp.zeros(g.shape)),
lambda g, ans, c, x=None, y=None: anp.where(c, anp.zeros(g.shape), g))
defvjp(anp.reshape, lambda g, ans, x, shape, order=None:
anp.reshape(g, anp.shape(x), order=order))
# ----- Dot grads -----
def _dot_vjp_0(g, ans, lhs, rhs):
if max(anp.ndim(lhs), anp.ndim(rhs)) > 2:
raise NotImplementedError("Current dot vjps only support ndim <= 2.")
if anp.ndim(lhs) == 0:
return anp.sum(rhs * g)
if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 1:
return g * rhs
if anp.ndim(lhs) == 2 and anp.ndim(rhs) == 1:
return g[:, None] * rhs
if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 2:
return anp.dot(rhs, g)
return anp.dot(g, rhs.T)
def _dot_vjp_1(g, ans, lhs, rhs):
if max(anp.ndim(lhs), anp.ndim(rhs)) > 2:
raise NotImplementedError("Current dot vjps only support ndim <= 2.")
if anp.ndim(rhs) == 0:
return anp.sum(lhs * g)
if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 1:
return g * lhs
if anp.ndim(lhs) == 2 and anp.ndim(rhs) == 1:
return anp.dot(g, lhs)
if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 2:
return lhs[:, None] * g
return anp.dot(lhs.T, g)
defvjp(anp.dot, _dot_vjp_0, _dot_vjp_1)