# Example-02: Hamiltonian factory (non-autonomous integration in extended phase space)

In [1]:
# In this example non-autonomous hamiltonial is integrated using midpoint and tao integrators
# For explicitly defined hamiltonian function and factory generated one

In [2]:
# Import 

import jax
from jax import Array
from jax import jit
from jax import vmap

from elementary import fold
from elementary import nest
from elementary import midpoint
from elementary import tao
from elementary import sequence

from elementary.hamiltonian import hamiltonian_factory
from elementary.hamiltonian import autonomize

jax.numpy.set_printoptions(linewidth=256, precision=12)

In [3]:
# Set data type

jax.config.update("jax_enable_x64", True)

In [4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)

In [5]:
# Set parameters

si = jax.numpy.float64(0.5)
ds = jax.numpy.float64(0.01)
kn = jax.numpy.float64(1.0)

In [6]:
# Set initial condition

qs = jax.numpy.array([0.001, -0.005, 0.0])
ps = jax.numpy.array([0.005, -0.001, 0.0001])
x = jax.numpy.hstack([qs, ps])

In [7]:
# Define non-autonomous and extended hamiltonian (explicit)

def hamiltonian(qs, ps, s, kn, *args):
    q_x, q_y, q_s = qs
    p_x, p_y, p_s = ps
    return p_s - jax.numpy.sqrt((1 + p_s)**2 - p_x**2 - p_y**2) + 1/2*kn*(1 + jax.numpy.cos(s))*(q_x**2 + q_y**2)

def extended(qs, ps, s, kn, *args):
    q_x, q_y, q_s, q_t = qs
    p_x, p_y, p_s, p_t = ps
    return p_t + (p_s - jax.numpy.sqrt((1 + p_s)**2 - p_x**2 - p_y**2) + 1/2*kn*(1 + jax.numpy.cos(q_t))*(q_x**2 + q_y**2))

In [8]:
# Set extended initial condition

Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, kn).reshape(-1)])
X = jax.numpy.hstack([Qs, Ps])

In [9]:
# Set implicit midpoint integration step

integrator = jit(fold(sequence(0, 2**1, [midpoint(extended, ns=2**1)], merge=False)))

In [10]:
# Set and compile element

element = jit(nest(10**2, integrator))
out = element(X, ds, si, kn)
print(out)

[ 4.093772217780e-03 -2.130920641490e-03 -1.157542898301e-05  1.500000000000e+00  7.828730523443e-04  5.455056721550e-03  1.000000000000e-04  9.999734128953e-01]


In [11]:
# Set tao integration step

integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))

In [12]:
# Set and compile element

element = jit(nest(10**2, integrator))
out = element(X, ds, si, kn)
print(out)

[ 4.093772217780e-03 -2.130920641487e-03 -1.157542898302e-05  1.500000000000e+00  7.828730523431e-04  5.455056721550e-03  1.000000000000e-04  9.999734128953e-01]


In [13]:
# Define non-autonomous and extended hamiltonian (factory)

def vector(qs:Array, s:Array, kn:Array, *args:Array) -> Array:
    q_x, q_y, q_s = qs
    a_x, a_y, a_s = jax.numpy.zeros_like(qs)
    a_s = - 1/2*kn*(1 + jax.numpy.cos(s))*(q_x**2 + q_y**2)
    return a_x, a_y, a_s

def scalar(qs:Array, s:Array, kn:Array, *args:Array) -> Array:
    q_x, q_y, q_s = qs
    return jax.numpy.zeros_like(s)

hamiltonian = hamiltonian_factory(vector, scalar)

extended = autonomize(hamiltonian)

In [14]:
# Set extended initial condition

Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, kn).reshape(-1)])
X = jax.numpy.hstack([Qs, Ps])

In [15]:
# Set implicit midpoint integration step

integrator = jit(fold(sequence(0, 2**1, [midpoint(extended, ns=2**1)], merge=False)))

In [16]:
# Set and compile element

element = jit(nest(10**2, integrator))
out = element(X, ds, si, kn)
print(out)

[ 4.093772217780e-03 -2.130920641490e-03 -1.157542898301e-05  1.500000000000e+00  7.828730523443e-04  5.455056721550e-03  1.000000000000e-04  9.999734128953e-01]


In [17]:
# Set tao integration step

integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))

In [18]:
# Set and compile element

element = jit(nest(10**2, integrator))
out = element(X, ds, si, kn)
print(out)

[ 4.093772217780e-03 -2.130920641487e-03 -1.157542898302e-05  1.500000000000e+00  7.828730523431e-04  5.455056721550e-03  1.000000000000e-04  9.999734128953e-01]


In [19]:
# Diffirentiability (initial condition)

jax.jacrev(element)(X, ds, si, kn)

Array([[ 2.733638829123e-01, -5.073806420139e-06,  0.000000000000e+00,  8.507003015250e-04,  7.640867440566e-01,  3.911335010850e-06, -2.099797209399e-03,  0.000000000000e+00],
       [-5.073787794666e-06,  2.733566344460e-01,  0.000000000000e+00, -1.445226318313e-03,  3.911295686010e-06,  7.640874081955e-01, -2.608613878958e-03,  0.000000000000e+00],
       [ 1.815088641299e-03,  2.754207969965e-03,  1.000000000000e+00,  3.912375857701e-06, -2.607918123372e-03, -1.844346610617e-03,  1.977826577988e-05,  0.000000000000e+00],
       [ 0.000000000000e+00,  0.000000000000e+00,  0.000000000000e+00,  1.000000000000e+00,  0.000000000000e+00,  0.000000000000e+00,  0.000000000000e+00,  0.000000000000e+00],
       [-1.169086112155e+00,  2.508455236989e-06,  0.000000000000e+00,  2.194571127834e-03,  3.903836816403e-01, -2.470173985962e-06,  2.340302050291e-03,  0.000000000000e+00],
       [ 2.508450612498e-06, -1.169086571022e+00,  0.000000000000e+00, -2.552292481525e-03, -2.470160046831e-06,  3

In [20]:
# Diffirentiability (parameter)

jax.jacrev(element, argnums=-1)(X, ds, si, kn)

Array([-1.720422619828e-03,  3.372942592063e-03, -8.584999600196e-06,  0.000000000000e+00, -3.509583906528e-03,  4.764367089934e-03,  0.000000000000e+00,  5.002086081004e-06], dtype=float64)