In [1]:
import math
import os
import re
from copy import deepcopy
from pathlib import Path

import numpy as np
import scipy.constants
from scipy.constants import physical_constants

import cheetah

In [2]:
os.environ["LCLS_LATTICE"] = str((Path(".").absolute().parent / "lcls-lattice"))
os.environ["LCLS_LATTICE"]

'/Users/jankaiser/Documents/DESY/lcls-lattice'

In [3]:
lattice_file_path = (
    Path("$LCLS_LATTICE") / "bmad" / "models" / "cu_hxr" / "cu_hxr.lat.bmad"
)
lattice_file_path

PosixPath('$LCLS_LATTICE/bmad/models/cu_hxr/cu_hxr.lat.bmad')

In [4]:
resolved_lattice_file_path = Path(
    *[
        os.environ[part[1:]] if part.startswith("$") else part
        for part in lattice_file_path.parts
    ]
)
resolved_lattice_file_path

PosixPath('/Users/jankaiser/Documents/DESY/lcls-lattice/bmad/models/cu_hxr/cu_hxr.lat.bmad')

In [5]:
lines = cheetah.bmad.read_clean_lines(resolved_lattice_file_path)
lines[:20]

['beginning[beta_a] =  5.91253676811640894e+000',
 'beginning[alpha_a] =  3.55631307633660354e+000',
 'beginning[beta_b] =  5.91253676811640982e+000',
 'beginning[alpha_b] =  3.55631307633660398e+000',
 'beginning[e_tot] = 6e6',
 'parameter[geometry] = open',
 'parameter[particle] = electron',
 'beginning[theta_position] = -35*pi/180',
 'beginning[z_position] = 3050.512000 - 1032.60052',
 'beginning[x_position] = 10.44893',
 'setsp = 0',
 'setcus = 0',
 'setal = 0',
 'setda = 0',
 'setxleap2 = 0',
 'sethxrss = 0',
 'setsxrss = 0',
 'setcbxfel = 0',
 'setpepx = 0',
 'intgsx = 30.0']

In [6]:
merged_lines = cheetah.bmad.merge_delimiter_continued_lines(
    lines, delimiter="&", remove_delimiter=True
)
merged_lines = cheetah.bmad.merge_delimiter_continued_lines(
    merged_lines, delimiter=",", remove_delimiter=False
)
merged_lines = cheetah.bmad.merge_delimiter_continued_lines(
    merged_lines, delimiter="{", remove_delimiter=False
)
len(lines), len(merged_lines)

(14409, 12215)

In [7]:
property_assignment_pattern = r"[a-z0-9_\*:]+\[[a-z0-9_%]+\]\s*=.*"
variable_assignment_pattern = r"[a-z0-9_]+\s*=.*"
element_definition_pattern = r"[a-z0-9_]+\s*\:\s*[a-z0-9_]+.*"
line_definition_pattern = r"[a-z0-9_]+\s*\:\s*line\s*=\s*\(.*\)"
overlay_definition_pattern = r"[a-z0-9_]+\s*\:\s*overlay\s*=\s*\{.*"
use_line_pattern = r"use\s*\,\s*[a-z0-9_]+"

num_successful = 0
num_property_assignment = 0
num_variable_assignment = 0
num_element_definition = 0
num_line_definition = 0
num_overlay_definition = 0
num_use_line = 0
for line in merged_lines:
    if re.fullmatch(property_assignment_pattern, line):
        num_successful += 1
        num_property_assignment += 1
    elif re.fullmatch(variable_assignment_pattern, line):
        num_successful += 1
        num_variable_assignment += 1
    elif re.fullmatch(line_definition_pattern, line):
        num_successful += 1
        num_line_definition += 1
    elif re.fullmatch(overlay_definition_pattern, line):
        num_successful += 1
        num_overlay_definition += 1
    elif re.fullmatch(element_definition_pattern, line):
        num_successful += 1
        num_element_definition += 1
    elif re.fullmatch(use_line_pattern, line):
        num_successful += 1
        num_use_line += 1
    else:
        print(line)
        break

print("")
print("######################################")
print(f"num_successful: {num_successful} / {len(merged_lines)}")
print("--------------------------------------")
print(f"{num_property_assignment = }")
print(f"{num_variable_assignment = }")
print(f"{num_element_definition = }")
print(f"{num_line_definition = }")
print(f"{num_overlay_definition = }")
print(f"{num_use_line = }")
print("######################################")


######################################
num_successful: 12215 / 12215
--------------------------------------
num_property_assignment = 4261
num_variable_assignment = 1684
num_element_definition = 4854
num_line_definition = 1309
num_overlay_definition = 106
num_use_line = 1
######################################


In [8]:
def evaluate_expression(expression: str, context: dict) -> dict:
    """Evaluate an expression in the context of a dictionary of variables."""

    # Try reading the expression as an integer
    try:
        return int(expression)
    except ValueError:
        pass

    # Try reading the expression as a float
    try:
        return float(expression)
    except ValueError:
        pass

    # Check against allowed keywords
    if expression in ["open", "electron", "t", "f", "traveling_wave", "full"]:
        return expression

    # Check against previously defined variables
    if expression in context:
        return context[expression]

    # Evaluate as a mathematical expression
    try:
        # Surround expressions in bracks with quotes
        expression = re.sub(r"\[([a-z0-9_%]+)\]", r"['\1']", expression)
        # Replace power operator with python equivalent
        expression = re.sub(r"\^", r"**", expression)
        # Replace abs with abs_func when it is followed by a (
        # NOTE: This is a hacky fix to deal with abs being overwritten in the LCLS
        # lattice file. I'm not sure this replacement will lead to the intended
        # behaviour.
        expression = re.sub(r"abs\(", r"abs_func(", expression)

        return eval(expression, context)
    except SyntaxError:
        if not (
            len(expression.split(":")) == 3 or len(expression.split(":")) == 4
        ):  # It's probably an alias
            print(
                f"DEBUG: Evaluating expression {expression}. Assuming it is a string."
            )
        return expression
    except Exception as e:
        print(expression)
        raise e

    print(expression)
    result = "foobar"

    return result

In [9]:
def resolve_object_name_wildcard(wildcard_pattern: str, context: dict) -> list:
    """Return a list of object names that match the given wildcard."""
    object_type, object_name = wildcard_pattern.split("::")

    pattern = object_name.replace("*", ".*").replace("%", ".")
    name_matching_keys = [key for key in context.keys() if re.fullmatch(pattern, key)]
    type_matching_keys = [
        key
        for key in name_matching_keys
        if isinstance(context[key], dict)
        and "element_type" in context[key]
        and context[key]["element_type"] == object_type
    ]

    return type_matching_keys

In [10]:
def assign_property(line: str, context: dict) -> dict:
    """Assign a property to the context."""
    pattern = r"([a-z0-9_\*:]+)\[([a-z0-9_%]+)\]\s*=(.*)"
    match = re.fullmatch(pattern, line)

    object_name = match.group(1).strip()
    property_name = match.group(2).strip()
    property_expression = match.group(3).strip()  # TODO: Evaluate expression first

    if "*" in object_name or "%" in object_name:
        object_names = resolve_object_name_wildcard(object_name, context)
    else:
        object_names = [object_name]

    expression_result = evaluate_expression(property_expression, context)

    for name in object_names:
        if name not in context:
            context[name] = {}
        context[name][property_name] = expression_result

    return context

In [11]:
def assign_variable(line: str, context: dict) -> dict:
    """Assign a variable to the context."""
    pattern = r"([a-z0-9_]+)\s*=(.*)"
    match = re.fullmatch(pattern, line)

    variable_name = match.group(1).strip()
    variable_expression = match.group(2).strip()  # TODO: Evaluate expression first

    context[variable_name] = evaluate_expression(variable_expression, context)

    if variable_name == "o_bc1":
        print("ASSIGN_VARIABLE")
        print(line)
        print(variable_expression)
        print(context[variable_name])
        print("------")

    return context

In [12]:
def validate_understood_properties(understood: list[str], properties: dict) -> None:
    """
    Validate that all properties are understood. This function primarily ensures that
    properties not understood by Cheetah are not ignored silently.
    """
    for property in properties:
        assert property in understood, (
            f"Property {property} with value {properties[property]} for element"
            f" type {properties['element_type']} is currently not understood."
        )

In [13]:
def convert_element(name: str, properties: dict):
    """
    Convert parsed element dictionary from a Bmad lattice file to a Cheetah element.
    """
    if properties["element_type"] == "drift" or properties["element_type"] == "pipe":
        validate_understood_properties(
            ["element_type", "l", "type", "descrip"], properties
        )
        return cheetah.Drift(name=name, length=properties["l"])
    elif (
        properties["element_type"] == "marker"
        or properties["element_type"] == "monitor"
        or properties["element_type"] == "instrument"
    ):
        validate_understood_properties(["element_type", "type"], properties)

        # TODO: Remove the length if by adding markers to Cheeath
        return cheetah.Drift(name=name, length=0.0)
    elif properties["element_type"] == "quadrupole":
        # TODO: Aperture for quadrupoles?
        validate_understood_properties(
            ["element_type", "l", "k1", "type", "aperture"], properties
        )
        return cheetah.Quadrupole(
            name=name, length=properties["l"], k1=properties["k1"]
        )
    else:
        print(
            f"WARNING: Element of type {properties['element_type']} cannot be converted"
            " correctly. Using drift section instead."
        )
        # TODO: Remove the length if by adding markers to Cheeath
        return cheetah.Drift(
            name=name, length=properties["l"] if "l" in properties else 0.0
        )

In [14]:
def define_element(line: str, context: dict) -> dict:
    """Define an element in the context."""
    pattern = r"([a-z0-9_]+)\s*\:\s*([a-z0-9_]+)(\,(.*))?"
    match = re.fullmatch(pattern, line)

    element_name = match.group(1).strip()
    element_type = match.group(2).strip()

    if element_type in context:
        element_properties = deepcopy(context[element_type])
    else:
        element_properties = {"element_type": element_type}

    if match.group(3) is not None:
        element_properties_string = match.group(4).strip()

        property_pattern = r"([a-z0-9_]+\s*\=\s*\"[^\"]+\"|[a-z0-9]+\s*\=\s*[^\=\,\"]+)"
        property_matches = re.findall(property_pattern, element_properties_string)

        for property_string in property_matches:
            property_string = property_string.strip()

            property_name, property_expression = property_string.split("=")
            property_name = property_name.strip()
            property_expression = property_expression.strip()

            element_properties[property_name] = evaluate_expression(
                property_expression, context
            )

    context[element_name] = element_properties

    return context

In [15]:
def define_line(line: str, context: dict) -> dict:
    """Define a beam line in the context."""
    pattern = r"([a-z0-9_]+)\s*\:\s*line\s*=\s*\((.*)\)"
    match = re.fullmatch(pattern, line)

    line_name = match.group(1).strip()
    line_elements_string = match.group(2).strip()

    line_elements = []
    for element_name in line_elements_string.split(","):
        element_name = element_name.strip()

        line_elements.append(element_name)

    context[line_name] = line_elements

    return context

In [16]:
def define_overlay(line: str, context: dict) -> dict:
    """Define an overlay in the context."""
    knot_based_pattern = r"([a-z0-9_]+)\s*\:\s*overlay\s*=\s*\{(.*)\}\s*\,\s*var\s*=\s*\{\s*([a-z0-9_]+)\s*\}\s*\,\s*x_knot\s*=\s*\{(.*)\}"
    expression_based_pattern = r"([a-z0-9_]+)\s*\:\s*overlay\s*=\s*\{(.*)\}\s*\,\s*var\s*=\s*\{(.*)\}\s*(\,.*)*"

    expression_match = re.fullmatch(expression_based_pattern, line)
    knot_match = re.fullmatch(knot_based_pattern, line)

    if knot_match:
        overlay_name = knot_match.group(1).strip()
        overlay_definition = knot_match.group(2).strip()
        overlay_variable = knot_match.group(3).strip()
        overlay_x_knot = knot_match.group(4).strip()

        context[overlay_name] = {
            "overlay_definition": overlay_definition,
            "overlay_variable": overlay_variable,
            "overlay_x_knot": overlay_x_knot,
        }
    elif expression_match:
        overlay_name = expression_match.group(1).strip()
        overlay_definition = expression_match.group(2).strip()
        overlay_variables = expression_match.group(3).strip()
        if expression_match.group(4) is not None:
            overlay_parameters = expression_match.group(4).strip()[1:].strip()
        else:
            overlay_parameters = None

        context[overlay_name] = {
            "overlay_definition": overlay_definition,
            "overlay_variables": overlay_variables,
            "overlay_parameters": overlay_parameters,
        }
    else:
        raise ValueError(f"Overlay definition {line} not understood.")

    return context

In [17]:
def parse_use_line(line: str, context: dict) -> dict:
    """Parse a use line."""
    pattern = r"use\s*\,\s*([a-z0-9_]+)"
    match = re.fullmatch(pattern, line)

    use_line_name = match.group(1).strip()
    context["__use__"] = use_line_name

    return context

In [18]:
context = {
    "pi": scipy.constants.pi,
    "twopi": 2 * scipy.constants.pi,
    "c_light": scipy.constants.c,
    "emass": physical_constants["electron mass energy equivalent in MeV"][0] * 1e-3,
    "m_electron": physical_constants["electron mass energy equivalent in MeV"][0] * 1e6,
    "sqrt": math.sqrt,
    "asin": math.asin,
    "sin": math.sin,
    "cos": math.cos,
    "abs_func": abs,
    "raddeg": scipy.constants.degree,
}
for line in merged_lines:
    if re.fullmatch(property_assignment_pattern, line):
        context = assign_property(line, context)
    elif re.fullmatch(variable_assignment_pattern, line):
        context = assign_variable(line, context)
    elif re.fullmatch(line_definition_pattern, line):
        context = define_line(line, context)
    elif re.fullmatch(overlay_definition_pattern, line):
        context = define_overlay(line, context)
    elif re.fullmatch(element_definition_pattern, line):
        context = define_element(line, context)
    elif re.fullmatch(use_line_pattern, line):
        context = parse_use_line(line, context)

In [23]:
context["__use__"]

'cu_hxr'

In [24]:
context["cu_hxr"]

['gunl0a', 'l0al0b', 'lcls2cuh']

In [26]:
context["lcls2cuh"]

['lcls2cuc', 'bsyltuh']

In [27]:
context["lcls2cuc"]

['dl1', 'l1', 'bc1', 'l2', 'bc2', 'l3']

In [28]:
context["dl1"]

['dl1_1', 'dl1_2']

In [29]:
context["dl1_1"]

['begdl1_1',
 'emat',
 'de00',
 'de00a',
 'qe01_full',
 'de01a',
 'im02',
 'de01b',
 'vv02',
 'de01c',
 'qe02_full',
 'dh00',
 'lsrhtr',
 'dh06',
 'tcav0_full',
 'de02',
 'qe03_full',
 'de03a',
 'de03b',
 'sc7',
 'de03c',
 'qe04_full',
 'de04',
 'ws01',
 'de05',
 'otr1',
 'de05c',
 'vv03',
 'de06a',
 'rst1',
 'de06b',
 'ws02',
 'de05a',
 'mrk0',
 'de05a',
 'otr2',
 'de06d',
 'bpm10',
 'de06e',
 'ws03',
 'de05',
 'otr3',
 'de07',
 'qm01_full',
 'de08',
 'sc8',
 'de08a',
 'vv04',
 'de08b',
 'qm02_full',
 'de09',
 'dbmark82',
 'enddl1_1']

In [30]:
context["vv02"]

{'element_type': 'marker'}