# Neural networks auto-differentiation using PyTorch 2.0

In [1]:
import os
import sys

sys.path.append(os.path.join(os.path.abspath(''), ".."))

import torch
from torch import nn
from torch.func import jacrev, jacfwd, hessian, vmap

from nnbma.networks import FullyConnected

## Introduction to differentiation with PyTorch

We will use the following modules:
* `jacrev`: compute the jacobian using reverse-mode autodiff
* `jacfwd`: compute the jacobian using forward-mode autodiff
* `hessian`: compute the jacobian using both reverse and forward-mode autodiff
* `vmap`: vectorizing function used to compute the derivatives of batched inputs

The computation of high order derivative can be done by composing several times `jacrev` and/or `jacfwd`. Note that `hessian` is just a convenience module defined as `hessian(f) = jacfwd(jacrev(f))`, but that the hessian computation can be done through other compositions.

## Comparison of computation times

We will create a larger neural network in order to compare the computation times of the different ways of calculating the derivatives.

In [2]:
layers_sizes = [4, 1000, 1000, 25]
activation = nn.ELU()

huge_net = FullyConnected(
    layers_sizes,
    activation,
)

In [3]:
# Jacobian matrix
jacr = vmap(jacrev(huge_net))
jacf = vmap(jacfwd(huge_net))

# Hessian matrix
hess = vmap(hessian(huge_net))
jacrr = vmap(jacrev(jacrev(huge_net)))
jacrf = vmap(jacrev(jacfwd(huge_net)))
jacfr = vmap(jacfwd(jacrev(huge_net)))
jacff = vmap(jacfwd(jacfwd(huge_net)))

In [4]:
n_batchs = 100
x = torch.normal(0, torch.ones(n_batchs, layers_sizes[0]))
x_numpy = x.numpy()

As a comparison, here's the evaluation time of the network and the time needed to make the conversion between numpy and torch:

In [5]:
%%timeit
huge_net.evaluate(x.numpy())

1.18 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
%%timeit
torch.from_numpy(x_numpy)

783 ns ± 62.5 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [7]:
%%timeit
x.numpy()

1.23 µs ± 58.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


Jacobian computation

In [8]:
%%timeit
jacr(x)

24 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%%timeit
jacf(x)

8 ms ± 251 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Hessian computation

In [10]:
%%timeit
hess(x)

194 ms ± 5.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%%timeit
jacrr(x)

3.48 s ± 240 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
jacrf(x)

818 ms ± 88.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%%timeit
jacfr(x)

258 ms ± 40.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%%timeit
jacff(x)

85 ms ± 5.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
