# Neural networks auto-differentiation using PyTorch 2.0

In [None]:
import os
import sys

sys.path.append(os.path.join(os.path.abspath(""), ".."))

import torch
from torch import nn
from torch.func import jacrev, jacfwd, hessian, vmap

from nnbma.networks import FullyConnected

## Introduction to differentiation with PyTorch

We will use the following modules:
* `jacrev`: compute the jacobian using reverse-mode autodiff
* `jacfwd`: compute the jacobian using forward-mode autodiff
* `hessian`: compute the jacobian using both reverse and forward-mode autodiff
* `vmap`: vectorizing function used to compute the derivatives of batched inputs

The computation of high order derivative can be done by composing several times `jacrev` and/or `jacfwd`. Note that `hessian` is just a convenience module defined as `hessian(f) = jacfwd(jacrev(f))`, but that the hessian computation can be done through other compositions.

## Comparison of computation times

We will create a larger neural network in order to compare the computation times of the different ways of calculating the derivatives.

In [None]:
layers_sizes = [4, 1000, 1000, 25]
activation = nn.ELU()

huge_net = FullyConnected(
    layers_sizes,
    activation,
)

In [None]:
# Jacobian matrix
jacr = vmap(jacrev(huge_net))
jacf = vmap(jacfwd(huge_net))

# Hessian matrix
hess = vmap(hessian(huge_net))
jacrr = vmap(jacrev(jacrev(huge_net)))
jacrf = vmap(jacrev(jacfwd(huge_net)))
jacfr = vmap(jacfwd(jacrev(huge_net)))
jacff = vmap(jacfwd(jacfwd(huge_net)))

In [None]:
n_batchs = 100
x = torch.normal(0, torch.ones(n_batchs, layers_sizes[0]))
x_numpy = x.numpy()

As a comparison, here's the evaluation time of the network and the time needed to make the conversion between numpy and torch:

In [None]:
%%timeit
huge_net.evaluate(x.numpy())

In [None]:
%%timeit
torch.from_numpy(x_numpy)

In [None]:
%%timeit
x.numpy()

Jacobian computation

In [None]:
%%timeit
jacr(x)

In [None]:
%%timeit
jacf(x)

Hessian computation

In [None]:
%%timeit
hess(x)

In [None]:
%%timeit
jacrr(x)

In [None]:
%%timeit
jacrf(x)

In [None]:
%%timeit
jacfr(x)

In [None]:
%%timeit
jacff(x)