# Example-11: Alignment errors (curved layout)

In [1]:
# In this example alignment errors for curved layout are illustrated

In [2]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import beta
from elementary.dipole import dipole_factory
from elementary.alignment import alignment_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)

In [3]:
# Set data type

jax.config.update("jax_enable_x64", True)

In [4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)

In [5]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
x = jax.numpy.hstack([qs, ps])

In [6]:
# Define generic dipole element

gamma = 10**3

length = jax.numpy.float64(1.0)
angle = jax.numpy.float64(0.05)

kq_n = jax.numpy.float64(-2.0)
kq_s = jax.numpy.float64(+1.5)
ks_n = jax.numpy.float64(-50.0)
ks_s = jax.numpy.float64(+75.0)
ko_n = jax.numpy.float64(-100.0)
ko_s = jax.numpy.float64(+500.0)

body = dipole_factory(multipole=True, beta=beta(gamma), gamma=gamma, order=2**1, iterations=200, settings=dict(ns=2**1))
xyz_entrance, xyz_exit = alignment_factory(beta=beta(gamma), gamma=gamma, flag=True)

@jit
def element(x, length, angle, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s, dx, dy, dz, wx, wy, wz):
    x = xyz_entrance(x, dx, dy, dz, wx, wy, wz)
    x = body(x, length, angle, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s)
    x = xyz_exit(x, dx, dy, dz, wx, wy, wz, length, angle)
    return x

In [7]:
# Set alignment errors

dx, dy, dz = jax.numpy.array([0.05*1, -0.02*1, 0.05*1])
wx, wy, wz = jax.numpy.array([0.005*1, -0.005*1, 0.1*1])

In [8]:
# Compare with PTC

print(res := element(x, length, angle, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s, dx, dy, dz, wx, wy, wz))
print(ref := ptc(x, 'sbend', {'l': float(length), 'angle': float(angle), 'knl': f'{{0.0,{float(kq_n*length)}, {float(ks_n*length)}, {float(ko_n*length)}}}', 'ksl': f'{{0.0,{float(kq_s*length)}, {float(ks_s*length)}, {float(ko_s*length)}}}', 'kill_ent_fringe': 'true', 'kill_exi_fringe': 'true'}, gamma=gamma, tx=float(dx), ty=float(dy), tz=float(dz), rx=float(wx), ry=float(wy), rz=float(wz)))
print(jax.numpy.allclose(res, ref))

[-6.489404978724e-02  2.078529965698e-02 -2.211225001039e-04 -1.450307157165e-01  8.291770017977e-02 -1.000000000000e-04]
[-6.489404978720e-02  2.078529965695e-02 -2.211225001033e-04 -1.450307157164e-01  8.291770017969e-02 -1.000000000000e-04]
True
