# Example-03: Drift element factory

In [1]:
# In this example drift factory is illustrated

In [2]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import beta
from elementary.drift import drift_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)

In [3]:
# Set data type

jax.config.update("jax_enable_x64", True)

In [4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)

In [5]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([0.0, 0.0, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
x = jax.numpy.hstack([qs, ps])

In [6]:
# Define generic drift element

gamma = 10**4
element = jit(drift_factory(beta=beta(gamma), gamma=gamma))

In [7]:
# Compare with PTC

print(res := element(x, 1.0))
print(ref := ptc(x, 'drift', {'l': 1.0}, gamma=gamma))
print(jax.numpy.allclose(res, ref))

[ 0.00100010101   0.00100010101   0.000998999797  0.001           0.001          -0.0001        ]
[ 0.00100010101   0.00100010101   0.000998999797  0.001           0.001          -0.0001        ]
True


In [8]:
# Differentiability

length = jax.numpy.float64(1.0)

print(jacrev(element)(x, length))
print()

print(jacrev(element, -1)(x, length))
print()

[[ 1.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000102010606e+00  1.000303061531e-06 -1.000203036215e-03]
 [ 0.000000000000e+00  1.000000000000e+00  0.000000000000e+00  1.000303061532e-06  1.000102010606e+00 -1.000203036215e-03]
 [ 0.000000000000e+00  0.000000000000e+00  1.000000000000e+00 -1.000203036215e-03 -1.000203036215e-03  2.010609153797e-06]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000000000000e+00  2.153530860291e-16  4.365802117678e-20]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00 -5.759110665868e-17  1.000000000000e+00 -1.666666809949e-20]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00 -1.030790028879e-17 -1.902883163971e-16  1.000000000000e+00]]

[ 1.000101010303e-03  1.000101010303e-03 -1.000202535512e-06 -4.365361716197e-20  1.638774596422e-20 -2.007338881283e-19]

