### Welcome to the functorch tutorial on Jacobians, Hessians and more - on colab! 

## Configuring your colab to run functorch 


**Getting setup** - running functorch currently requires Pytorch Nightly.  
Thus we'll go through a pytorch nightly install and build functorch. 

After that and a restart, you'll be ready to run the tutorial here on colab.

Let's setup a restart function:

In [2]:
def colab_restart():
  print("--> Restarting colab instance") 
  get_ipython().kernel.do_shutdown(True)

Next, let's confirm that we have a gpu.  
(If not, select Runtime -> Change Runtime type above,
 and select GPU under Hardward Accelerator )

In [3]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0


Let's remove the default PyTorch install:

In [4]:
!pip uninstall -y torch

Found existing installation: torch 1.10.0+cu111
Uninstalling torch-1.10.0+cu111:
  Successfully uninstalled torch-1.10.0+cu111


And install the relevant nightly version.  (this defaults to 11.1 Cuda which works on most colabs). 

In [5]:
cuda_version = "cu111" # optionally - cu113 (for 11.3) is an option as well if you have 11.3 listed above in the nvcc output. 

In [6]:
!pip install --pre torch -f https://download.pytorch.org/whl/nightly/{cuda_version}/torch_nightly.html --upgrade

Looking in links: https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
Collecting torch
  Downloading https://download.pytorch.org/whl/nightly/cu111/torch-1.12.0.dev20220216%2Bcu111-cp37-cp37m-linux_x86_64.whl (1922.9 MB)
[K     |█████████████▉                  | 834.1 MB 60.9 MB/s eta 0:00:18tcmalloc: large alloc 1147494400 bytes == 0x559246f1a000 @  0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f
[K     |█████████████████▋              | 1055.7 MB 1.3 MB/s eta 0:11:14tcmalloc: large alloc 1434370048 bytes == 0x55928b570000 @  0x7fa0182ce615 0x55920e8183bc 0x55

Let's install Ninja to accelerate the functorch building process:

In [7]:
!pip install ninja

Collecting ninja
  Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)
[?25l[K     |███                             | 10 kB 19.8 MB/s eta 0:00:01[K     |██████                          | 20 kB 8.7 MB/s eta 0:00:01[K     |█████████                       | 30 kB 7.4 MB/s eta 0:00:01[K     |████████████▏                   | 40 kB 6.8 MB/s eta 0:00:01[K     |███████████████▏                | 51 kB 4.0 MB/s eta 0:00:01[K     |██████████████████▏             | 61 kB 4.2 MB/s eta 0:00:01[K     |█████████████████████▏          | 71 kB 4.3 MB/s eta 0:00:01[K     |████████████████████████▎       | 81 kB 4.8 MB/s eta 0:00:01[K     |███████████████████████████▎    | 92 kB 3.7 MB/s eta 0:00:01[K     |██████████████████████████████▎ | 102 kB 4.0 MB/s eta 0:00:01[K     |████████████████████████████████| 108 kB 4.0 MB/s 
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.10.2.3


Next we'll install and build functorch (eta is ~6 minutes):

In [8]:
!pip install --user "git+https://github.com/pytorch/functorch.git"

Collecting git+https://github.com/pytorch/functorch.git
  Cloning https://github.com/pytorch/functorch.git to /tmp/pip-req-build-htz8t0jk
  Running command git clone -q https://github.com/pytorch/functorch.git /tmp/pip-req-build-htz8t0jk
Building wheels for collected packages: functorch
  Building wheel for functorch (setup.py) ... [?25l[?25hdone
  Created wheel for functorch: filename=functorch-0.2.0a0+2cf76f3-cp37-cp37m-linux_x86_64.whl size=21457003 sha256=be6cfe683ff09d15bac0a66e14d6d2d476a15a18273ceb0fc64a1d13fa0e37d7
  Stored in directory: /tmp/pip-ephem-wheel-cache-zrhpj6mp/wheels/b0/a9/4a/ffec50dda854c8d9f2ba21e4ffc0f2489ea97946cb1102c5ab
Successfully built functorch
Installing collected packages: functorch
Successfully installed functorch-0.2.0a0+2cf76f3


Finally - restart colab and after that - just skip directly down to the '-- Tutorial Start --' section to get underway.

In [9]:
colab_restart() 

--> Restarting colab instance


## -- Tutorial Start -- 



In [2]:
# Confirm we are ready to start.  
# If this errs, please make sure you have completed the install steps above first and then return here.

import functorch    

# Jacobians, hessians, and more: composing functorch transforms

Computing jacobians or hessians are useful in a number of non-traditional deep learning models. 

It is difficult (or annoying) to compute these quantities efficiently using a standard autodiff system like PyTorch Autograd; functorch provides ways of computing various higher-order autodiff quantities efficiently.


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)


**Comparing functorch vs the naive approach:**

Let’s start with a function that we’d like to compute the jacobian of.  This is a simple linear function with non-linear activation.



In [4]:
def predict(weight, bias, x):
    return F.linear(x, weight, bias).tanh()

Let's add some dummy data:   a weight, a bias, and a feature vector x.



In [55]:
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector

Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.
PyTorch Autograd computes vector-Jacobian products. In order to compute the full
Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row
by using a different unit vector each time.

In [56]:
def compute_jac(xp):
    jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
                     for vec in unit_vectors]
    return torch.stack(jacobian_rows)

In [57]:
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)

jacobian = compute_jac(xp)

print(jacobian.shape)
print(jacobian[0])  # show first row

torch.Size([16, 16])
tensor([-5.9625e-06,  1.9876e-05,  7.0103e-06,  1.1086e-05, -1.1939e-05,
         1.0975e-05,  8.3484e-06, -1.4599e-06, -1.9937e-05,  1.4976e-05,
        -7.4515e-06, -2.2042e-06,  5.0195e-07,  1.5267e-05, -7.8227e-06,
         6.9435e-06])


Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. 
We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:



In [58]:
from functorch import vmap, vjp

_, vjp_fn = vjp(partial(predict, weight, bias), x)

ft_jacobian, = vmap(vjp_fn)(unit_vectors)

# lets confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)

In another tutorial a composition of reverse-mode AD and vmap gave us per-sample-gradients. 
In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! 
Various compositions of vmap and autodiff transforms can give us different interesting quantities.

functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.



In [59]:
from functorch import jacrev

ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)

# confirm 
assert torch.allclose(ft_jacobian, jacobian)

Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). 

In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.

Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.




Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:

In [60]:
def get_perf(first, first_descriptor, second, second_descriptor):
  """  takes torch.benchmark objects and compares delta of second vs first. """
  faster = second.times[0]
  slower = first.times[0]
  gain = (slower-faster)/slower
  if gain < 0: gain *=-1 
  final_gain = gain*100
  print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
  

And then run the performance comparison:

In [61]:
from torch.utils.benchmark import Timer

without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)

print(no_vmap_timer)
print(with_vmap_timer)


<torch.utils.benchmark.utils.common.Measurement object at 0x7f682eb5a450>
compute_jac(xp)
  2.04 ms
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f6733f08810>
jacrev(predict, argnums=2)(weight, bias, x)
  810.29 us
  1 measurement, 500 runs , 1 thread


Lets do a relative performance comparison of the above with our get_perf function:

In [62]:
get_perf(no_vmap_timer, "without vmap",  with_vmap_timer, "vmap");

 Performance delta: 60.3299 percent improvement with vmap 


Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input.

In [63]:
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) # note the change in input via argnums params of 0,1 to map to weight and bias

## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)


We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: 

- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. 

- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. 

jacfwd and jacrev can be substituted for each other but they have different performance characteristics.

As a general rule of thumb, if you’re computing the jacobian of an 𝑅𝑁−>𝑅𝑀 function, and there are many more outputs than inputs (i.e. M > N) then jacfwd is preferred, otherwise use jacrev. 

There are exceptions to this rule, but a non-rigorous argument for this follows:

In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. 

The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.



In [64]:
from functorch import jacrev, jacfwd

First, let's benchmark with more inputs than outputs:



In [65]:
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)

bias = torch.randn(Dout)
x = torch.randn(Din)

# remember the general rule about taller vs wider...here we have a taller matrix:
print(weight.shape)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')


torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f6734014c90>
jacfwd(predict, argnums=2)(weight, bias, x)
  1.18 ms
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67344b8650>
jacrev(predict, argnums=2)(weight, bias, x)
  14.98 ms
  1 measurement, 500 runs , 1 thread


and then do a relative benchmark:

In [67]:
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );

 Performance delta: 1170.0622 percent improvement with jacrev 


and now the reverse - more outputs (M) than inputs (N):

In [71]:
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')

jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67340145d0>
jacfwd(predict, argnums=2)(weight, bias, x)
  8.99 ms
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67344b8110>
jacrev(predict, argnums=2)(weight, bias, x)
  1.03 ms
  1 measurement, 500 runs , 1 thread


and a relative perf comparison:

In [72]:
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")

 Performance delta: 775.3424 percent improvement with jacfwd 


## Hessian computation with functorch.hessian


We offer a convenience API to compute hessians: functorch.hessian. 
Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).

This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. 
Indeed, under the hood, hessian(f) is simply jacfwd(jacrev(f)).



Note: to boost performance: depending on your model, you may also want to use jacfwd(jacfwd(f)) or jacrev(jacrev(f)) instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.



In [87]:
from functorch import hessian

# lets reduce the size in order not to blow out colab. Hessians require significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)


Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())

In [89]:
torch.allclose(hess_api, hess_fwdfwd)

True

## Batch Jacobian and Batch Hessian


In the above examples we’ve been operating with a single feature vector. 

In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. 

That is, given a batch of inputs of shape (B, N) and a function that goes from R^N -> R^M, we would like a Jacobian of shape (B, M, N). 

The easiest way to do this is to use vmap:



In [91]:
batch_size = 64
Din = 31
Dout = 33

weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")

bias = torch.randn(Dout)

x = torch.randn(batch_size, Din)

weight shape = torch.Size([33, 31])


In [92]:
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)

If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:



In [93]:
def predict_with_output_summed(weight, bias, x):
    return predict(weight, bias, x).sum(0)

batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)

If you instead have a function that goes from 𝑅𝑁−>𝑅𝑀 but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:

Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.



In [95]:
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))

batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape

torch.Size([64, 33, 31, 31])