# Example-07: Reverse and forward differentiation modes

In [1]:
# By default jax.grad (reverse mode) is used to setuo equations for given hamiltonian
# Alternativly, jax.facrev and jax.jacfwd can be passed
# When computing derivatives, it is more optimal to nest fwd and inv methods

In [2]:
# Import

import jax
from jax import jit
from jax import grad
from jax import jacrev
from jax import jacfwd

# Function iterations

from sympint import nest
from sympint import fold

# Yoshida composition

from sympint import sequence

# Tao integrator

from sympint import tao

In [3]:
# Set data type

jax.config.update("jax_enable_x64", True)

In [4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)

In [5]:
# Define hamiltonian function

def h(qs, ps, t, b, *args):
    q_x, q_y, q_s = qs
    p_x, p_y, p_s = ps
    return p_s - jax.numpy.sqrt((p_s + 1)**2 - (p_x + b/2*q_y)**2 - (p_y - b/2*q_x)**2)

In [6]:
# Parameters and initial condition

dt = jax.numpy.float64(0.01)
t = jax.numpy.float64(0.0)
b = jax.numpy.float64(0.1)
x = jax.numpy.array([0.001, -0.001, 0.0001, 0.0, 0.0, 0.0])

In [7]:
# jax.grad

integrator = tao(h, gradient=grad)
fs = sequence(0, 2, [integrator], merge=False)
integrator = fold(fs)
step = jit(nest(100, integrator))
drev = jit(jacrev(step))
dfwd = jit(jacfwd(step))

out = step(x, dt, t, b)
out = drev(x, dt, t, b)
out = dfwd(x, dt, t, b)

In [8]:
%timeit out = step(x, dt, t, b).block_until_ready()
%timeit out = drev(x, dt, t, b).block_until_ready()
%timeit out = dfwd(x, dt, t, b).block_until_ready()

315 µs ± 9.55 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
18.5 ms ± 921 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.16 ms ± 232 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
print(step(x, dt, t, b))
print()

print(drev(x, dt, t, b))
print()

print(dfwd(x, dt, t, b))
print()

[ 9.47585374e-04 -1.04741879e-03  9.99975000e-05 -2.37093954e-06
  2.62073131e-06  0.00000000e+00]

[[ 9.97502083e-01  4.99167086e-02  0.00000000e+00  9.98334172e-01
   4.99583435e-02  5.47418783e-05]
 [-4.99167086e-02  9.97502083e-01  0.00000000e+00 -4.99583385e-02
   9.98334172e-01  4.47585392e-05]
 [-2.50000002e-06  2.50000002e-06  1.00000000e+00  5.00000004e-05
   5.00000004e-05  4.99999988e-09]
 [-2.49583543e-03 -1.24895880e-04  0.00000000e+00  9.97502082e-01
   4.99167085e-02  2.23792682e-06]
 [ 1.24895892e-04 -2.49583543e-03  0.00000000e+00 -4.99167086e-02
   9.97502082e-01 -2.73709404e-06]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  1.00000000e+00]]

[[ 9.97502083e-01  4.99167086e-02  0.00000000e+00  9.98334172e-01
   4.99583435e-02  5.47418783e-05]
 [-4.99167086e-02  9.97502083e-01  0.00000000e+00 -4.99583385e-02
   9.98334172e-01  4.47585392e-05]
 [-2.50000002e-06  2.50000002e-06  1.00000000e+00  5.00000004e-05
   5.00000004e-05  5.000

In [10]:
# jax.jacrev

integrator = tao(h, gradient=jacrev)
fs = sequence(0, 2, [integrator], merge=False)
integrator = fold(fs)
step = jit(nest(100, integrator))
drev = jit(jacrev(step))
dfwd = jit(jacfwd(step))

out = step(x, dt, t, b)
out = drev(x, dt, t, b)
out = dfwd(x, dt, t, b)

In [11]:
%timeit out = step(x, dt, t, b).block_until_ready()
%timeit out = drev(x, dt, t, b).block_until_ready()
%timeit out = dfwd(x, dt, t, b).block_until_ready()

307 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
18.7 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.63 ms ± 89.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [12]:
print(step(x, dt, t, b))
print()

print(drev(x, dt, t, b))
print()

print(dfwd(x, dt, t, b))
print()

[ 9.47585374e-04 -1.04741879e-03  9.99975000e-05 -2.37093954e-06
  2.62073131e-06  0.00000000e+00]

[[ 9.97502083e-01  4.99167086e-02  0.00000000e+00  9.98334172e-01
   4.99583435e-02  5.47418783e-05]
 [-4.99167086e-02  9.97502083e-01  0.00000000e+00 -4.99583385e-02
   9.98334172e-01  4.47585392e-05]
 [-2.50000002e-06  2.50000002e-06  1.00000000e+00  5.00000004e-05
   5.00000004e-05  4.99999988e-09]
 [-2.49583543e-03 -1.24895880e-04  0.00000000e+00  9.97502082e-01
   4.99167085e-02  2.23792682e-06]
 [ 1.24895892e-04 -2.49583543e-03  0.00000000e+00 -4.99167086e-02
   9.97502082e-01 -2.73709404e-06]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  1.00000000e+00]]

[[ 9.97502083e-01  4.99167086e-02  0.00000000e+00  9.98334172e-01
   4.99583435e-02  5.47418783e-05]
 [-4.99167086e-02  9.97502083e-01  0.00000000e+00 -4.99583385e-02
   9.98334172e-01  4.47585392e-05]
 [-2.50000002e-06  2.50000002e-06  1.00000000e+00  5.00000004e-05
   5.00000004e-05  5.000

In [13]:
# jax.jacfwd

integrator = tao(h, gradient=jacfwd)
fs = sequence(0, 2, [integrator], merge=False)
integrator = fold(fs)
step = jit(nest(100, integrator))
drev = jit(jacrev(step))
dfwd = jit(jacfwd(step))

out = step(x, dt, t, b)
out = drev(x, dt, t, b)
out = dfwd(x, dt, t, b)

In [14]:
%timeit out = step(x, dt, t, b).block_until_ready()
%timeit out = drev(x, dt, t, b).block_until_ready()
%timeit out = dfwd(x, dt, t, b).block_until_ready()

371 µs ± 30.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
17.9 ms ± 425 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.51 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [15]:
print(step(x, dt, t, b))
print()

print(drev(x, dt, t, b))
print()

print(dfwd(x, dt, t, b))
print()

[ 9.47585374e-04 -1.04741879e-03  9.99975000e-05 -2.37093954e-06
  2.62073131e-06  0.00000000e+00]

[[ 9.97502083e-01  4.99167086e-02  0.00000000e+00  9.98334172e-01
   4.99583435e-02  5.47418783e-05]
 [-4.99167086e-02  9.97502083e-01  0.00000000e+00 -4.99583385e-02
   9.98334172e-01  4.47585392e-05]
 [-2.50000002e-06  2.50000002e-06  1.00000000e+00  5.00000004e-05
   5.00000004e-05  4.99999988e-09]
 [-2.49583543e-03 -1.24895880e-04  0.00000000e+00  9.97502082e-01
   4.99167085e-02  2.23792682e-06]
 [ 1.24895892e-04 -2.49583543e-03  0.00000000e+00 -4.99167086e-02
   9.97502082e-01 -2.73709404e-06]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  1.00000000e+00]]

[[ 9.97502083e-01  4.99167086e-02  0.00000000e+00  9.98334172e-01
   4.99583435e-02  5.47418783e-05]
 [-4.99167086e-02  9.97502083e-01  0.00000000e+00 -4.99583385e-02
   9.98334172e-01  4.47585392e-05]
 [-2.50000002e-06  2.50000002e-06  1.00000000e+00  5.00000004e-05
   5.00000004e-05  5.000