# Example-03: Drift element factory

In [1]:
# In this example drift factory is illustrated

In [2]:
import jax
from jax import jit

from elementary.util import ptc
from elementary.util import beta
from elementary.drift import drift

jax.numpy.set_printoptions(linewidth=256, precision=12)

In [3]:
# Set data type

jax.config.update("jax_enable_x64", True)

In [4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)

In [5]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([0.0, 0.0, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
x = jax.numpy.hstack([qs, ps])

In [6]:
# Define drift and compare with ptc (non-parametric)

gamma = 10**4
element = jit(drift(length=1.0, beta=beta(gamma), gamma=gamma))

print(res := element(x))
print(ref := ptc(x, 'drift, l=1.0'))
print(jax.numpy.allclose(res, ref))

[ 0.00100010101   0.00100010101   0.000998999797  0.001           0.001          -0.0001        ]
[ 0.00100010101   0.00100010101   0.000998999798  0.001           0.001          -0.0001        ]
True


In [7]:
# Parameters (e.g. length) set on initialization are saved in the 'rc' attribute

print(element.rc)

{'length': 1.0}


In [8]:
# Define drift (parametric)

gamma = 10**4
length = jax.numpy.float64(1.0)
element = jit(drift(beta=beta(gamma), gamma=gamma))


print(element(x, length))
print(jax.jacrev(element, argnums=-1)(x, length))

[ 0.00100010101   0.00100010101   0.000998999797  0.001           0.001          -0.0001        ]
[ 1.000101010303e-03  1.000101010303e-03 -1.000202535512e-06 -4.365361716197e-20  1.638774596422e-20 -2.007338881283e-19]
