In [1]:
import torch
from torch import nn, autograd
import torch.nn.functional as F
from torch.autograd import Function
import numpy as np

In [2]:
class Exp(Function):
    
    @staticmethod
    def forward(ctx: autograd.function.FunctionCtx, x: torch.Tensor):
        res = x.exp()
        print('x: ', x)
        print('res: ', res)
        print('ctx: ', ctx)
        ctx.save_for_backward(res)
        return res
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        print("backward: ----------")
        print('Grad outputs: ', grad_outputs)
        out,  = ctx.saved_tensors
        print('out:  ', out)
        
        return grad_outputs[0] * out


In [3]:
x = torch.arange(3).float().to(float).requires_grad_()
x

tensor([0., 1., 2.], dtype=torch.float64, requires_grad=True)

In [4]:
exp = Exp.apply(x)

x:  tensor([0., 1., 2.], dtype=torch.float64, requires_grad=True)
res:  tensor([1.0000, 2.7183, 7.3891], dtype=torch.float64)
ctx:  <torch.autograd.function.ExpBackward object at 0x7f26747ba9a0>


In [5]:
y = exp.sum()
y

tensor(11.1073, dtype=torch.float64, grad_fn=<SumBackward0>)

In [6]:
y.backward()

backward: ----------
Grad outputs:  (tensor([1., 1., 1.], dtype=torch.float64),)
out:   tensor([1.0000, 2.7183, 7.3891], dtype=torch.float64, grad_fn=<ExpBackward>)


In [7]:
x.grad

tensor([1.0000, 2.7183, 7.3891], dtype=torch.float64)

In [8]:
x.requires_grad

True

In [131]:
class Linear(Function):
    
    @staticmethod
    def forward(ctx, input, weights, bias) -> torch.Tensor:
        # out = w @ x  + b
        out = weights @ input + bias
        ctx.save_for_backward(input, weights, bias, out)
        return out
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        print("Linear backward ", grad_outputs)
        # print(ctx.saved_tensors)
        input, weights, bias, out = ctx.saved_tensors
        # grad_ouputs: dl / dout
        # grad_w = dl / dw =  dout/dw * dl/dout = x
        grad_w = grad_outputs[0].unsqueeze(1) * input.expand(out.size(0), -1)
        grad_b =  bias.new_ones(bias.size()) * grad_outputs[0]
        print("grad w shape ", grad_w.shape)
        print("grad b shape ", grad_b.shape)
        inp_grad = None
        
        if ctx.needs_input_grad[0]:
            print("Need inp grad ")
            inp_grad = weights.t() @ grad_outputs[0]
        return  inp_grad, grad_w, grad_b,
    
class Sum(Function):
    @staticmethod
    def forward(ctx, x):
        out = x.sum()
        ctx.save_for_backward(x)
        return out

    @staticmethod
    def backward(ctx, *grad_outputs):
        # J * v
        # Jacobian dy.T/dx
        # grad_outputs: v
        
        print("Sum backward: ", grad_outputs)
        x,  = ctx.saved_tensors
        grad = grad_outputs[0] * torch.ones_like(x)
        print("Returning sum grad: ", grad)
        return grad
        
        

In [132]:
s = torch.arange(3).float().requires_grad_()
t = 2 * Sum.apply(s)
t.backward()
s.grad

Sum backward:  (tensor(2.),)
Returning sum grad:  tensor([2., 2., 2.])


tensor([2., 2., 2.])

In [133]:
w = torch.arange(4*3, dtype=torch.float).view(4, 3).requires_grad_()
b = torch.zeros(4).type_as(w).requires_grad_()
x = torch.arange(3).float()
x, w, b

(tensor([0., 1., 2.]),
 tensor([[ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.]], requires_grad=True),
 tensor([0., 0., 0., 0.], requires_grad=True))

In [134]:
y = Linear.apply(x, w, b)
y

tensor([ 5., 14., 23., 32.], grad_fn=<LinearBackward>)

In [135]:
z = 2 * Sum.apply(y)
z

tensor(148., grad_fn=<MulBackward0>)

In [136]:
z.backward()

Sum backward:  (tensor(2.),)
Returning sum grad:  tensor([2., 2., 2., 2.])
Linear backward  (tensor([2., 2., 2., 2.]),)
grad w shape  torch.Size([4, 3])
grad b shape  torch.Size([4])


In [137]:
w.grad

tensor([[0., 2., 4.],
        [0., 2., 4.],
        [0., 2., 4.],
        [0., 2., 4.]])

In [138]:
b.grad

tensor([2., 2., 2., 2.])

In [139]:
x.grad

In [140]:
w.grad -= 0.1 * w.grad

In [141]:
w.grad

tensor([[0.0000, 1.8000, 3.6000],
        [0.0000, 1.8000, 3.6000],
        [0.0000, 1.8000, 3.6000],
        [0.0000, 1.8000, 3.6000]])

In [142]:
w = torch.arange(4*3, dtype=torch.float).view(4, 3).requires_grad_()
b = torch.zeros(4).type_as(w).requires_grad_()
x = torch.arange(3).float()
q = torch.ones(3)
u =  2 * x * q  + 5
y = Linear.apply(u, w, b)
w, u, y,y.sum()

(tensor([[ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.]], requires_grad=True),
 tensor([5., 7., 9.]),
 tensor([ 25.,  88., 151., 214.], grad_fn=<LinearBackward>),
 tensor(478., grad_fn=<SumBackward0>))

In [145]:
y.sum().backward()

Linear backward  (tensor([1., 1., 1., 1.]),)
grad w shape  torch.Size([4, 3])
grad b shape  torch.Size([4])


In [146]:
w.grad

tensor([[5., 7., 9.],
        [5., 7., 9.],
        [5., 7., 9.],
        [5., 7., 9.]])

In [147]:
b.grad

tensor([1., 1., 1., 1.])

In [24]:
class Hadamard(Function):
    
    @staticmethod
    def forward(ctx, x, scale) -> torch.Tensor:
        print("forward hadamard: ", x)
        out = x * scale
        ctx.save_for_backward(x)
        return out
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        print("Hadamard backward ", grad_outputs)
        print(ctx.saved_tensors)
        input, = ctx.saved_tensors
        print("grad hadamard ")
        return  None, input * grad_outputs[0]
        
        

In [155]:
x = torch.arange(3).float()
q = torch.empty_like(x).copy_(x).requires_grad_()
h = Hadamard.apply(x, q)
h.retain_grad() # because not leaf this is needed for getting the gradients
w = torch.arange(4*3, dtype=torch.float).view(4, 3).requires_grad_()
b = torch.zeros(4).type_as(w).requires_grad_()
y = Linear.apply(h, w, b)

forward hadamard:  tensor([0., 1., 2.])


In [156]:
y

tensor([ 9., 24., 39., 54.], grad_fn=<LinearBackward>)

In [157]:
y.sum()

tensor(126., grad_fn=<SumBackward0>)

In [158]:
y.sum().backward()

Linear backward  (tensor([1., 1., 1., 1.]),)
grad w shape  torch.Size([4, 3])
grad b shape  torch.Size([4])
Need inp grad 
Hadamard backward  (tensor([18., 22., 26.]),)
(tensor([0., 1., 2.]),)
grad hadamard 


In [159]:
w, w.grad

(tensor([[ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.]], requires_grad=True),
 tensor([[0., 1., 4.],
         [0., 1., 4.],
         [0., 1., 4.],
         [0., 1., 4.]]))

In [160]:
h, h.grad

(tensor([0., 1., 4.], grad_fn=<HadamardBackward>), tensor([18., 22., 26.]))

In [31]:
q.grad

tensor([ 0., 22., 52.])

In [32]:
b.grad

tensor([1., 1., 1., 1.])

In [163]:
x = torch.arange(3).float()
q = torch.arange(3).float().requires_grad_()
h = Hadamard.apply(x, q)
h.retain_grad()
w = torch.arange(4*3, dtype=torch.float).view(4, 3).requires_grad_()
b = torch.zeros(4).type_as(w).requires_grad_()
# y = w @ h + b
y = Linear.apply(h, w, b)

forward hadamard:  tensor([0., 1., 2.])


In [164]:
y.sum().backward()

Linear backward  (tensor([1., 1., 1., 1.]),)
grad w shape  torch.Size([4, 3])
grad b shape  torch.Size([4])
Need inp grad 
Hadamard backward  (tensor([18., 22., 26.]),)
(tensor([0., 1., 2.]),)
grad hadamard 


In [165]:
w.grad

tensor([[0., 1., 4.],
        [0., 1., 4.],
        [0., 1., 4.],
        [0., 1., 4.]])

In [166]:
q.requires_grad

True

In [167]:
q.grad

tensor([ 0., 22., 52.])

In [168]:
h.grad

tensor([18., 22., 26.])

In [39]:
y.grad_fn

<AddBackward0 at 0x7f2560527c70>

In [40]:
y.grad_fn.next_functions

((<MvBackward0 at 0x7f2560527d90>, 0), (<AccumulateGrad at 0x7f2560527dc0>, 0))

In [169]:
y.grad_fn.next_functions[0]

(<torch.autograd.function.HadamardBackward at 0x7f2537b71130>, 0)

In [171]:
y.grad_fn.next_functions[0][0].next_functions

((None, 0), (<AccumulateGrad at 0x7f2537b5a850>, 0))

In [42]:
y.grad_fn.next_functions[0][0].next_functions[1][0]

<torch.autograd.function.HadamardBackward at 0x7f25604cd400>

In [43]:
y.grad_fn.next_functions[0][0].next_functions[1][0].next_functions

((None, 0), (<AccumulateGrad at 0x7f25604ccc40>, 0))

In [44]:
y.grad_fn.next_functions[0][0].next_functions[1][0].next_functions[1]

(<AccumulateGrad at 0x7f25604ccc40>, 0)

In [45]:
x = torch.tensor([1., 2, 3]).requires_grad_()
y = x
y.sum().backward()
x.grad

tensor([1., 1., 1.])

In [46]:
x = torch.tensor([1., 2, 3]).requires_grad_()
y = x
x.requires_grad_(False)
z = x
y.requires_grad

False

In [47]:
def add_tensor1(x, y):
    return x + y

@torch.no_grad()
def add_tensor2(x, y):
    return x + y

x = torch.tensor([1., 2, 3]).requires_grad_()
y = x

In [48]:
add_tensor1(x, y)

tensor([2., 4., 6.], grad_fn=<AddBackward0>)

In [49]:
add_tensor2(x, y)

tensor([2., 4., 6.])

In [50]:
W = torch.randn(3, 4).requires_grad_()
x = torch.arange(4).float()
y = W @ x
y

tensor([ 0.4446,  3.3463, -3.2221], grad_fn=<MvBackward0>)

In [51]:
y.backward(torch.ones_like(y), retain_graph=True)

In [52]:
W.grad

tensor([[0., 1., 2., 3.],
        [0., 1., 2., 3.],
        [0., 1., 2., 3.]])

In [53]:
y.backward(torch.ones_like(y),)

In [54]:
W.grad

tensor([[0., 2., 4., 6.],
        [0., 2., 4., 6.],
        [0., 2., 4., 6.]])

In [178]:
def func(x):
    z = x**2 
    u = 2*x
    y = z  + u + 1
    return y


In [56]:
y = autograd.functional.jacobian(func, torch.tensor(1.0))
y

tensor(4.)

In [57]:
y = autograd.functional.jacobian(func, torch.tensor(0.))
y

tensor(2.)

In [58]:
y = autograd.functional.jacobian(func, torch.tensor([0., 1.]))
y # [dy1/dx, dy2/dx].T

tensor([[2., 0.],
        [0., 4.]])

In [59]:
y = autograd.functional.jacobian(lambda x: torch.dot(x, x), 
                                torch.tensor([0., 1.]))
y

tensor([0., 2.])

In [179]:
def hess_fn(x):
    return x.t() @ torch.eye(x.size(0)) @ x

In [61]:
autograd.functional.hessian(hess_fn, torch.tensor([1.0, 2.0]))

tensor([[2., 0.],
        [0., 2.]])

In [62]:
autograd.functional.jacobian(hess_fn, torch.tensor([1.0, 2.0]))

tensor([2., 4.])

In [63]:
# (output, J*v)
autograd.functional.jvp(hess_fn, torch.tensor([1.0, 2.0]), v=torch.tensor([1.0, 2.0]))

(tensor(5.), tensor(10.))

In [64]:
autograd.functional.vjp(hess_fn, torch.tensor([1.0, 2.0]), v=torch.tensor(2.))

(tensor(5.), tensor([4., 8.]))

In [65]:
def adder(x, y):
    print(x)
    print(y)
    return 2 * x + 3 *y


In [66]:
autograd.functional.jacobian(adder, (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])))

tensor([1., 2., 3.], requires_grad=True)
tensor([4., 5., 6.], requires_grad=True)


(tensor([[2., 0., 0.],
         [0., 2., 0.],
         [0., 0., 2.]]),
 tensor([[3., 0., 0.],
         [0., 3., 0.],
         [0., 0., 3.]]))

In [67]:
autograd.functional.jvp(adder, (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])), v=(torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 1.0, 1.0])))

tensor([1., 2., 3.], requires_grad=True)
tensor([4., 5., 6.], requires_grad=True)


(tensor([14., 19., 24.]), tensor([3., 3., 3.]))

In [68]:
autograd.functional.jvp(adder, (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])), v=(torch.tensor([1.0] * 3), torch.tensor([1.0, 1.0, 1.0])))

tensor([1., 2., 3.], requires_grad=True)
tensor([4., 5., 6.], requires_grad=True)


(tensor([14., 19., 24.]), tensor([5., 5., 5.]))

In [69]:
def square(x):
    return x ** 3 + x + sum(x)
jac = autograd.functional.jacobian(square, torch.tensor([1.0, 2.0]))
jac

tensor([[ 5.,  1.],
        [ 1., 14.]])

In [70]:
# J = dy/dx.T
# dy_i = 3 *x**2 + 1 + I.roll(1)
# dy_1 = [5  1]
# dy_2 = [1  14]
# jvp = inner_prod(J, v)
out, jvp = autograd.functional.jvp(square, torch.tensor([1.0, 2.0]), v=torch.ones(2))
out, jvp

(tensor([ 5., 13.]), tensor([ 6., 15.]))

In [71]:
x = jac
x

tensor([[ 5.,  1.],
        [ 1., 14.]])

In [72]:
y = torch.ones(2)

In [73]:
x @ y

tensor([ 6., 15.])

In [74]:
jac.inner(y)

tensor([ 6., 15.])

In [75]:
def exp_func(x):
    return x.exp().sum()
autograd.functional.jacobian(exp_func, torch.tensor([1.0, 2.0]))

tensor([2.7183, 7.3891])

In [76]:
exp_func(torch.tensor([1.0, 2.0]))

tensor(10.1073)

In [77]:
autograd.functional.jacobian(exp_func, torch.tensor([[1.0, 2.0], [3.0, 4.0]]))

tensor([[ 2.7183,  7.3891],
        [20.0855, 54.5981]])

In [78]:
autograd.functional.vjp(exp_func, torch.tensor([[1.0, 2.0], [3.0, 4.0]]), torch.tensor(1.0))

(tensor(84.7910),
 tensor([[ 2.7183,  7.3891],
         [20.0855, 54.5981]]))

In [181]:
def exp_func(x):
    return x.exp().sum(dim=1)
# J = block(dy / dX)
autograd.functional.jacobian(exp_func, torch.tensor([[1.0, 2.0], [3.0, 4.0]]))

tensor([[[ 2.7183,  7.3891],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.0000],
         [20.0855, 54.5981]]])

In [182]:
autograd.functional.jacobian(exp_func, torch.tensor([[1.0, 2.0], [3.0, 4.0]]), create_graph=True)

tensor([[[ 2.7183,  7.3891],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.0000],
         [20.0855, 54.5981]]], grad_fn=<ViewBackward0>)

In [183]:
#  v = dl / dy
# vjp_i = v[0] * dy1 / dx1 + v[1] * dy2 / dx1
autograd.functional.vjp(exp_func, torch.tensor([[1.0, 2.0], [3.0, 4.0]]), v=torch.tensor([1.0, 1.0]))

(tensor([10.1073, 74.6837]),
 tensor([[ 2.7183,  7.3891],
         [20.0855, 54.5981]]))

In [184]:
#  v = dl / dy
# vjp_i = v[0] * dy1 / dx1 + v[1] * dy2 / dx1
autograd.functional.vjp(exp_func, torch.tensor([[1.0, 2.0], [3.0, 4.0]]), v=torch.tensor([0, 1.0]))

(tensor([10.1073, 74.6837]),
 tensor([[ 0.0000,  0.0000],
         [20.0855, 54.5981]]))

In [82]:
torch.manual_seed(2)
N = 10
D = 5
x = torch.randn(N, D)
y = torch.randn(N, D)
res = (x * y).sum(dim=1)
res

tensor([ 0.9642,  0.4815, -1.2837, -1.4077,  0.7736, -3.2275,  1.5103,  4.7469,
        -0.0962, -1.6041])

In [83]:
def batched_dot(x, y):
    print(x, y)
    return (x * y).sum()

torch.manual_seed(2)
N = 10
D = 5
x = torch.randn(N, D)
y = torch.randn(N, D)
res = torch.vmap(batched_dot)(x, y)
res

BatchedTensor(lvl=1, bdim=0, value=
    tensor([[-1.0408,  0.9166, -1.3042, -1.1097, -1.2188],
            [ 1.1676, -1.0574, -0.1188, -0.9078,  0.3452],
            [-0.5713, -0.2351,  1.0076, -0.7529, -0.2250],
            [-0.4327, -1.5071, -0.4586, -0.8480,  0.5266],
            [ 0.0299, -0.0498,  1.0651,  0.8860,  0.4640],
            [-0.4986,  0.1289,  2.7631,  0.1405,  1.1191],
            [ 0.3152,  1.7528, -0.7650,  1.8299, -1.6036],
            [ 1.8493,  0.0447,  1.5853, -0.5912,  1.1312],
            [ 0.9466, -1.7669, -0.5833, -0.4407, -1.9791],
            [ 0.7787, -0.7749, -0.1398, -0.3467,  0.0873]])
) BatchedTensor(lvl=1, bdim=0, value=
    tensor([[-1.4702, -0.2134, -0.8707,  1.6159, -0.2356],
            [ 0.9444,  0.5461, -1.3575,  0.1757, -0.1319],
            [-0.2735,  0.3355,  0.1885,  2.1432, -0.2779],
            [ 0.5511, -0.0625,  0.8269,  0.5599, -0.7776],
            [ 0.3339,  0.1759,  0.4863,  0.2769,  0.0195],
            [ 1.1213, -1.4873, -0.2043, 

tensor([ 0.9642,  0.4815, -1.2837, -1.4077,  0.7736, -3.2275,  1.5103,  4.7469,
        -0.0962, -1.6041])

In [84]:
@torch.jit.script
def batched_dot(x, y):
    return (x * y).sum()

torch.manual_seed(2)
N = 10
D = 5
x = torch.randn(N, D)
y = torch.randn(N, D)
res = torch.vmap(batched_dot, in_dims=1)(x, y,)
res

tensor([ 4.0021,  2.0832,  2.7936, -4.8749, -3.1466])

In [85]:
@torch.jit.script
def batched_dot(x, y):
    return (x * y).sum()

def mult_dot(x, y):
    return (x * y).sum(dim=1)
torch.manual_seed(2)
N = 40000
D = 5000
x = torch.randn(N, D)
y = torch.randn(N, D)

In [86]:
%timeit mult_dot(x, y)

360 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [87]:
vmap = torch.vmap(batched_dot, in_dims=1)

In [88]:
%timeit  vmap(x, y,)

393 ms ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [89]:
%timeit  vmap(x, y,)

375 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [90]:
%timeit mult_dot(x, y)

369 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [91]:
%timeit  vmap(x, y,)

398 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [92]:
x = torch.tensor(1.).requires_grad_()
y = x ** 2
y

tensor(1., grad_fn=<PowBackward0>)

In [93]:
autograd.grad(y, x, retain_graph=True)

(tensor(2.),)

In [94]:
autograd.grad(y, x)

(tensor(2.),)

In [95]:
y.requires_grad

True

In [96]:
z = y 

In [97]:
x = torch.tensor(1.).requires_grad_()
u = torch.tensor(2.0).requires_grad_()
y = x ** 2 + u
z = 3 * y  + 1
autograd.grad(z, x)

(tensor(6.),)

In [98]:
x = torch.tensor(1.).requires_grad_()
u = torch.tensor(2.0).requires_grad_()
y = x ** 2 + u
z = 3 * y  + 1
autograd.grad(z, (x, u))

(tensor(6.), tensor(3.))

In [99]:
x = torch.tensor(1.).requires_grad_()
u = torch.tensor(2.0).requires_grad_()
y = x ** 2 + u
z = 3 * y  + 1
autograd.grad(z, (x, u), grad_outputs=(torch.tensor(3.0), torch.tensor(4.0), torch.tensor(4.0)))

(tensor(18.), tensor(9.))

In [100]:
x = torch.tensor([1., 2.]).requires_grad_()
u = torch.tensor(2.0).requires_grad_()
y = x ** 2 + u  # [x1 ** 2 + u   x2 ** 2 + u]
z = (3 * y  + 1).sum()
# dz / du = 3.0 * [3 3] * [1, 1].T
autograd.grad(z, (x, u), grad_outputs=(torch.tensor(3.0), torch.tensor(4.0), torch.tensor(4.0)))

(tensor([18., 36.]), tensor(18.))

In [101]:
x = torch.tensor([1., 2.]).requires_grad_()
u = torch.tensor(2.0).requires_grad_()
y = x ** 2 + u  
z = y.sum()
autograd.grad(z, (x, u), grad_outputs=(torch.tensor(1.0)))

(tensor([2., 4.]), tensor(2.))

In [102]:
x = torch.tensor([1., 2.]).requires_grad_()
u = torch.tensor(2.0).requires_grad_()
y = x ** 2 + u  # [x1 + u   x2 + u]
z1 = (3 * y  + 1).sum() # dz1/dx = [6*x1 6*x2], [6]
z2 = y.sum() # dz1/dx = [2*x1 2*x2], [2]
autograd.grad((z1, z2), (x, u), grad_outputs=(torch.tensor(3.0), torch.tensor(1.0),))
# aggeration of grad wrt to inputs: x.grad = dz1 / dx + dz2 / dx

(tensor([20., 40.]), tensor(20.))

In [103]:
x = torch.tensor([1., 2.]).requires_grad_()
u = torch.tensor(2.0).requires_grad_()
y = x ** 2 + u  # [x1 + u   x2 + u]
z1 = (3 * y  + 1).sum()
z2 = y.sum()
autograd.grad((z1, z2), (x, u), grad_outputs=(torch.tensor(3.0), torch.tensor(1.0),), create_graph=True)
# aggeration of grad wrt to inputs: x.grad = dz1 / dx + dz2 / dx

(tensor([20., 40.], grad_fn=<MulBackward0>), tensor(20.))

In [104]:
# calculate jacobian from Rn -> R
D = 10
x = torch.arange(D).float().requires_grad_()
y = sum(x)
autograd.grad(y, x)

(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),)

In [105]:
# calculate jacobian from Rn -> Rn
D = 10
x = torch.arange(D).float().requires_grad_()
y = x
v = torch.eye(D)
def get_grad(v):
    return autograd.grad(x, x, v)
torch.vmap(get_grad)(v)


(tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]),)

In [177]:
# calculate jacobian from Rn -> Rn
D = 10
x = torch.arange(D).float().requires_grad_()
y = x + x.sum()
v = torch.eye(D)
def get_grad(v):
    print(v.size(), y.size())
    return autograd.grad(y, x, v) # vjp
                         
torch.vmap(get_grad)(v)

torch.Size([10]) torch.Size([10])


(tensor([[2., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 2., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 2., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 2., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 2., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 2., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 2., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 2., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 2., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 2.]]),)

In [107]:
y# calculate jacobian from Rn -> Rn
def multiply(x, y):
    return x * y
torch.manual_seed(5)
x = torch.randn(4, 2, 5)
y = torch.zeros(2, 5)
y[:, -1] = 1.
x, torch.vmap(multiply, in_dims=(0, None))(x, y)

(tensor([[[ 1.8423,  0.5189, -1.7119, -1.7014,  2.0194],
          [-0.2686, -0.1307, -1.4374,  0.3908, -0.0190]],
 
         [[-1.3527, -0.7308,  0.9879, -0.4194, -0.5849],
          [-0.7823,  2.7799,  1.2220, -0.3364, -0.9651]],
 
         [[-0.1297, -0.6018,  0.1450, -0.1498,  0.8183],
          [-0.6633,  0.2653, -1.5660, -1.6407, -0.0197]],
 
         [[ 0.2278, -0.3985, -1.0365,  0.6705, -0.1777],
          [ 0.4314,  1.2417,  2.1503, -2.2281, -1.2897]]]),
 tensor([[[ 0.0000,  0.0000, -0.0000, -0.0000,  2.0194],
          [-0.0000, -0.0000, -0.0000,  0.0000, -0.0190]],
 
         [[-0.0000, -0.0000,  0.0000, -0.0000, -0.5849],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.9651]],
 
         [[-0.0000, -0.0000,  0.0000, -0.0000,  0.8183],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0197]],
 
         [[ 0.0000, -0.0000, -0.0000,  0.0000, -0.1777],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -1.2897]]]))

In [108]:
# calculate jacobian from Rn -> Rn
def multiply(x, y):
    return x * y
torch.manual_seed(5)
x = torch.randn(4, 2, 5)
y = torch.zeros(2, 5)
y[:, -1] = 1.
z = torch.vmap(multiply, in_dims=(0, None), out_dims=-1)(x, y)
x, z, z.shape

(tensor([[[ 1.8423,  0.5189, -1.7119, -1.7014,  2.0194],
          [-0.2686, -0.1307, -1.4374,  0.3908, -0.0190]],
 
         [[-1.3527, -0.7308,  0.9879, -0.4194, -0.5849],
          [-0.7823,  2.7799,  1.2220, -0.3364, -0.9651]],
 
         [[-0.1297, -0.6018,  0.1450, -0.1498,  0.8183],
          [-0.6633,  0.2653, -1.5660, -1.6407, -0.0197]],
 
         [[ 0.2278, -0.3985, -1.0365,  0.6705, -0.1777],
          [ 0.4314,  1.2417,  2.1503, -2.2281, -1.2897]]]),
 tensor([[[ 0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000],
          [ 2.0194, -0.5849,  0.8183, -0.1777]],
 
         [[-0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0190, -0.9651, -0.0197, -1.2897]]]),
 torch.Size([2, 5, 4]))

In [109]:
# calculate jacobian from Rn -> Rn
def multiply(x, y):
    return x * y
torch.manual_seed(5)
x = torch.randn(4, 2, 5)
y = torch.zeros(2, 5)
y[:, -1] = 1.
z = torch.vmap(multiply, in_dims=(0, None), out_dims=2)(x, y)
x, z, z.shape

(tensor([[[ 1.8423,  0.5189, -1.7119, -1.7014,  2.0194],
          [-0.2686, -0.1307, -1.4374,  0.3908, -0.0190]],
 
         [[-1.3527, -0.7308,  0.9879, -0.4194, -0.5849],
          [-0.7823,  2.7799,  1.2220, -0.3364, -0.9651]],
 
         [[-0.1297, -0.6018,  0.1450, -0.1498,  0.8183],
          [-0.6633,  0.2653, -1.5660, -1.6407, -0.0197]],
 
         [[ 0.2278, -0.3985, -1.0365,  0.6705, -0.1777],
          [ 0.4314,  1.2417,  2.1503, -2.2281, -1.2897]]]),
 tensor([[[ 0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000],
          [ 2.0194, -0.5849,  0.8183, -0.1777]],
 
         [[-0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0190, -0.9651, -0.0197, -1.2897]]]),
 torch.Size([2, 5, 4]))

In [110]:
# calculate jacobian from Rn -> Rn
def multiply(x, y):
    return x * y
torch.manual_seed(5)
x = torch.randn(4, 2, 5)
y = torch.zeros(2, 5)
y[:, -1] = 1.
z = torch.vmap(multiply, in_dims=(0, None), out_dims=1)(x, y)
x, z, z.shape

(tensor([[[ 1.8423,  0.5189, -1.7119, -1.7014,  2.0194],
          [-0.2686, -0.1307, -1.4374,  0.3908, -0.0190]],
 
         [[-1.3527, -0.7308,  0.9879, -0.4194, -0.5849],
          [-0.7823,  2.7799,  1.2220, -0.3364, -0.9651]],
 
         [[-0.1297, -0.6018,  0.1450, -0.1498,  0.8183],
          [-0.6633,  0.2653, -1.5660, -1.6407, -0.0197]],
 
         [[ 0.2278, -0.3985, -1.0365,  0.6705, -0.1777],
          [ 0.4314,  1.2417,  2.1503, -2.2281, -1.2897]]]),
 tensor([[[ 0.0000,  0.0000, -0.0000, -0.0000,  2.0194],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.5849],
          [-0.0000, -0.0000,  0.0000, -0.0000,  0.8183],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.1777]],
 
         [[-0.0000, -0.0000, -0.0000,  0.0000, -0.0190],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.9651],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0197],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -1.2897]]]),
 torch.Size([2, 4, 5]))

In [111]:
class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        print("Forward ")
        
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        print("Setup context: ")
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        print("grad: ", result)
        return result



In [112]:
# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

my_cube(torch.randn(3, requires_grad=True))

Forward 
Setup context: 


tensor([-7.4607, -0.1744,  0.0792], grad_fn=<MyCubeBackward>)

In [113]:
my_cube(torch.randn(3, requires_grad=True)).sum().backward()

Forward 
Setup context: 
grad:  tensor([15.3814,  0.1221,  3.4200])


In [114]:
class LinearFunction(Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(ctx, input, weight, bias=None):
        # The forward pass can use ctx.
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

In [115]:
def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    # Note that forward does not take ctx
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        print("forward ", result, ind)
        
        # Any intermediates to be saved in backward must be returned as
        # outputs.
        return (
            # The desired output
            torch.tensor(result, device=device),
            # intermediate to save for backward
            torch.tensor(ind, device=device),
            # intermediate to save for backward
            torch.tensor(ind_inv, device=device),
        )

    # setup_context is responsible for calling methods and/or assigning to
    # the ctx object. Please do not do additional compute (e.g. add
    # Tensors together) in setup_context.
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        # Note that output is whatever you returned from forward.
        # If you returned multiple values, then output is a Tuple of multiple values.
        # If you returned a single Tensor, then output is a Tensor.
        # If you returned a Tuple with a single Tensor, then output is a
        # Tuple with a single Tensor.
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        print('setup context sort ', ind)
        # Tensors must be saved via ctx.save_for_backward. Please do not
        # assign them directly onto the ctx object.
        ctx.save_for_backward(ind, ind_inv)
        # Non-tensors may be saved by assigning them as attributes on the ctx object.
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        # For the autograd.Function to be arbitrarily composable with function
        # transforms, all staticmethod other than forward and setup_context
        # must be implemented in a "transformable" way; that is, they must
        # only consist of PyTorch operations or autograd.Function.
        #
        # For example, this allows us to do double backwards and/or compute
        # second order gradients.
        #
        # We've written the backward pass of NumpySort in terms of another
        # autograd.Function, NumpyTake.
        ind, ind_inv = ctx.saved_tensors
        print("backward sort ", grad_output)
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        print("forward take ")
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        print('setup context take ')
        
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        print('backward take ')
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

In [116]:
def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    print("done forward")
    return result

In [117]:
torch.manual_seed(10)
x = torch.randn(2, 3)
x

tensor([[-0.6014, -1.0122, -0.3023],
        [-1.2277,  0.9198, -0.3485]])

In [118]:
y = numpy_sort(x)
y

forward  [[-1.0122098  -0.6013928  -0.30226925]
 [-1.2276864  -0.3484694   0.91982824]] [[1 0 2]
 [0 2 1]]
setup context sort  tensor([[1, 0, 2],
        [0, 2, 1]])
done forward


tensor([[-1.0122, -0.6014, -0.3023],
        [-1.2277, -0.3485,  0.9198]])

In [119]:
torch.manual_seed(10)
x = torch.randn(2, 3)
x
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
grad_x


forward  [[-1.0122098  -0.6013928  -0.30226925]
 [-1.2276864  -0.3484694   0.91982824]] [[1 0 2]
 [0 2 1]]
setup context sort  tensor([[1, 0, 2],
        [0, 2, 1]])
setup context sort  GradTrackingTensor(lvl=1, value=
    tensor([[1, 0, 2],
            [0, 2, 1]])
)
done forward
backward sort  GradTrackingTensor(lvl=1, value=
    tensor([[1., 1., 1.],
            [1., 1., 1.]])
)
forward take 
setup context take 
setup context take 


tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [120]:
def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        return (
            torch.tensor(result, device=device),
            torch.tensor(ind, device=device),
            torch.tensor(ind_inv, device=device),
        )

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

    # The signature of the vmap staticmethod is:
    # vmap(info, in_dims: Tuple[Optional[int]], *args)
    # where *args is the same as the arguments to `forward`.
    @staticmethod
    def vmap(info, in_dims, x, dim):
        # For every input (x and dim), in_dims stores an Optional[int]
        # that is:
        # - None if the input is not being vmapped over or if the input
        #   is not a Tensor
        # - an integer if the input is being vmapped over that represents
        #   the index of the dimension being vmapped over.
        x_bdim, _ = in_dims

        # A "vmap rule" is the logic of how to perform the operation given
        # inputs with one additional dimension. In NumpySort, x has an
        # additional dimension (x_bdim). The vmap rule is simply
        # to call NumpySort again but pass it a different `dim`.
        x = x.movedim(x_bdim, 0)
        # Handle negative dims correctly
        dim = dim if dim >= 0 else dim + x.dim() - 1
        result = NumpySort.apply(x, dim + 1)

        # The vmap rule must return a tuple of two things
        # 1. the output. Should be the same amount of things
        #    as returned by the forward().
        # 2. one Optional[int] for each output specifying if each output
        # is being vmapped over, and if so, the index of the
        # dimension being vmapped over.
        #
        # NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
        # dimension being vmapped over to the front of `x`, that appears at
        # dimension 0 of all outputs.
        # The return is (output, out_dims) -- output is a tuple of 3 Tensors
        # and out_dims is a Tuple of 3 Optional[int]
        return NumpySort.apply(x, dim + 1), (0, 0, 0)

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

    @staticmethod
    def vmap(info, in_dims, x, ind, ind_inv, dim):
        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims

        # The strategy is: expand {x, ind, ind_inv} to all have the dimension
        # being vmapped over.
        # Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).

        # Handle negative dims by wrapping them to be positive
        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
        dim = dim if dim >= 0 else dim + logical_dim

        def maybe_expand_bdim_at_front(x, x_bdim):
            if x_bdim is None:
                return x.expand(info.batch_size, *x.shape)
            return x.movedim(x_bdim, 0)

        # If the Tensor doesn't have the dimension being vmapped over,
        # expand it out. Otherwise, move it to the front of the Tensor
        x = maybe_expand_bdim_at_front(x, x_bdim)
        ind = maybe_expand_bdim_at_front(ind, ind_bdim)
        ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)

        # The return is a tuple (output, out_dims). Since output is a Tensor,
        # then out_dims is an Optional[int] (instead of being a Tuple).
        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
result

tensor([[-1.1920, -0.9582, -0.8692],
        [-0.9373, -0.8465,  1.9050]])

In [186]:
x = torch.arange(10).float()
u = torch.tensor(2.0)
def grad_fn(x, u):
    print(x)
    y = x**2 + u
    print("y : ", y)
    return y.sum()
    
    
grad_x = torch.func.grad(grad_fn, (0,1))(x, u)
grad_x

GradTrackingTensor(lvl=1, value=
    tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
)
y :  GradTrackingTensor(lvl=1, value=
    tensor([ 2.,  3.,  6., 11., 18., 27., 38., 51., 66., 83.])
)


(tensor([ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.]), tensor(10.))

In [187]:
autograd.functional.jacobian(grad_fn, (x, u))

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], requires_grad=True)
y :  tensor([ 2.,  3.,  6., 11., 18., 27., 38., 51., 66., 83.],
       grad_fn=<AddBackward0>)


(tensor([ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.]), tensor(10.))

In [122]:
x = torch.arange(10).float()
u = torch.tensor(2.0)
def grad_fn(x, u):
    print(x)
    y = x**2 + u
    print("y : ", y)
    return y.sum()
    
    
grad_x = torch.func.grad(grad_fn, (1, ))(x, u)
grad_x

GradTrackingTensor(lvl=1, value=
    tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
)
y :  GradTrackingTensor(lvl=1, value=
    tensor([ 2.,  3.,  6., 11., 18., 27., 38., 51., 66., 83.])
)


(tensor(10.),)

In [123]:
from torch.func import grad, vmap
batch_size, feature_size = 3, 5
def model(weights, feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()
def compute_loss(weights, example, target):
    y = model(weights, example)
    return ((y - target) ** 2).mean()  # MSELoss
weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights, examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
grad_weight_per_example

tensor([[ 3.6802, -2.3130, -1.7118,  1.0709,  1.1529],
        [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
        [36.6034, 41.5923,  3.5197, 13.4378,  9.3961]], grad_fn=<MulBackward0>)

In [124]:
from torch.func import grad
def my_loss_func(y, y_pred):
   loss_per_sample = (0.5 * y_pred - y) ** 2
   loss = loss_per_sample.mean()
   return loss, (y_pred, loss_per_sample, loss)

torch.manual_seed(0)
fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
y_true = torch.rand(4)
y_preds = torch.rand(4, requires_grad=True)
out = fn(y_true, y_preds)
# > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample, loss))
out

((tensor([ 0.1713,  0.2256, -0.0783, -0.1581], grad_fn=<NegBackward0>),
  tensor([-0.0856, -0.1128,  0.0391,  0.0790], grad_fn=<MulBackward0>)),
 (tensor([0.3074, 0.6341, 0.4901, 0.8964], requires_grad=True),
  tensor([0.1173, 0.2036, 0.0245, 0.1000], grad_fn=<PowBackward0>),
  tensor(0.1113, grad_fn=<MeanBackward0>)))

In [125]:
def f(x):
    with torch.no_grad():
        c = x ** 2
    return x - c

with torch.no_grad():
    x = torch.tensor(2.0, requires_grad=True)
    print(grad(f)(x))

tensor(1.)
