# Implementing custom elements in Cheetah

Cheetah is designed to be extensible, making it easy to add custom elements when the existing elements do not meet your needs. The following guide will illustrate on examples how to implement elements following linear beam dynamics as well as elements following more complex physics.

**If you implement a new element and think it could be useful for others, please consider contributing it directly to Cheetah!**


To make any contributions to Cheetah, please open a pull request on the [Cheetah GitHub repository](https://github.com/desy-ml/cheetah). To do so, fork the repository, implement your new element and then open a pull request. One of the Cheetah maintainers will then review your code, maybe ask for some improvements and then merge it onto the `master` branch.


The most important part of implementing a new element is to create a subclass of the `cheetah.accelerator.Element` class. The latter is an abstract base class that defines the interface for all elements in Cheetah. By subclassing it, you can define the specific behaviour and properties of your custom element.

Start by creating a new Python file for your custom element in the `cheetah/accelerator/` directory of the Cheetah repository. For example, if you want to create a custom element called `MyCustomElement`, you would create a file named `my_custom_element.py`.

The following code snippet shows how to implement a simple custom element. You can paste it into your new file and then follow the instructions in the `TODO` comments to complete the implementation of the element itself.


In [None]:
from typing import Literal

import matplotlib.pyplot as plt
import torch

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam, Species
from cheetah.track_methods import base_rmatrix, base_ttensor, misalignment_matrix
from cheetah.utils import verify_device_and_dtype


class MyCustomElement(Element):
    """
    TODO: Describe the purpose of your custom element here. For example: A quadrupole
    magnet that focuses the beam in the horizontal direction and defocuses it in the vertical direction.

    :param length: Length in meters. TODO: If your element has a physical length, keep this line (but delete the TODO comment). If it does not, like for example a `Marker`, delete this line.
    :param my_second_parameter: TODO: Specify the parameters of your element here, following the format `:param parameter_name: Description of the parameter.`.
    :param name: Unique identifier of the element.
    """

    def __init__(
        self,
        length: torch.Tensor,
        # TODO: Replace `my_second_parameter` with the actual parameters of your element. Note that the first parameter must always be `length`, and the last two parameters must always be `name`, `device` and `dtype`. Please remember to add correct type annotations, and to update the docstring above accordingly. Note that parameters with default values must not specify these values here in the signature of `__init__`, but rather have them set in the body of `__init__` (see below).
        a_tensor_parameter: torch.Tensor | None = None,
        not_a_tensor_parameter: Literal[
            "for_example_a_mode", "another_mode"
        ] = "for_example_a_mode",
        name: str | None = None,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        device, dtype = verify_device_and_dtype(
            [length, a_tensor_parameter], device, dtype
        )  # TODO: Pass all your `Tensor` parameters to `verify_device_and_dtype` here.
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(name=name, **factory_kwargs)

        # TODO: Delete the following line if your element does not have a physical length.
        self.length = torch.as_tensor(length, **factory_kwargs)

        # TODO: Assign with `self.parameter_name = ...` only the parameters that are not `torch.Tensor`s.
        self.not_a_tensor_parameter = not_a_tensor_parameter

        # TODO: Register all your torch.Tensor parameters like the following. If they
        # are `torch.Tensor`s, do not simply assign them with `self.my_parameter = ...`! Note that `length` is already registered by the call to `super().__init__()`, so the registration of `length` is not needed, and the assignment `self.length = ...` is allowed.
        self.register_buffer_or_parameter(
            "my_second_parameter",
            torch.as_tensor(
                a_tensor_parameter if a_tensor_parameter is not None else 0.0,
                **factory_kwargs,
            ),
        )

        # TODO: Add other initialisation code your may need here.

    def transfer_map(self, energy: torch.Tensor, species: Species) -> torch.Tensor:
        # TODO: If your element only has first-order effects, you should delete the `track` method below and only implement this method, which computes the first-order transfer map. Cheetah will then automatically handle the tracking
        # for you.
        # The transfer map needs to be a tensor of shape (..., 7, 7), where all but the last two dimensions are vector dimensions, the the upper left 6x6 block is the first-order transfer map you probably know from accelerator physics. The last row should be zeros, except for the last element, which should be 1.0. The last column can be used to add zero-th order effects.
        # There is no need to write a docstring for this method, as the superclass `Element` already provides one.

        # Example placeholder creating a completely random transfer map.
        R = torch.rand(7, 7, device=energy.device, dtype=energy.dtype)

        return R

    def track(self, incoming: Beam) -> Beam:
        # TODO: This is the main method where you implement your physics. It transforms a beam entering your element into a beam exiting your element.
        # If all you need is linear first-order tracking, you can delete this method and only implement `transfer_map` above. Cheetah will then automatically handle the tracking for you.

        # TODO: You will probably have to distinguish between the two different types of beams in Cheetah, `ParameterBeam` and `ParticleBeam`. See the documentation for more information on them.
        if isinstance(incoming, ParameterBeam):
            # TODO: Implement your logic for `ParameterBeam`s here. If your element has complex physics, especially if those include interactions between particles, it is often legitimate to fall back to the first-order transfer map for `ParameterBeam`s. You can do so by calling `super().track(incoming)`.
            return super().track(incoming)
        elif isinstance(incoming, ParticleBeam):
            # TODO: Implement your logic for `ParticleBeam`s here. Return a new `ParticleBeam` object with the transformed particles.

            outgoing_particles = torch.rand_like(
                incoming.particles
            )  # Placeholder for transformed particles

            return ParticleBeam(
                particles=outgoing_particles,
                energy=incoming.energy,
                particle_charges=incoming.particle_charges,
                survival_probabilities=incoming.survival_probabilities,
                species=incoming.species,
            )
        else:
            raise TypeError(
                f"Unsupported beam type: {type(incoming)}. Expected ParameterBeam or "
                "ParticleBeam."
            )

    @property
    def is_skippable(self) -> bool:
        # TODO: If your element follows strictly linear beam dynamics, Cheetah can perform speed optimisations. To see whether it can do so, it checks whether the element `is_skippable` property returns `True`. If your element has only first-order effects, you can return `True` here. If it has second-order effects, you should return `False`. In some cases it makes sense to determine this dynamically.
        return False

    @property
    def is_active(self) -> bool:
        # TODO: Return `True` if your element is active, i.e. "on", and `False` if it is inactive, i.e. "off".
        return True

    def split(self, resolution: torch.Tensor) -> list[Element]:
        # TODO: Implement your logic for splitting your element longitudinally here. This is used to provide beams at multiple positions along the element, specifically `resolution` meters apart. Not all elements can easily be split, which is why Cheetah does not guarantee to the user that `resolution` will always be respected. If your element cannot be split, you may return a list containing only `self.clone()`.
        return [self.clone()]

    def plot(self, ax: plt.Axes, s: float, vector_idx: tuple | None = None) -> None:
        # TODO: Cheetah can plot using in Matplotlib a straight representation of the lattice. You can implement here how your element should be represented.
        raise NotImplementedError

    def to_mesh(
        self, cuteness: float | dict = 1.0, show_download_progress: bool = True
    ) -> "tuple[trimesh.Trimesh | None, np.ndarray]":  # noqa: F821 # type: ignore
        # TODO: Cheetah can create 3D mesh representations of elements and lattices. In most cases you don't have to implement this method. If you feel creative, you can open a PR with a mesh file on the `desy-ml/3d-assets` repository with a file name of your class name in snake case, and then your mesh will automatically be used by Cheetah. In some cases, placing the mesh or computing the 3D trasnformation to place its beam output position is more complex. In those cases, you might need to implement this method. (See `Dipole` for an example of such a case.)
        raise NotImplementedError

    @property
    def defining_features(self) -> list[str]:
        # TODO: Add to this list all the properties that define your element, i.e. those that should be saved and loaded, when the element is serialised. Please make sure to add the list of your element's defining features to the end of the list returned by `super().defining_features`.
        return super().defining_features + [
            "length",
            "my_second_parameter",
            "not_a_tensor_parameter",
        ]

    def __repr__(self) -> str:
        # TODO: Implement a string representation for your element. This is usally the class name, followed by the values of the defining features in such a way that the result could be pasted into a Python interpreter to create an equivalent object.
        return (
            f"{self.__class__.__name__}(length={repr(self.length)}, "
            f"my_second_parameter={repr(self.my_second_parameter)}, "
            f"not_a_tensor_parameter={repr(self.not_a_tensor_parameter)}, "
            f"name={repr(self.name)})"
        )

After you have finished your main implementation, you will need to do some housekeeping to make sure that your new element is properly usable in Cheetah.

- Add your element to the `cheetah/accelerator/__init__.py` file, so that it can be imported from the `cheetah.accelerator` module.
- Add your element to the `cheetah/__init__.py` file, so that it can be imported directly from the `cheetah` module as `cheetah.MyCustomElement`.


If you further want to contribute your element to the main Cheetah repository, you will also need to take care of the following:

- In `CHANGELOG.md` ...
  - Create an entry to advertise your new element under "Features" for the currently under development version.
  - Add your name and GitHub handle at the bottom of the list of "First Time Contributors".
- In `docs/accelerator.rst`, add an entry like the following for your element in the correct alphabetical place:
  ```rst
  .. automodule:: accelerator.my_custom_element
     :members:
     :undoc-members:
  ```
- Write a test for your element in a new appropriately named in the `tests` directory. The file name should follow the pattern `test_<element_name>.py`. This test should test the physics of your element, for example by checking something that should always hold true about those physics, or by checking an example case against results from another code or from a paper.
- Add your new element to the parameterisation of the `test_element_buffer_contents_and_location` test in `tests/test_clone.py`.
- Add your new element to the parametrisations of the `test_forced_element_dtype` and `test_infer_element_dtype` tests in `tests/test_device_dtype.py`.
- Add a minimal dictionary of arguments to your element to the `ELEMENT_CLASSES_REQUIRING_ARGS` dictionary in `tests/test_elements.py`.
- Add your element to the parametrisation of the `test_drift_broadcasting_two_different_inputs` test in `tests/test_vectorized.py`.
- In both `README.md` and `docs/index.rst` ...
  - ... add your name and GitHub handle at the bottom of the list of contributors.
  - ... if not there yet, add the logo of your institution.
  - ... if needed, add to the funding string section.
