# Example-08: Dipole element factory (body + cylindrical multipole)

In [1]:
# In this example dipole factory is illustrated with addidional cylindrical multipoles

In [2]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import beta
from elementary.dipole import dipole_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)

In [3]:
# Set data type

jax.config.update("jax_enable_x64", True)

In [4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)

In [5]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
x = jax.numpy.hstack([qs, ps])

In [6]:
# Define generic dipole element

gamma = 10**6
element = jit(dipole_factory(multipole=True, beta=beta(gamma), gamma=gamma, order=2**1, iterations=200, settings=dict(ns=2**1)))

In [7]:
# Compare with PTC

length = jax.numpy.float64(1.0)
angle = jax.numpy.float64(0.05)

kq_n = jax.numpy.float64(-2.0)
kq_s = jax.numpy.float64(+1.5)
ks_n = jax.numpy.float64(-50.0)
ks_s = jax.numpy.float64(+75.0)
ko_n = jax.numpy.float64(-100.0)
ko_s = jax.numpy.float64(+500.0)

print(res := element(x, length, angle, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s))
print(ref := ptc(x, 'sbend', {'l': float(length), 'angle': float(angle), 'knl': f'{{0.0,{float(kq_n*length)}, {float(ks_n*length)}, {float(ko_n*length)}}}', 'ksl': f'{{0.0,{float(kq_s*length)}, {float(ks_s*length)}, {float(ko_s*length)}}}', 'kill_ent_fringe': 'true', 'kill_exi_fringe': 'true'}, gamma=gamma))
print(jax.numpy.allclose(res, ref))

[-0.017200420755 -0.003042779236  0.00151939261  -0.017822768068 -0.015620313086 -0.0001        ]
[-0.017200420755 -0.003042779236  0.00151939261  -0.017822768068 -0.015620313086 -0.0001        ]
True
