# Neural networks auto-differentiation using PyTorch 2.0

In [1]:
import os
import sys

sys.path.append(os.path.join(os.path.abspath(""), ".."))

import torch
from torch import nn
from torch.func import jacrev, jacfwd, hessian, vmap

from nnbma.networks import FullyConnected

## Introduction to differentiation with PyTorch

We will use the following modules:
* `jacrev`: compute the jacobian using reverse-mode autodiff
* `jacfwd`: compute the jacobian using forward-mode autodiff
* `hessian`: compute the jacobian using both reverse and forward-mode autodiff
* `vmap`: vectorizing function used to compute the derivatives of batched inputs

The computation of high order derivative can be done by composing several times `jacrev` and/or `jacfwd`. Note that `hessian` is just a convenience module defined as `hessian(f) = jacfwd(jacrev(f))`, but that the hessian computation can be done through other compositions.

## Comparison of computation times

We will create a larger neural network in order to compare the computation times of the different ways of calculating the derivatives.

In [2]:
n_inputs = 5
n_outputs = 25

layers_sizes = [n_inputs, 1000, 1000, n_outputs]
activation = nn.ELU()

huge_net = FullyConnected(
    layers_sizes,
    activation,
)

In [3]:
# Jacobian matrix
jacr = vmap(jacrev(huge_net.forward))
jacf = vmap(jacfwd(huge_net.forward))

# Hessian matrix
hess = vmap(hessian(huge_net.forward))
jacrr = vmap(jacrev(jacrev(huge_net.forward)))
jacrf = vmap(jacrev(jacfwd(huge_net.forward)))
jacfr = vmap(jacfwd(jacrev(huge_net.forward)))
jacff = vmap(jacfwd(jacfwd(huge_net.forward)))

In [4]:
n_batchs = 100

As a comparison, here's the evaluation time of the network and the time needed to make the conversion between numpy and torch:

In [5]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
huge_net.evaluate(x.numpy())

1.13 ms ± 92.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
x_numpy = x.numpy()
torch.from_numpy(x_numpy)

17.8 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
x.numpy()

17.4 µs ± 571 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Jacobian computation

In [8]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
jacr(x)

28.6 ms ± 3.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
jacf(x)

8.15 ms ± 521 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Hessian computation

In [10]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
hess(x)

214 ms ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
jacrr(x)

3.88 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
jacrf(x)

1.16 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
jacfr(x)

218 ms ± 5.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%%timeit
x = torch.normal(0, torch.ones(n_batchs, n_inputs))
jacff(x)

68.4 ms ± 8.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
