In [1]:
from typing import Union

import numpy as np
import torch
from scipy import constants

from cheetah import Species
from cheetah.utils import compute_relativistic_factors
from cheetah.track_methods import rotation_matrix

In [2]:
REST_ENERGY_PYTHON = (
    constants.electron_mass * constants.speed_of_light**2 / constants.elementary_charge
)  # electron mass
REST_ENERGY_TORCH = torch.tensor(
    constants.electron_mass * constants.speed_of_light**2 / constants.elementary_charge
)  # electron mass

In [3]:
my_length = 1.0
my_k1 = 4.2
my_hx = 0.5
my_tilt = 0.1

my_length_th = torch.tensor(my_length)
my_k1_th = torch.tensor(my_k1)
my_hx_th = torch.tensor(my_hx)
my_tilt_th = torch.tensor(my_tilt)

my_species = Species("electron")

mass_eV = my_species.mass_eV.item()

In [4]:
def base_rmatrix_05(
    length: float,
    k1: float,
    hx: float,
    tilt: float = 0.0,
    energy: float = 0.0,
    device: Union[str, torch.device] = "auto",
):
    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    gamma = energy / REST_ENERGY_PYTHON
    igamma2 = 1 / gamma**2 if gamma != 0 else 0

    beta = np.sqrt(1 - igamma2)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = np.sqrt(kx2 + 0.0j)
    ky = np.sqrt(ky2 + 0.0j)
    cx = np.cos(kx * length).real
    cy = np.cos(ky * length).real
    sy = (np.sin(ky * length) / ky).real if ky != 0 else length

    if kx != 0:
        sx = (np.sin(kx * length) / kx).real
        dx = hx / kx2 * (1.0 - cx)
        r56 = hx**2 * (length - sx) / kx2 / beta**2
    else:
        sx = length
        dx = length**2 * hx / 2
        r56 = hx**2 * length**3 / 6 / beta**2

    r56 -= length / beta**2 * igamma2

    R = torch.tensor(
        [
            [cx, sx, 0, 0, 0, dx / beta, 0],
            [-kx2 * sx, cx, 0, 0, 0, sx * hx / beta, 0],
            [0, 0, cy, sy, 0, 0, 0],
            [0, 0, -ky2 * sy, cy, 0, 0, 0],
            [sx * hx / beta, dx / beta, 0, 0, 1, r56, 0],
            [0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 1],
        ],
        dtype=torch.float32,
        device=device,
    )

    return R

In [5]:
def base_rmatrix_master(
    length: torch.Tensor,
    k1: torch.Tensor,
    hx: torch.Tensor,
    species: Species,
    tilt: torch.Tensor | None = None,
    energy: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Create a first order universal transfer map for a beamline element.

    :param length: Length of the element in m.
    :param k1: Quadrupole strength in 1/m**2.
    :param hx: Curvature (1/radius) of the element in 1/m.
    :param species: Particle species of the beam.
    :param tilt: Roation of the element relative to the longitudinal axis in rad.
    :param energy: Beam energy in eV.
    :return: First order transfer map for the element.
    """
    device = length.device
    dtype = length.dtype

    zero = torch.tensor(0.0, device=device, dtype=dtype)

    tilt = tilt if tilt is not None else zero
    energy = energy if energy is not None else species.mass_eV

    _, igamma2, beta = compute_relativistic_factors(energy, species.mass_eV)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = torch.sqrt(torch.complex(kx2, zero))
    ky = torch.sqrt(torch.complex(ky2, zero))
    cx = torch.cos(kx * length).real
    cy = torch.cos(ky * length).real
    sx = (torch.sinc(kx * length / torch.pi) * length).real
    sy = (torch.sinc(ky * length / torch.pi) * length).real
    dx = torch.where(kx2 != 0, hx / kx2 * (1.0 - cx), zero)
    r56 = torch.where(kx2 != 0, hx**2 * (length - sx) / kx2 / beta**2, zero)

    r56 = r56 - length / beta**2 * igamma2

    vector_shape = torch.broadcast_shapes(
        length.shape, k1.shape, hx.shape, tilt.shape, energy.shape
    )

    R = torch.eye(7, dtype=dtype, device=device).repeat(*vector_shape, 1, 1)
    R[..., 0, 0] = cx
    R[..., 0, 1] = sx
    R[..., 0, 5] = dx / beta
    R[..., 1, 0] = -kx2 * sx
    R[..., 1, 1] = cx
    R[..., 1, 5] = sx * hx / beta
    R[..., 2, 2] = cy
    R[..., 2, 3] = sy
    R[..., 3, 2] = -ky2 * sy
    R[..., 3, 3] = cy
    R[..., 4, 0] = sx * hx / beta
    R[..., 4, 1] = dx / beta
    R[..., 4, 5] = r56

    # Rotate the R matrix for skew / vertical magnets. The rotation only has an effect
    # if hx != 0 or k1 != 0. Note that the first if is here to improve speed when no
    # rotation needs to be applied accross all vector dimensions. The torch.where is
    # here to improve numerical stability for the vector elements where no rotation
    # needs to be applied.
    if torch.any((tilt != 0) & ((hx != 0) | (k1 != 0))):
        rotation = rotation_matrix(tilt)
        R = torch.where(
            ((tilt != 0) & ((hx != 0) | (k1 != 0))).unsqueeze(-1).unsqueeze(-1),
            rotation.transpose(-1, -2) @ R @ rotation,
            R,
        )

    return R

In [6]:
def base_rmatrix_improved(
    length: torch.Tensor,
    k1: torch.Tensor,
    hx: torch.Tensor,
    species: Species,
    tilt: torch.Tensor | None = None,
    energy: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Create a first order universal transfer map for a beamline element.

    :param length: Length of the element in m.
    :param k1: Quadrupole strength in 1/m**2.
    :param hx: Curvature (1/radius) of the element in 1/m.
    :param species: Particle species of the beam.
    :param tilt: Roation of the element relative to the longitudinal axis in rad.
    :param energy: Beam energy in eV.
    :return: First order transfer map for the element.
    """
    zero = length.new_zeros(())

    if tilt is None:
        tilt = zero
    if energy is None:
        energy = species.mass_eV

    _, igamma2, beta = compute_relativistic_factors(energy, species.mass_eV)
    ibeta2 = torch.reciprocal(torch.square(beta))

    kx2 = k1 + torch.square(hx)
    ky2 = -k1
    kx = torch.sqrt(torch.complex(kx2, zero))
    ky = torch.sqrt(torch.complex(ky2, zero))
    kLx = kx * length
    kLy = ky * length
    cx = torch.cos(kLx).real
    cy = torch.cos(kLy).real
    sx = (torch.sinc(kLx / torch.pi) * length).real
    sy = (torch.sinc(kLy / torch.pi) * length).real

    kx2_is_not_zero = kx2 != 0
    dx = torch.where(kx2_is_not_zero, hx / kx2 * (1.0 - cx), zero)
    r56 = (
        torch.where(kx2_is_not_zero, torch.square(hx) * (length - sx) / kx2, zero)
        - length * igamma2
    ) * ibeta2

    dx_ibeta2 = dx * ibeta2
    sx_hx_ibeta2 = sx * hx * ibeta2

    cx, sx, dx, cy, sy, r56, dx_ibeta2, sx_hx_ibeta2 = torch.broadcast_tensors(
        cx, sx, dx, cy, sy, r56, dx_ibeta2, sx_hx_ibeta2
    )

    R = torch.eye(7, dtype=cx.dtype, device=cx.device).expand(*cx.shape, 7, 7).clone()
    R[
        ...,
        (0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 4),
        (0, 1, 5, 0, 1, 5, 2, 3, 2, 3, 0, 1, 5),
    ] = torch.stack(
        [
            cx,
            sx,
            dx_ibeta2,
            -kx2 * sx,
            cx,
            sx_hx_ibeta2,
            cy,
            sy,
            -ky2 * sy,
            cy,
            sx_hx_ibeta2,
            dx_ibeta2,
            r56,
        ],
        dim=-1,
    )

    # Rotate the R matrix for skew / vertical magnets. The rotation only has an effect
    # if hx != 0 or k1 != 0. Note that the first if is here to improve speed when no
    # rotation needs to be applied accross all vector dimensions. The torch.where is
    # here to improve numerical stability for the vector elements where no rotation
    # needs to be applied.
    needs_rotation = (tilt != 0) & ((hx != 0) | (k1 != 0))
    if torch.any(needs_rotation):
        rotation = rotation_matrix(tilt)
        R = torch.where(
            needs_rotation.unsqueeze(-1).unsqueeze(-1),
            rotation.transpose(-1, -2) @ R @ rotation,
            R,
        )

    return R

In [7]:
%%timeit
_ = base_rmatrix_05(
    length=my_length,
    k1=my_k1,
    hx=my_hx,
    tilt=my_tilt,
    energy=mass_eV,
    device="cpu",
)

  beta = np.sqrt(1 - igamma2)


11.9 μs ± 16.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [8]:
%%timeit
_ = base_rmatrix_master(
    length=my_length_th,
    k1=my_k1_th,
    hx=my_hx_th,
    species=my_species,
    tilt=my_tilt_th,
    energy=my_species.mass_eV,
)

258 μs ± 90.9 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%timeit
_ = base_rmatrix_improved(
    length=my_length_th,
    k1=my_k1_th,
    hx=my_hx_th,
    species=my_species,
    tilt=my_tilt_th,
    energy=my_species.mass_eV,
)

157 μs ± 736 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
