In [15]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [16]:
import torch
from torch import vmap # If using PyTorch 1.8 or later
# from torch._vmap_internals import vmap # If using PyTorch 1.7
from typing import Callable

In [61]:
def jacobian(fun, x) -> torch.Tensor:
  x = x.detach().requires_grad_()
  y = fun(x)
  vjp = lambda v: torch.autograd.grad(y, x, v)[0]

  vs = torch.eye(y.numel())\
            .view(y.numel(), *y.shape)
  result = vmap(vjp)(vs)
  return result.detach()

In [62]:
# compute jacobian(g) via vmap and autograd.grad
f = lambda x: x ** 3
jacobian(f, torch.ones(3))

tensor([[3., 0., 0.],
        [0., 3., 0.],
        [0., 0., 3.]])

In [49]:
def hessian(f: Callable, x: torch.Tensor) ->  torch.Tensor:
    def first_grad(x: torch.Tensor):
        y = f(x)
        assert y.dim() == 0  # hessian only defined on scalar-valued function
        return torch.autograd.grad(y, x, create_graph=True)[0]

    return jacobian(first_grad, x)

In [50]:
# compute hessian(g) via vmap and autograd.grad
g = lambda x: (x ** 3).sum()
hessian(g, torch.ones(3))

tensor([[6., 0., 0.],
        [0., 6., 0.],
        [0., 0., 6.]])

In [32]:
torch.dot                     # [ D ] , [ D ]  -> []
vdot = torch.vmap(torch.dot)  # [N, D], [N, D] -> [N]
vvdot = torch.vmap(vdot)      # [N, D, C] [ N, D, C ] -> [N, D]
x, y = torch.ones(3, 2, 5), torch.ones(3, 2, 5)
vvdot(x, y)

tensor([[5., 5.],
        [5., 5.],
        [5., 5.]])

In [22]:
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

model = lambda features: features.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
model = torch.vmap(model)

In [23]:
model

<function __main__.<lambda>(features)>

In [24]:
model(examples)

tensor([0.0000, 0.0000, 0.6167], grad_fn=<ReluBackward0>)

In [25]:
# Setup
N = 5
f = lambda x: x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
I_N = torch.eye(N)

# vectorized gradient computation
vjp = lambda v: torch.autograd.grad(y, x, v)
jacobian = torch.vmap(vjp)

In [26]:
jacobian(I_N)

(tensor([[ 3.4172, -0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000, -1.0290, -0.0000, -0.0000,  0.0000],
         [ 0.0000, -0.0000, -1.6843, -0.0000,  0.0000],
         [ 0.0000, -0.0000, -0.0000, -0.6492,  0.0000],
         [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0559]]),)

In [49]:
class Var:
  def __init__(self, val, grad_fn=lambda: []):
    self.v, self.grad_fn = val, grad_fn
  def __add__(self, other):
    return Var(self.v + other.v,
      lambda: [(self, 1.0), (other, 1.0)])
  def __mul__(self, other):
    return Var(self.v * other.v,
      lambda: [(self, other.v), (other, self.v)])
  def grad(self, bp = 1.0, dict = {}):
    dict[self] = dict.get(self, 0) + bp
    for input, val in self.grad_fn():
        input.grad(val * bp, dict)
    return dict

In [50]:
x = Var(1.)
y = Var(1.)
f = x * x * x * x + y
f.grad()[y]

1.0

In [31]:
f.v

2.0