# Example-06: Sextupole element factory

In this example sextupole factory is illustrated. 

The sextupole hamiltonian is:

$
\begin{align}
& H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right)  \\
& \\
& P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s)  \\
& P_x = p_x - a_x(q_x, q_y, q_s; s)  \\
& P_y = p_y - a_y(q_x, q_y, q_s; s) \\
\\
& (a_x, a_y, a_s) = (0, 0, -\frac{1}{2!} k_n \left(\frac{q_x^3}{3} - q_x q_y^2 \right)  - \frac{1}{2!} k_s \left(\frac{q_y^3}{3}  - q_x^2 q_y\right))\\
& \varphi = 0 \\
& t = h = 0 \\
\end{align}
$

The constructed element signature is:

```python
def sextupole(qsps:Array, length:Array, kn:Array, ks:Array) -> Array:
    ...
```

Note, both `kn` and `ks` should be passed on invocation.

In [1]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import beta
from elementary.sextupole import sextupole_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)

In [2]:
# Set data type

jax.config.update("jax_enable_x64", True)

In [3]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)

In [4]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])

In [5]:
# Define generic sextupole element

gamma = 10**3
element = jit(sextupole_factory(beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))

In [6]:
# Compare with PTC

length = jax.numpy.float64(0.2)
kn = jax.numpy.float64(-50.0)
ks = jax.numpy.float64(+75.0)

print(res := element(qsps, length, kn, ks))
print(ref := ptc(qsps, 'sextupole', {'l': float(length), 'k2': float(kn), 'k2s': float(ks)}, gamma=gamma))
print(jax.numpy.allclose(res, ref))

[-0.009839311476  0.005305274154  0.000999691938  0.000595944971  0.002048184556 -0.0001        ]
[-0.009839311476  0.005305274154  0.000999691938  0.000595944971  0.002048184556 -0.0001        ]
True


In [7]:
# Differentiability

print(jacrev(element)(qsps, length, kn, ks))
print()

print(jacrev(element, 1)(qsps, length, kn, ks))
print()

[[ 9.977539230851e-01 -1.999683876014e-02  0.000000000000e+00  1.998755038515e-01 -1.333336463994e-03 -1.588878813703e-04]
 [-1.999528286417e-02  1.002381087687e+00  0.000000000000e+00 -1.333283393256e-03  2.001711272451e-01 -3.042893998520e-04]
 [ 3.567744778423e-05  1.067886761186e-05  1.000000000000e+00 -1.582182316560e-04 -3.046376164518e-04  8.131286746188e-07]
 [-2.075140977962e-02 -2.000692596676e-01  0.000000000000e+00  9.979602547484e-01 -2.001825891400e-02  2.884097120735e-05]
 [-2.000275114867e-01  2.345269216804e-02  0.000000000000e+00 -2.001650181701e-02  1.002174743629e+00  1.451547319575e-05]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000000000000e+00]]

[ 5.960059282262e-04  2.048394055811e-03 -2.275669201308e-06 -2.198365452755e-03  5.184991611235e-03  0.000000000000e+00]

