In [None]:
from typing import Optional, Union

import numpy as np
import torch
from scipy import constants

from cheetah import Species
from cheetah.utils import compute_relativistic_factors

In [2]:
REST_ENERGY_PYTHON = (
    constants.electron_mass * constants.speed_of_light**2 / constants.elementary_charge
)  # electron mass
REST_ENERGY_TORCH = torch.tensor(
    constants.electron_mass * constants.speed_of_light**2 / constants.elementary_charge
)  # electron mass

In [3]:
def base_rmatrix(
    length: torch.Tensor,
    k1: torch.Tensor,
    hx: torch.Tensor,
    tilt: Optional[torch.Tensor] = None,
    energy: Optional[torch.Tensor] = None,
    device: Union[str, torch.device] = "auto",
) -> torch.Tensor:
    """
    Create a universal transfer matrix for a beamline element.
    :param length: Length of the element in m.
    :param k1: Quadrupole strength in 1/m**2.
    :param hx: Curvature (1/radius) of the element in 1/m**2.
    :param tilt: Roation of the element relative to the longitudinal axis in rad.
    :param energy: Beam energy in eV.
    :param device: Device where the transfer matrix is created. If "auto", the device
        is selected automatically.
    :return: Transfer matrix for the element.
    """

    tilt = tilt if tilt is not None else torch.tensor(0.0)
    energy = energy if energy is not None else torch.tensor(0.0)

    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    gamma = energy / REST_ENERGY_TORCH
    igamma2 = 1 / gamma**2 if gamma != 0 else torch.tensor(0.0)

    beta = torch.sqrt(1 - igamma2)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = torch.sqrt(torch.complex(kx2, torch.tensor(0.0)))
    ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0)))
    cx = torch.cos(kx * length).real
    cy = torch.cos(ky * length).real
    sy = (torch.sin(ky * length) / ky).real if ky != 0 else length

    if kx != 0:
        sx = (torch.sin(kx * length) / kx).real
        dx = hx / kx2 * (1.0 - cx)
        r56 = hx**2 * (length - sx) / kx2 / beta**2
    else:
        sx = length
        dx = length**2 * hx / 2
        r56 = hx**2 * length**3 / 6 / beta**2

    r56 -= length / beta**2 * igamma2

    R = torch.eye(7, dtype=torch.float32, device=device)
    R[0, 0] = cx
    R[0, 1] = sx
    R[0, 5] = dx / beta
    R[1, 0] = -kx2 * sx
    R[1, 1] = cx
    R[1, 5] = sx * hx / beta
    R[2, 2] = cy
    R[2, 3] = sy
    R[3, 2] = -ky2 * sy
    R[3, 3] = cy
    R[4, 0] = sx * hx / beta
    R[4, 1] = dx / beta
    R[4, 5] = r56

    return R

In [4]:
%%timeit
_ = base_rmatrix(
    length=torch.tensor(1.0),
    k1=torch.tensor(4.2),
    hx=torch.tensor(0.0),
    )

143 μs ± 195 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
my_length = torch.tensor(1.0)
my_k1 = torch.tensor(4.2)
my_hx = torch.tensor(0.0)

In [6]:
%%timeit
_ = base_rmatrix(
    length=my_length,
    k1=my_k1,
    hx=my_hx,
)

137 μs ± 182 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
def base_rmatrix(
    length: float,
    k1: float,
    hx: float,
    tilt: float = 0.0,
    energy: float = 0.0,
    device: Union[str, torch.device] = "auto",
):
    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    gamma = energy / REST_ENERGY_PYTHON
    igamma2 = 1 / gamma**2 if gamma != 0 else 0

    beta = np.sqrt(1 - igamma2)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = np.sqrt(kx2 + 0.0j)
    ky = np.sqrt(ky2 + 0.0j)
    cx = np.cos(kx * length).real
    cy = np.cos(ky * length).real
    sy = (np.sin(ky * length) / ky).real if ky != 0 else length

    if kx != 0:
        sx = (np.sin(kx * length) / kx).real
        dx = hx / kx2 * (1.0 - cx)
        r56 = hx**2 * (length - sx) / kx2 / beta**2
    else:
        sx = length
        dx = length**2 * hx / 2
        r56 = hx**2 * length**3 / 6 / beta**2

    r56 -= length / beta**2 * igamma2

    R = torch.tensor(
        [
            [cx, sx, 0, 0, 0, dx / beta, 0],
            [-kx2 * sx, cx, 0, 0, 0, sx * hx / beta, 0],
            [0, 0, cy, sy, 0, 0, 0],
            [0, 0, -ky2 * sy, cy, 0, 0, 0],
            [sx * hx / beta, dx / beta, 0, 0, 1, r56, 0],
            [0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 1],
        ],
        dtype=torch.float32,
        device=device,
    )

    return R

In [8]:
%%timeit
_ = base_rmatrix(
    length=1.0,
    k1=4.2,
    hx=0.0,
    tilt=0.0,
    energy=0.0)

11.9 μs ± 51.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [None]:
def base_rmatrix(
    length: torch.Tensor,
    k1: torch.Tensor,
    hx: torch.Tensor,
    species: Species,
    tilt: torch.Tensor | None = None,
    energy: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Create a first order universal transfer map for a beamline element.

    :param length: Length of the element in m.
    :param k1: Quadrupole strength in 1/m**2.
    :param hx: Curvature (1/radius) of the element in 1/m.
    :param species: Particle species of the beam.
    :param tilt: Roation of the element relative to the longitudinal axis in rad.
    :param energy: Beam energy in eV.
    :return: First order transfer map for the element.
    """
    device = length.device
    dtype = length.dtype

    zero = torch.tensor(0.0, device=device, dtype=dtype)

    tilt = tilt if tilt is not None else zero
    energy = energy if energy is not None else zero

    _, igamma2, beta = compute_relativistic_factors(energy, species.mass_eV)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = torch.sqrt(torch.complex(kx2, zero))
    ky = torch.sqrt(torch.complex(ky2, zero))
    cx = torch.cos(kx * length).real
    cy = torch.cos(ky * length).real
    sx = (torch.sinc(kx * length / torch.pi) * length).real
    sy = (torch.sinc(ky * length / torch.pi) * length).real
    dx = torch.where(kx2 != 0, hx / kx2 * (1.0 - cx), zero)
    r56 = torch.where(kx2 != 0, hx**2 * (length - sx) / kx2 / beta**2, zero)

    r56 = r56 - length / beta**2 * igamma2

    vector_shape = torch.broadcast_shapes(
        length.shape, k1.shape, hx.shape, tilt.shape, energy.shape
    )

    R = torch.eye(7, dtype=dtype, device=device).repeat(*vector_shape, 1, 1)
    R[..., 0, 0] = cx
    R[..., 0, 1] = sx
    R[..., 0, 5] = dx / beta
    R[..., 1, 0] = -kx2 * sx
    R[..., 1, 1] = cx
    R[..., 1, 5] = sx * hx / beta
    R[..., 2, 2] = cy
    R[..., 2, 3] = sy
    R[..., 3, 2] = -ky2 * sy
    R[..., 3, 3] = cy
    R[..., 4, 0] = sx * hx / beta
    R[..., 4, 1] = dx / beta
    R[..., 4, 5] = r56

    return R

In [None]:
my_length = torch.tensor(1.0)
my_k1 = torch.tensor(4.2)
my_hx = torch.tensor(0.0)
my_species = Species("electron")

In [None]:
%%timeit
_ = base_rmatrix(
    length=my_length,
    k1=my_k1,
    hx=my_hx,
    species=my_species
    )

In [9]:
cx = 1.0
cy = 2.0
sx = 3.0
sy = 4.0
dx = 5.0
dy = 6.0
beta = 7.0
kx2 = 8.0
ky2 = 9.0
hx = 7.0
hy = 8.0
r56 = 9.0

In [10]:
%%timeit
R = torch.tensor(
        [
            [cx, sx, 0, 0, 0, dx / beta, 0],
            [-kx2 * sx, cx, 0, 0, 0, sx * hx / beta, 0],
            [0, 0, cy, sy, 0, 0, 0],
            [0, 0, -ky2 * sy, cy, 0, 0, 0],
            [sx * hx / beta, dx / beta, 0, 0, 1, r56, 0],
            [0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 1],
        ]
    )

9.85 μs ± 13.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [11]:
cx = torch.tensor(1.0)
cy = torch.tensor(2.0)
sx = torch.tensor(3.0)
sy = torch.tensor(4.0)
dx = torch.tensor(5.0)
dy = torch.tensor(6.0)
beta = torch.tensor(7.0)
kx2 = torch.tensor(8.0)
ky2 = torch.tensor(9.0)
hx = torch.tensor(7.0)
hy = torch.tensor(8.0)
r56 = torch.tensor(9.0)

In [12]:
%%timeit
R = torch.eye(7)
R[0, 0] = cx
R[0, 1] = sx
R[0, 5] = dx / beta
R[1, 0] = -kx2 * sx
R[1, 1] = cx
R[1, 5] = sx * hx / beta
R[2, 2] = cy
R[2, 3] = sy
R[3, 2] = -ky2 * sy
R[3, 3] = cy
R[4, 0] = sx * hx / beta
R[4, 1] = dx / beta
R[4, 5] = r56

48.1 μs ± 108 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
%%timeit
# make reusable 0/1 tensors on the right device/dtype, keeping graph intact
z = torch.zeros_like(cx)
o = torch.ones_like(cx)

row0 = torch.stack([cx, sx, z, z, z, dx / beta, z])
row1 = torch.stack([-kx2 * sx, cx, z, z, z, sx * hx / beta, z])
row2 = torch.stack([z, z, cy, sy, z, z, z])
row3 = torch.stack([z, z, -ky2 * sy, cy, z, z, z])
row4 = torch.stack([sx * hx / beta, dx / beta, z, z, o, r56, z])
row5 = torch.stack([z, z, z, z, z, o, z])
row6 = torch.stack([z, z, z, z, z, z, o])

R = torch.stack([row0, row1, row2, row3, row4, row5, row6])

46.1 μs ± 78.8 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
%%timeit
R = torch.eye(7)
rows = torch.tensor([0,0,0,1,1,1,2,2,3,3,4,4,4])
cols = torch.tensor([0,1,5,0,1,5,2,3,2,3,0,1,5])
vals = torch.stack([
    cx, sx, dx/beta,
    -kx2*sx, cx, sx*hx/beta,
    cy, sy,
    -ky2*sy, cy,
    sx*hx/beta, dx/beta, r56
])
R.index_put_((rows, cols), vals, accumulate=False)  # single kernel, autograd-safe

30.7 μs ± 210 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%%timeit
R = torch.eye(7)
rows = [0,0,0,1,1,1,2,2,3,3,4,4,4]
cols = [0,1,5,0,1,5,2,3,2,3,0,1,5]
vals = torch.stack([
    cx, sx, dx/beta,
    -kx2*sx, cx, sx*hx/beta,
    cy, sy,
    -ky2*sy, cy,
    sx*hx/beta, dx/beta, r56
])
R[rows, cols] = vals  # single kernel, autograd-safe


29.9 μs ± 923 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [24]:
# Build once (e.g., module __init__)
def make_basis(dtype, device):
    B = torch.zeros(10, 7, 7, dtype=dtype, device=device)
    # 0: cx -> (0,0), (1,1)
    B[0, 0, 0] = 1
    B[0, 1, 1] = 1
    # 1: sx -> (0,1)
    B[1, 0, 1] = 1
    # 2: -kx2*sx -> (1,0)
    B[2, 1, 0] = 1
    # 3: cy -> (2,2), (3,3)
    B[3, 2, 2] = 1
    B[3, 3, 3] = 1
    # 4: sy -> (2,3)
    B[4, 2, 3] = 1
    # 5: -ky2*sy -> (3,2)
    B[5, 3, 2] = 1
    # 6: dx/beta -> (0,5), (4,1)
    B[6, 0, 5] = 1
    B[6, 4, 1] = 1
    # 7: sx*hx/beta -> (1,5), (4,0)
    B[7, 1, 5] = 1
    B[7, 4, 0] = 1
    # 8: r56 -> (4,5)
    B[8, 4, 5] = 1
    # 9: constant ones -> (4,4), (5,5), (6,6)
    B[9, 4, 4] = 1
    B[9, 5, 5] = 1
    B[9, 6, 6] = 1
    return B


# call-time (autograd-safe, single fused op)
def build_R(cx, sx, dx, beta, kx2, cy, sy, ky2, hx, r56, B):
    one = cx.new_ones(())
    s = torch.stack(
        [
            cx,  # 0
            sx,  # 1
            -kx2 * sx,  # 2
            cy,  # 3
            sy,  # 4
            -ky2 * sy,  # 5
            dx / beta,  # 6
            sx * hx / beta,  # 7
            r56,  # 8
            one,  # 9 (constant)
        ]
    )
    # single kernel does the whole linear combo
    R = torch.tensordot(s, B, dims=([0], [0]))  # shape [7, 7]
    return R

In [29]:
B = make_basis(dtype=cx.dtype, device=cx.device)
one = cx.new_ones(())

In [30]:
%%timeit
s = torch.stack(
    [
        cx,  # 0
        sx,  # 1
        -kx2 * sx,  # 2
        cy,  # 3
        sy,  # 4
        -ky2 * sy,  # 5
        dx / beta,  # 6
        sx * hx / beta,  # 7
        r56,  # 8
        one,  # 9 (constant)
    ]
)
# single kernel does the whole linear combo
R = torch.tensordot(s, B, dims=([0], [0]))  # shape [7, 7]

22.2 μs ± 217 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
class MyNormalClass:
    def __init__(self, value: torch.Tensor):
        self.value = value


class MyModuleClass(torch.nn.Module):
    def __init__(self, value: torch.Tensor):
        super().__init__()
        self.value = value

In [16]:
%%timeit
obj = MyNormalClass(value=torch.tensor(1.0))

1.91 μs ± 3.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [17]:
%%timeit
obj = MyModuleClass(value=torch.tensor(1.0))

5.7 μs ± 8.64 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
