<a href="https://colab.research.google.com/github/cu-applied-math/appm-4600-numerics/blob/main/Demos/Ch1_AutoDiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Automatic Differentiation demo
Using Jax and PyTorch

APPM 4600

Copyright Dept of Applied Math, University of Colorado Boulder. Released under a BSD 3-clause license

Learning objectives:
1. See how to use AutoDiff using two popular frameworks (jax and PyTorch)
2. See that reverse mode is usually faster (than forward mode) for functions $f:\mathbb{R}^n \to \mathbb{R}$
3. Compare to symbolic differentiation

Further reading
- another [AutoDiff](https://github.com/cu-applied-math/SciML-Class/blob/main/Demos/AutomaticDifferentiation.ipynb) demo from CU
- [JAX](https://docs.jax.dev/en/latest/index.html)
- [PyTorch](https://pytorch.org/)

## using jax

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import jacfwd, jacrev
from jax import nn
import numpy as np

We'll make a simple function. Note that the "@" sign is matrix multiplication (for either jax or numpy), i.e., [jax.numpy.matmul](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html#jax.numpy.matmul)

In [None]:
n = int(1e1)
m = int(n/2)

# We want some arbitrary matrix -- e.g., we could do this randomly
# jax has some utilities for this, but if you don't want to learn them,
# just convert from numpy
A = np.random.randn(m,n)
x = np.random.randn(n,1)
A = jnp.array(A)
x = jnp.array(x)

def f(x):
    return jnp.sum( A @ x )

We can ask jax for:
- the gradient (of a function $f: \mathbb{R}^n \to \mathbb{R}$)
- the Jacobian (of a function $f: \mathbb{R}^n \to \mathbb{R}^m$)
  - if $m=1$ this *is* the gradient! (though sometimes there is a transpose difference...)

Gradients are always computed via **reverse mode**, but for Jacobians, you can choose either **reverse** or **forward** mode. In general, if $n > m$ you want **reverse** mode. See Jax's ["Autodiff Cookbook"](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)

In [None]:
g  = grad(f)
J1 = jacfwd(f)
J2 = jacrev(f)

g(x), J1(x), J2(x)

(Array([[ 0.10704881],
        [ 4.605774  ],
        [ 1.5944297 ],
        [ 1.575459  ],
        [ 0.11993957],
        [-3.6401691 ],
        [-1.0900795 ],
        [ 1.6887466 ],
        [-2.1338854 ],
        [-6.1179714 ]], dtype=float32),
 Array([[ 0.10704881],
        [ 4.605774  ],
        [ 1.5944297 ],
        [ 1.575459  ],
        [ 0.11993957],
        [-3.6401691 ],
        [-1.0900795 ],
        [ 1.6887466 ],
        [-2.1338854 ],
        [-6.1179714 ]], dtype=float32),
 Array([[ 0.10704881],
        [ 4.605774  ],
        [ 1.5944297 ],
        [ 1.575459  ],
        [ 0.11993957],
        [-3.6401691 ],
        [-1.0900795 ],
        [ 1.6887466 ],
        [-2.1338854 ],
        [-6.1179714 ]], dtype=float32))

Let's be slightly more interesting

In [None]:
n = int(5e3)
m = n
k = n

key = jax.random.key(seed=0)
A = jax.random.normal(key, (m,n))
B = jax.random.normal(key, (k,m))
x = jax.random.normal(key, (n,1))
# A = jnp.array(np.random.randn(m,n)) # another (slower) way to do it
# B = jnp.array(np.random.randn(k,m))
# x = jnp.array(np.random.randn(n,1))

def f(x):
    return jnp.sum( nn.sigmoid(B @ nn.sigmoid(A @ x ) ) )

# The first time we call the function, it is doing some overhead
%time y = f(x)

CPU times: user 462 ms, sys: 15.1 ms, total: 477 ms
Wall time: 239 ms


In [None]:
%time y = f(x)

CPU times: user 6.24 ms, sys: 23 µs, total: 6.26 ms
Wall time: 4.95 ms


In [None]:
g  = grad(f)
J1 = jacfwd(f)
J2 = jacrev(f)

In [None]:
%%time
y = g(x)  # reverse-mode

CPU times: user 304 ms, sys: 8.11 ms, total: 312 ms
Wall time: 234 ms


In [None]:
%%time
y = J1(x) # forward-mode

CPU times: user 1.16 s, sys: 758 ms, total: 1.92 s
Wall time: 1.08 s


In [None]:
%%time
y = J2(x) # reverse-mode

CPU times: user 554 ms, sys: 85.1 ms, total: 639 ms
Wall time: 342 ms


We seee that reverse-mode (`J2` and `g`) are faster than forward mode (`J1`). Now, naively you'd expect them to be **way** faster, but I think jax is being somewhat clever about how it does the forward mode

## Let's repeat the same thing in PyTorch
PyTorch is another popular autodiff framework

In [None]:
import torch
import matplotlib.pyplot as plt
import sys
import numpy as np
from torch.nn.functional import sigmoid
print("Torch version is", torch.__version__)
print("Numpy version is", np.__version__)
print("Python version is", sys.version)

torch.manual_seed(100)
# dtype = torch.float32 # the default
dtype = torch.float64

n = int(8e3)
m = n
k = n

A = torch.randn((m,n),dtype=dtype)
B = torch.randn((k,m),dtype=dtype)
x = torch.randn((n,1), dtype=dtype, requires_grad=True)

def f(x):
    return torch.sum( sigmoid(B @ sigmoid(A @ x ) ) )

y = f(x)

Torch version is 2.8.0+cu126
Numpy version is 2.0.2
Python version is 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]


In [None]:
%%time
if x.grad is not None:
    x.grad.data.zero_()
out = f(x)
out.backward()
y = x.grad

CPU times: user 291 ms, sys: 1.26 ms, total: 292 ms
Wall time: 314 ms


## Speelpenning Function
Taken from the longer [SciML AutoDiff](https://github.com/cu-applied-math/SciML-Class/blob/main/Demos/AutomaticDifferentiation.ipynb) example

In [None]:
import sympy
from sympy.abc import x
roots = np.linspace(0,1,10)
def g(x):
    y = 1
    for i in range(len(roots)):
        y = y * (x - roots[i])
    return y
g(x)

x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.222222222222222)*(x - 0.111111111111111)

In [None]:
gprime = sympy.diff(g(x),x)
gprime

x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.222222222222222) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.222222222222222)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.333333333333333)*(x - 0.222222222222222)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.222222222222222)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.555555555555556)*(x - 0

In [None]:
gprime.evalf(16,subs={x:.88889})

-0.0001040765432583041

That symbolic deriviative is **correct**, but it's not an efficient implementation. We can get an efficient implementation if we play around a bit, but it's not automatic.

For example, we can tell sympy to expand $g(x)$ out, and *then* differentiate:

In [None]:
sympy.expand(g(x))

x**10 - 5.0*x**9 + 10.7407407407407*x**8 - 12.962962962963*x**7 + 9.64380429812528*x**6 - 4.56104252400549*x**5 + 1.36173159391165*x**4 - 0.245182437937607*x**3 + 0.0238479488367999*x**2 - 0.000936656708416885*x

In [None]:
sympy.diff( sympy.expand(g(x)), x )

10*x**9 - 45.0*x**8 + 85.9259259259259*x**7 - 90.7407407407407*x**6 + 57.8628257887517*x**5 - 22.8052126200274*x**4 + 5.44692637564659*x**3 - 0.735547313812822*x**2 + 0.0476958976735998*x - 0.000936656708416885

## Showing that AutoDiff depends on the implementation

We'll define the function $f(x)=0$ but in a slow way

In [1]:
import torch
d   = int(4e3)

torch.manual_seed(100)
A   = torch.randn( (d,d) )

def f(x, N = 100):
    """ Implements the zero function: f(x) = 0 """
    for k in range(N):
        x = A @ x

    return torch.sum(x - x)

x   = torch.randn( (d,1), requires_grad=True )

In [3]:
%%time
with torch.no_grad():
    y = f(x)

CPU times: user 564 ms, sys: 3 µs, total: 564 ms
Wall time: 567 ms


The gradient is the all zeros vector, but as you can see from the time it takes to execute the code, it's not being that clever...

In [4]:
%%time
y = f(x)
y.backward()

CPU times: user 1.25 s, sys: 5.3 ms, total: 1.26 s
Wall time: 1.32 s
