In [2]:
from typing import Optional, Union

import numpy as np
import torch
from scipy import constants

import cheetah
from cheetah import Species
from cheetah.utils import compute_relativistic_factors
import jax.numpy as jnp
import jax

In [2]:
%load_ext line_profiler
%load_ext pyinstrument

In [3]:
REST_ENERGY_PYTHON = (
    constants.electron_mass * constants.speed_of_light**2 / constants.elementary_charge
)  # electron mass
REST_ENERGY_TORCH = torch.tensor(
    constants.electron_mass * constants.speed_of_light**2 / constants.elementary_charge
)  # electron mass

In [4]:
def base_rmatrix(
    length: torch.Tensor,
    k1: torch.Tensor,
    hx: torch.Tensor,
    tilt: Optional[torch.Tensor] = None,
    energy: Optional[torch.Tensor] = None,
    device: Union[str, torch.device] = "auto",
) -> torch.Tensor:
    """
    Create a universal transfer matrix for a beamline element.
    :param length: Length of the element in m.
    :param k1: Quadrupole strength in 1/m**2.
    :param hx: Curvature (1/radius) of the element in 1/m**2.
    :param tilt: Roation of the element relative to the longitudinal axis in rad.
    :param energy: Beam energy in eV.
    :param device: Device where the transfer matrix is created. If "auto", the device
        is selected automatically.
    :return: Transfer matrix for the element.
    """

    tilt = tilt if tilt is not None else torch.tensor(0.0)
    energy = energy if energy is not None else torch.tensor(0.0)

    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    gamma = energy / REST_ENERGY_TORCH
    igamma2 = 1 / gamma**2 if gamma != 0 else torch.tensor(0.0)

    beta = torch.sqrt(1 - igamma2)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = torch.sqrt(torch.complex(kx2, torch.tensor(0.0)))
    ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0)))
    cx = torch.cos(kx * length).real
    cy = torch.cos(ky * length).real
    sy = (torch.sin(ky * length) / ky).real if ky != 0 else length

    if kx != 0:
        sx = (torch.sin(kx * length) / kx).real
        dx = hx / kx2 * (1.0 - cx)
        r56 = hx**2 * (length - sx) / kx2 / beta**2
    else:
        sx = length
        dx = length**2 * hx / 2
        r56 = hx**2 * length**3 / 6 / beta**2

    r56 -= length / beta**2 * igamma2

    R = torch.eye(7, dtype=torch.float32, device=device)
    R[0, 0] = cx
    R[0, 1] = sx
    R[0, 5] = dx / beta
    R[1, 0] = -kx2 * sx
    R[1, 1] = cx
    R[1, 5] = sx * hx / beta
    R[2, 2] = cy
    R[2, 3] = sy
    R[3, 2] = -ky2 * sy
    R[3, 3] = cy
    R[4, 0] = sx * hx / beta
    R[4, 1] = dx / beta
    R[4, 5] = r56

    return R

In [5]:
%%timeit
_ = base_rmatrix(
    length=torch.tensor(1.0),
    k1=torch.tensor(4.2),
    hx=torch.tensor(0.0),
    )

139 μs ± 401 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
my_length = torch.tensor(1.0)
my_k1 = torch.tensor(4.2)
my_hx = torch.tensor(0.0)

In [7]:
%%timeit
_ = base_rmatrix(
    length=my_length,
    k1=my_k1,
    hx=my_hx,
)

134 μs ± 452 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
def base_rmatrix(
    length: float,
    k1: float,
    hx: float,
    tilt: float = 0.0,
    energy: float = 0.0,
    device: Union[str, torch.device] = "auto",
):
    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    gamma = energy / REST_ENERGY_PYTHON
    igamma2 = 1 / gamma**2 if gamma != 0 else 0

    beta = np.sqrt(1 - igamma2)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = np.sqrt(kx2 + 0.0j)
    ky = np.sqrt(ky2 + 0.0j)
    cx = np.cos(kx * length).real
    cy = np.cos(ky * length).real
    sy = (np.sin(ky * length) / ky).real if ky != 0 else length

    if kx != 0:
        sx = (np.sin(kx * length) / kx).real
        dx = hx / kx2 * (1.0 - cx)
        r56 = hx**2 * (length - sx) / kx2 / beta**2
    else:
        sx = length
        dx = length**2 * hx / 2
        r56 = hx**2 * length**3 / 6 / beta**2

    r56 -= length / beta**2 * igamma2

    R = torch.tensor(
        [
            [cx, sx, 0, 0, 0, dx / beta, 0],
            [-kx2 * sx, cx, 0, 0, 0, sx * hx / beta, 0],
            [0, 0, cy, sy, 0, 0, 0],
            [0, 0, -ky2 * sy, cy, 0, 0, 0],
            [sx * hx / beta, dx / beta, 0, 0, 1, r56, 0],
            [0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 1],
        ],
        dtype=torch.float32,
        device=device,
    )

    return R

In [9]:
%%timeit
_ = base_rmatrix(
    length=1.0,
    k1=4.2,
    hx=0.0,
    tilt=0.0,
    energy=0.0)

12.6 μs ± 12.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [10]:
def base_rmatrix(
    length: torch.Tensor,
    k1: torch.Tensor,
    hx: torch.Tensor,
    species: Species,
    tilt: torch.Tensor | None = None,
    energy: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Create a first order universal transfer map for a beamline element.

    :param length: Length of the element in m.
    :param k1: Quadrupole strength in 1/m**2.
    :param hx: Curvature (1/radius) of the element in 1/m.
    :param species: Particle species of the beam.
    :param tilt: Roation of the element relative to the longitudinal axis in rad.
    :param energy: Beam energy in eV.
    :return: First order transfer map for the element.
    """
    device = length.device
    dtype = length.dtype

    zero = torch.tensor(0.0, device=device, dtype=dtype)

    tilt = tilt if tilt is not None else zero
    energy = energy if energy is not None else zero

    _, igamma2, beta = compute_relativistic_factors(energy, species.mass_eV)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = torch.sqrt(torch.complex(kx2, zero))
    ky = torch.sqrt(torch.complex(ky2, zero))
    cx = torch.cos(kx * length).real
    cy = torch.cos(ky * length).real
    sx = (torch.sinc(kx * length / torch.pi) * length).real
    sy = (torch.sinc(ky * length / torch.pi) * length).real
    dx = torch.where(kx2 != 0, hx / kx2 * (1.0 - cx), zero)
    r56 = torch.where(kx2 != 0, hx**2 * (length - sx) / kx2 / beta**2, zero)

    r56 = r56 - length / beta**2 * igamma2

    vector_shape = torch.broadcast_shapes(
        length.shape, k1.shape, hx.shape, tilt.shape, energy.shape
    )

    R = torch.eye(7, dtype=dtype, device=device).repeat(*vector_shape, 1, 1)
    R[..., 0, 0] = cx
    R[..., 0, 1] = sx
    R[..., 0, 5] = dx / beta
    R[..., 1, 0] = -kx2 * sx
    R[..., 1, 1] = cx
    R[..., 1, 5] = sx * hx / beta
    R[..., 2, 2] = cy
    R[..., 2, 3] = sy
    R[..., 3, 2] = -ky2 * sy
    R[..., 3, 3] = cy
    R[..., 4, 0] = sx * hx / beta
    R[..., 4, 1] = dx / beta
    R[..., 4, 5] = r56

    return R

In [11]:
my_length = torch.tensor(1.0)
my_k1 = torch.tensor(4.2)
my_hx = torch.tensor(0.0)
my_species = Species("electron")

In [12]:
%%timeit
_ = base_rmatrix(
    length=my_length,
    k1=my_k1,
    hx=my_hx,
    species=my_species
    )

228 μs ± 38 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%lprun -f base_rmatrix base_rmatrix(length=my_length, k1=my_k1, hx=my_hx, species=my_species)

Timer unit: 1e-09 s

Total time: 0.002546 s
File: /var/folders/z8/vzg_1dr50gg1zchydp1styc00000gn/T/ipykernel_84847/33921839.py
Function: base_rmatrix at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def base_rmatrix(
     2                                               length: torch.Tensor,
     3                                               k1: torch.Tensor,
     4                                               hx: torch.Tensor,
     5                                               species: Species,
     6                                               tilt: torch.Tensor | None = None,
     7                                               energy: torch.Tensor | None = None,
     8                                           ) -> torch.Tensor:
     9                                               """
    10                                               Create a first order universal transfer map for a beamline element.

In [14]:
# %%pyinstrument --interval=0.0000001
# _ = base_rmatrix(
#     length=my_length,
#     k1=my_k1,
#     hx=my_hx,
#     species=my_species
# )

In [15]:
energy = torch.tensor(1e9)  # 1 GeV
species = Species("electron")

In [19]:
_, _, _ = compute_relativistic_factors(energy, species.mass_eV)

In [22]:
%%timeit
_, _, _ = compute_relativistic_factors(energy, species.mass_eV)

21.3 μs ± 380 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [30]:
# %lprun -f compute_relativistic_factors compute_relativistic_factors(energy, species.mass_eV)

In [18]:
def base_rmatrix(
    length: torch.Tensor,
    k1: torch.Tensor,
    hx: torch.Tensor,
    species: Species,
    tilt: torch.Tensor | None = None,
    energy: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Create a first order universal transfer map for a beamline element.

    :param length: Length of the element in m.
    :param k1: Quadrupole strength in 1/m**2.
    :param hx: Curvature (1/radius) of the element in 1/m.
    :param species: Particle species of the beam.
    :param tilt: Roation of the element relative to the longitudinal axis in rad.
    :param energy: Beam energy in eV.
    :return: First order transfer map for the element.
    """
    device = length.device
    dtype = length.dtype

    zero = torch.tensor(0.0, device=device, dtype=dtype)

    tilt = tilt if tilt is not None else zero
    energy = energy if energy is not None else zero

    _, igamma2, beta = compute_relativistic_factors(energy, species.mass_eV)

    kx2 = k1 + hx**2
    ky2 = -k1
    kx = torch.sqrt(torch.complex(kx2, zero))
    ky = torch.sqrt(torch.complex(ky2, zero))
    cx = torch.cos(kx * length).real
    cy = torch.cos(ky * length).real
    sx = (torch.sinc(kx * length / torch.pi) * length).real
    sy = (torch.sinc(ky * length / torch.pi) * length).real
    dx = torch.where(kx2 != 0, hx / kx2 * (1.0 - cx), zero)
    r56 = torch.where(kx2 != 0, hx**2 * (length - sx) / kx2 / beta**2, zero)

    r56 = r56 - length / beta**2 * igamma2
    cx, sx, dx, cy, sy, r56 = torch.broadcast_tensors(cx, sx, dx, cy, sy, r56)

    R = torch.eye(7, dtype=dtype, device=device).repeat(*cx.shape, 1, 1)
    R[
        ...,
        (0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 4),
        (0, 1, 5, 0, 1, 5, 2, 3, 2, 3, 0, 1, 5),
    ] = torch.stack(
        [
            cx,
            sx,
            dx / beta,
            -kx2 * sx,
            cx,
            sx * hx / beta,
            cy,
            sy,
            -ky2 * sy,
            cy,
            sx * hx / beta,
            dx / beta,
            r56,
        ],
        dim=-1,
    )

    return R

In [19]:
my_length = torch.tensor(1.0)
my_k1 = torch.tensor(4.2)
my_hx = torch.tensor(0.0)
my_species = Species("electron")

In [20]:
%%timeit
_ = base_rmatrix(
    length=my_length,
    k1=my_k1,
    hx=my_hx,
    species=my_species,
)

187 μs ± 8.09 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [21]:
cx = 1.0
cy = 2.0
sx = 3.0
sy = 4.0
dx = 5.0
dy = 6.0
beta = 7.0
kx2 = 8.0
ky2 = 9.0
hx = 7.0
hy = 8.0
r56 = 9.0

In [22]:
%%timeit
R = torch.tensor(
        [
            [cx, sx, 0, 0, 0, dx / beta, 0],
            [-kx2 * sx, cx, 0, 0, 0, sx * hx / beta, 0],
            [0, 0, cy, sy, 0, 0, 0],
            [0, 0, -ky2 * sy, cy, 0, 0, 0],
            [sx * hx / beta, dx / beta, 0, 0, 1, r56, 0],
            [0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 1],
        ]
    )

12.3 μs ± 775 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [23]:
cx = torch.tensor(1.0)
cy = torch.tensor(2.0)
sx = torch.tensor(3.0)
sy = torch.tensor(4.0)
dx = torch.tensor(5.0)
dy = torch.tensor(6.0)
beta = torch.tensor(7.0)
kx2 = torch.tensor(8.0)
ky2 = torch.tensor(9.0)
hx = torch.tensor(7.0)
hy = torch.tensor(8.0)
r56 = torch.tensor(9.0)

In [24]:
%%timeit
R = torch.eye(7)
R[0, 0] = cx
R[0, 1] = sx
R[0, 5] = dx / beta
R[1, 0] = -kx2 * sx
R[1, 1] = cx
R[1, 5] = sx * hx / beta
R[2, 2] = cy
R[2, 3] = sy
R[3, 2] = -ky2 * sy
R[3, 3] = cy
R[4, 0] = sx * hx / beta
R[4, 1] = dx / beta
R[4, 5] = r56

54.6 μs ± 8.14 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [25]:
length = my_length
k1 = my_k1
hx = my_hx
species = my_species
another_tilt = None
another_energy = None

In [26]:
%%timeit
zero = torch.tensor(0.0)

tilt = another_tilt if another_tilt is not None else zero
energy = another_energy if another_energy is not None else zero

_, igamma2, beta = compute_relativistic_factors(energy, species.mass_eV)

kx2 = k1 + hx**2
ky2 = -k1
kx = torch.sqrt(torch.complex(kx2, zero))
ky = torch.sqrt(torch.complex(ky2, zero))
cx = torch.cos(kx * length).real
cy = torch.cos(ky * length).real
sx = (torch.sinc(kx * length / torch.pi) * length).real
sy = (torch.sinc(ky * length / torch.pi) * length).real
dx = torch.where(kx2 != 0, hx / kx2 * (1.0 - cx), zero)
r56 = torch.where(kx2 != 0, hx**2 * (length - sx) / kx2 / beta**2, zero)

r56 = r56 - length / beta**2 * igamma2

vector_shape = torch.broadcast_shapes(
    length.shape, k1.shape, hx.shape, tilt.shape, energy.shape
)

117 μs ± 18.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [27]:
vector_shape = torch.broadcast_shapes(
    cx.shape,
    cy.shape,
    sx.shape,
    sy.shape,
    dx.shape,
    dy.shape,
    beta.shape,
    kx2.shape,
    ky2.shape,
    hx.shape,
    hy.shape,
    r56.shape,
)

In [31]:
%%timeit
R = torch.eye(7).repeat(*vector_shape, 1, 1)
R[..., 0, 0] = cx
R[..., 0, 1] = sx
R[..., 0, 5] = dx / beta
R[..., 1, 0] = -kx2 * sx
R[..., 1, 1] = cx
R[..., 1, 5] = sx * hx / beta
R[..., 2, 2] = cy
R[..., 2, 3] = sy
R[..., 3, 2] = -ky2 * sy
R[..., 3, 3] = cy
R[..., 4, 0] = sx * hx / beta
R[..., 4, 1] = dx / beta
R[..., 4, 5] = r56

51.8 μs ± 94.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [24]:
%%timeit
# make reusable 0/1 tensors on the right device/dtype, keeping graph intact
z = torch.zeros_like(cx)
o = torch.ones_like(cx)

row0 = torch.stack([cx, sx, z, z, z, dx / beta, z])
row1 = torch.stack([-kx2 * sx, cx, z, z, z, sx * hx / beta, z])
row2 = torch.stack([z, z, cy, sy, z, z, z])
row3 = torch.stack([z, z, -ky2 * sy, cy, z, z, z])
row4 = torch.stack([sx * hx / beta, dx / beta, z, z, o, r56, z])
row5 = torch.stack([z, z, z, z, z, o, z])
row6 = torch.stack([z, z, z, z, z, z, o])

R = torch.stack([row0, row1, row2, row3, row4, row5, row6])

44 μs ± 232 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [25]:
%%timeit
R = torch.eye(7)
rows = torch.tensor([0,0,0,1,1,1,2,2,3,3,4,4,4])
cols = torch.tensor([0,1,5,0,1,5,2,3,2,3,0,1,5])
vals = torch.stack([
    cx, sx, dx/beta,
    -kx2*sx, cx, sx*hx/beta,
    cy, sy,
    -ky2*sy, cy,
    sx*hx/beta, dx/beta, r56
])
R.index_put_((rows, cols), vals, accumulate=False)  # single kernel, autograd-safe

35.4 μs ± 6.31 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [38]:
%%timeit
R = torch.eye(7).repeat(*vector_shape, 1, 1)
rows = [0,0,0,1,1,1,2,2,3,3,4,4,4]
cols = [0,1,5,0,1,5,2,3,2,3,0,1,5]
vals = torch.stack([
    cx, sx, dx/beta,
    -kx2*sx, cx, sx*hx/beta,
    cy, sy,
    -ky2*sy, cy,
    sx*hx/beta, dx/beta, r56
])
R[rows, cols] = vals  # single kernel, autograd-safe


32.5 μs ± 157 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [27]:
# Build once (e.g., module __init__)
def make_basis(dtype, device):
    B = torch.zeros(10, 7, 7, dtype=dtype, device=device)
    # 0: cx -> (0,0), (1,1)
    B[0, 0, 0] = 1
    B[0, 1, 1] = 1
    # 1: sx -> (0,1)
    B[1, 0, 1] = 1
    # 2: -kx2*sx -> (1,0)
    B[2, 1, 0] = 1
    # 3: cy -> (2,2), (3,3)
    B[3, 2, 2] = 1
    B[3, 3, 3] = 1
    # 4: sy -> (2,3)
    B[4, 2, 3] = 1
    # 5: -ky2*sy -> (3,2)
    B[5, 3, 2] = 1
    # 6: dx/beta -> (0,5), (4,1)
    B[6, 0, 5] = 1
    B[6, 4, 1] = 1
    # 7: sx*hx/beta -> (1,5), (4,0)
    B[7, 1, 5] = 1
    B[7, 4, 0] = 1
    # 8: r56 -> (4,5)
    B[8, 4, 5] = 1
    # 9: constant ones -> (4,4), (5,5), (6,6)
    B[9, 4, 4] = 1
    B[9, 5, 5] = 1
    B[9, 6, 6] = 1
    return B


# call-time (autograd-safe, single fused op)
def build_R(cx, sx, dx, beta, kx2, cy, sy, ky2, hx, r56, B):
    one = cx.new_ones(())
    s = torch.stack(
        [
            cx,  # 0
            sx,  # 1
            -kx2 * sx,  # 2
            cy,  # 3
            sy,  # 4
            -ky2 * sy,  # 5
            dx / beta,  # 6
            sx * hx / beta,  # 7
            r56,  # 8
            one,  # 9 (constant)
        ]
    )
    # single kernel does the whole linear combo
    R = torch.tensordot(s, B, dims=([0], [0]))  # shape [7, 7]
    return R

In [28]:
B = make_basis(dtype=cx.dtype, device=cx.device)
one = cx.new_ones(())

In [29]:
%%timeit
s = torch.stack(
    [
        cx,  # 0
        sx,  # 1
        -kx2 * sx,  # 2
        cy,  # 3
        sy,  # 4
        -ky2 * sy,  # 5
        dx / beta,  # 6
        sx * hx / beta,  # 7
        r56,  # 8
        one,  # 9 (constant)
    ]
)
# single kernel does the whole linear combo
R = torch.tensordot(s, B, dims=([0], [0]))  # shape [7, 7]

21.8 μs ± 668 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [30]:
class MyNormalClass:
    def __init__(self, value: torch.Tensor):
        self.value = value


class MyModuleClass(torch.nn.Module):
    def __init__(self, value: torch.Tensor):
        super().__init__()
        self.value = value

In [31]:
%%timeit
obj = MyNormalClass(value=torch.tensor(1.0))

1.87 μs ± 3.56 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [32]:
%%timeit
obj = MyModuleClass(value=torch.tensor(1.0))

6.81 μs ± 20.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [33]:
gamma = torch.tensor(1.0)

In [34]:
%%timeit
igamma2 = 1 / gamma**2

5.32 μs ± 18.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [35]:
%%timeit
igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2)

10.6 μs ± 26.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [55]:
k1_py = 4.2
hx_py = 1e-6

k1_th = torch.tensor(k1_py)
hx_th = torch.tensor(hx_py)

k1_np = np.array(k1_py)
hx_np = np.array(hx_py)

k1_jx = jnp.array(k1_py)
hx_jx = jnp.array(hx_py)

In [65]:
def py_fun(k1, hx):
    kx2 = k1 + hx**2
    return kx2


def th_fun(k1, hx):
    kx2 = k1 + hx**2
    return kx2


@torch.compile
def compiled_th_fun(k1, hx):
    kx2 = k1 + hx**2
    return kx2


def np_fun(k1, hx):
    kx2 = k1 + hx**2
    return kx2


def jx_fun(k1, hx):
    kx2 = k1 + hx**2
    return kx2


@jax.jit
def compiled_jx_fun(k1, hx):
    kx2 = k1 + hx**2
    return kx2

In [67]:
%%timeit
_ = py_fun(k1_py, hx_py)

101 ns ± 0.135 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [73]:
%%timeit
_ = th_fun(k1_th, hx_th)

2.63 μs ± 32.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [76]:
%%timeit
_ = compiled_th_fun(k1_th, hx_th)

14 μs ± 105 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [80]:
%%timeit
_ = np_fun(k1_np, hx_np)

967 ns ± 123 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [82]:
%%timeit
_ = jx_fun(k1_jx, hx_jx)

35.1 μs ± 240 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [85]:
%%timeit
_ = compiled_jx_fun(k1_jx, hx_jx)

4.94 μs ± 43.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [56]:
%%timeit
kx2_py = k1_py + hx_py**2

63.5 ns ± 0.0816 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [57]:
%%timeit
kx2_th = k1_th + hx_th**2

2.57 μs ± 33 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [58]:
%%timeit
kx2_np = k1_np + hx_np**2

777 ns ± 7.05 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [62]:
%%timeit
kx2_jx = k1_jx + hx_jx**2

34 μs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [2]:
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / 1.41421))


@torch.compile
def gelu_comp(x):
    return x * 0.5 * (1.0 + torch.erf(x / 1.41421))

In [3]:
x1 = torch.tensor(1.0)
x2 = torch.randn(100, 100)

_ = gelu(x1)
# _ = gelu(x2)
# _ = gelu_comp(x1)
_ = gelu_comp(x2)

In [4]:
_ = gelu_comp(x1)

In [8]:
%%timeit
_ = gelu(x1)

8.04 μs ± 70 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [9]:
%%timeit
_ = gelu_comp(x1)

7.96 μs ± 182 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [12]:
%%timeit
_ = gelu(x2)

64.8 μs ± 241 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
%%timeit
_ = gelu_comp(x2)

30.5 μs ± 77.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [2]:
@torch.compile
def compute_relativistic_beta(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> torch.Tensor:
    return torch.sqrt(1 - (particle_mass_eV / energy) ** 2)


@torch.compile
def compute_relativistic_igamma2(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> torch.Tensor:
    return (particle_mass_eV / energy) ** 2


@torch.compile
def compute_relativistic_gamma(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> torch.Tensor:
    return energy / particle_mass_eV


@torch.compile
def compute_relativistic_beta_comp(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> torch.Tensor:
    return torch.sqrt(1 - (particle_mass_eV / energy) ** 2)


@torch.compile
def compute_relativistic_igamma2_comp(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> torch.Tensor:
    return (particle_mass_eV / energy) ** 2


@torch.compile
def compute_relativistic_gamma_comp(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> torch.Tensor:
    return energy / particle_mass_eV


my_energy = torch.tensor(1e9)  # 1 GeV
my_mass = torch.tensor(0.511e6)  # electron mass in eV

_ = compute_relativistic_beta(my_energy, my_mass)
_ = compute_relativistic_igamma2(my_energy, my_mass)
_ = compute_relativistic_gamma(my_energy, my_mass)

_ = compute_relativistic_beta_comp(my_energy, my_mass)
_ = compute_relativistic_igamma2_comp(my_energy, my_mass)
_ = compute_relativistic_gamma_comp(my_energy, my_mass)

In [17]:
%%timeit
_ = compute_relativistic_beta(my_energy, my_mass)

8.21 μs ± 5.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [18]:
%%timeit
_ = compute_relativistic_beta_comp(my_energy, my_mass)

8.32 μs ± 45.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [21]:
%%timeit
_ = compute_relativistic_gamma(my_energy, my_mass)

8.23 μs ± 28.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [22]:
%%timeit
_ = compute_relativistic_gamma_comp(my_energy, my_mass)

8.06 μs ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [11]:
compute_relativistic_factors_compiled = torch.compile(
    compute_relativistic_factors, mode="max-autotune"
)

# Run 3 times to try and get over warmup
_ = compute_relativistic_factors_compiled(my_energy, my_mass)
_ = compute_relativistic_factors_compiled(my_energy, my_mass)
_ = compute_relativistic_factors_compiled(my_energy, my_mass)

In [12]:
%%timeit
_ = compute_relativistic_factors(my_energy, my_mass)

10.9 μs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [13]:
%%timeit
_ = compute_relativistic_factors_compiled(my_energy, my_mass)

10.4 μs ± 124 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [14]:
expanded_energy = my_energy.expand(1000)
expanded_mass = my_mass.expand(1000)

In [15]:
%%timeit
_ = compute_relativistic_factors(expanded_energy, expanded_mass)

13.3 μs ± 1.78 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [16]:
%%timeit
_ = compute_relativistic_factors_compiled(expanded_energy, expanded_mass)

13.5 μs ± 11.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [10]:
my_energy = torch.tensor(1e9)  # 1 GeV
my_mass = torch.tensor(0.511e6)  # electron mass in eV
my_gamma = my_energy / my_mass
my_igamma2 = 1 / my_gamma**2

In [7]:
%%timeit
gamma = my_energy / my_mass

1.06 μs ± 1.52 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [9]:
%%timeit
gamma = torch.div(my_energy, my_mass)

1.08 μs ± 2.59 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [17]:
%%timeit
igamma2 = 1 / my_gamma**2

6.17 μs ± 810 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [16]:
%%timeit
igamma2 = torch.square(torch.reciprocal(my_gamma))

2.1 μs ± 5.93 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [18]:
%%timeit
igamma2 = torch.reciprocal(torch.square(my_gamma))

2.07 μs ± 6.45 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [19]:
%%timeit
igamma2 = 1 / torch.square(my_gamma)

5 μs ± 18.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [21]:
%%timeit
igamma2 = torch.reciprocal(my_gamma**2)

2.34 μs ± 7.16 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [25]:
def compute_relativistic_factors_1(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Computes the relativistic factors gamma, inverse gamma squared and beta for
    particles.

    :param energy: Energy in eV.
    :param particle_mass_eV: Mass of the particle in eV.
    :return: gamma, igamma2, beta.
    """
    gamma = energy / particle_mass_eV
    igamma2 = 1 / gamma**2  # Division by zero not handled because not physical
    beta = torch.sqrt(1 - igamma2)

    return gamma, igamma2, beta


def compute_relativistic_factors_2(
    energy: torch.Tensor, particle_mass_eV: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Computes the relativistic factors gamma, inverse gamma squared and beta for
    particles.

    :param energy: Energy in eV.
    :param particle_mass_eV: Mass of the particle in eV.
    :return: gamma, igamma2, beta.
    """
    gamma = energy / particle_mass_eV
    # Division by zero not physical
    # reciprocal and square save on kernel launches
    igamma2 = torch.reciprocal(torch.square(gamma))
    beta = torch.sqrt(1 - igamma2)

    return gamma, igamma2, beta

In [26]:
%%timeit
_ = compute_relativistic_factors_1(my_energy, my_mass)

10.7 μs ± 29.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [27]:
%%timeit
_ = compute_relativistic_factors_2(my_energy, my_mass)

7.41 μs ± 20.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [29]:
length = torch.tensor(1.0)

In [32]:
%%timeit
zero = torch.tensor(0.0)

1.61 μs ± 4.74 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [33]:
%%timeit
zero = length.new_zeros(())

812 ns ± 3.38 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [34]:
tilt_1 = None
tilt_2 = torch.tensor(0.3)
zero = tilt_2.new_zeros(())

In [38]:
%%timeit
tilt = tilt_1 if tilt_1 is not None else zero

9.48 ns ± 0.0631 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [None]:
%%timeit
tilt_2 = tilt_2 if tilt_2 is not None else zero

9.58 ns ± 0.0134 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [39]:
%%timeit
tilt = zero if tilt_1 is None else tilt_1

10.1 ns ± 0.032 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [40]:
%%timeit
tilt = zero if tilt_2 is None else tilt_2

9.21 ns ± 0.0319 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [47]:
%%timeit
if tilt_1 is None:
    tilt = zero

9.45 ns ± 0.00709 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [48]:
%%timeit
if tilt_2 is None:
    tilt = zero

5.63 ns ± 0.0114 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [49]:
%%timeit
tilt = tilt_1 or zero

12.1 ns ± 0.00842 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [50]:
%%timeit
tilt = tilt_2 or zero

301 ns ± 6.42 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [55]:
k1 = torch.tensor(4.2)
hx = torch.tensor(0.0)

In [56]:
%%timeit
kx2 = k1 + hx**2

2.47 μs ± 8.38 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [57]:
%%timeit
kx2 = k1 + torch.square(hx)

2.15 μs ± 2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [58]:
%%timeit
kx2 = torch.add(k1, torch.square(hx))

2.17 μs ± 4.02 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [73]:
kx2 = torch.tensor(4.2)
kx2_is_not_zero = kx2 != 0
hx = torch.tensor(0.0)
length = torch.tensor(1.0)
sx = torch.tensor(0.42)
sy = torch.tensor(0.52)
ibeta2 = torch.tensor(1.0)
zero = torch.tensor(0.0)
igamma2 = torch.tensor(0.3)
cx = torch.tensor(0.5)
cy = torch.tensor(0.6)
dx = torch.tensor(0.1)

In [77]:
%%timeit
r56 = torch.where(
    kx2_is_not_zero, torch.square(hx) * (length - sx) / kx2 * ibeta2, zero
)
r56 = r56 - length * ibeta2 * igamma2

10 μs ± 19.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [78]:
%%timeit
r56 = torch.where(
    kx2_is_not_zero, torch.square(hx) * (length - sx) / kx2 * ibeta2, zero
) - length * ibeta2 * igamma2

10.2 μs ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [80]:
%%timeit
r56 = (torch.where(
    kx2_is_not_zero, torch.square(hx) * (length - sx) / kx2, zero
) - length * igamma2) * ibeta2

9.09 μs ± 19.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [14]:
cx = torch.tensor(0.5)

In [15]:
%%timeit
R = torch.eye(7, dtype=cx.dtype, device=cx.device).repeat(*cx.shape, 1, 1)

5.67 μs ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [16]:
%%timeit
R = (
    torch.eye(7, dtype=cx.dtype, device=cx.device)
    .expand(*cx.shape, 7, 7)
    .clone()
)

4.67 μs ± 16.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [2]:
cavity = cheetah.Cavity(length=torch.tensor(1.0), voltage=torch.tensor([1e6, 0.0]))
species = cheetah.Species("electron")
energy = torch.tensor(1e9)  # 1 GeV

In [5]:
%%timeit
cavity.first_order_transfer_map(energy=energy, species=species)

231 μs ± 1.18 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
cavity.first_order_transfer_map(energy=energy, species=species)

tensor([[[ 9.9950e-01,  9.9951e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-3.7447e-07,  9.9950e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  9.9950e-01,  9.9951e-01,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -3.7447e-07,  9.9950e-01,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,
          -2.6073e-07,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           9.9900e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  1.0000e+00]],

        [[ 1.0000e+00,         nan,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0

In [7]:
k1 = torch.tensor(4.2)
hx = torch.tensor(0.3)

In [11]:
%%timeit
_ = hx**2

1.33 μs ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [12]:
%%timeit
_ = hx * hx

1.02 μs ± 1.43 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [13]:
%%timeit
_ = torch.square(hx)

1.05 μs ± 2.41 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
