In [1]:
# import sys
# sys.path.append('./models/')
# from BPMQ_model import BPMQ_model

In [2]:
import re
from pprint import pprint

def parse_lattice(lattice_text):
    """Efficiently parses the lattice definition and returns dictionaries of element properties and line definitions."""
    
    # Pre-compile regex patterns
    element_pattern = re.compile(r'([a-zA-Z0-9_:]+):\s*([a-zA-Z]+)(.*?);', re.DOTALL)
    line_pattern = re.compile(r'([a-zA-Z0-9_]+):\s*LINE\s*=\s*\((.*?)\);', re.DOTALL)
    property_pattern = re.compile(r'([a-zA-Z0-9_]+)\s*=\s*([+-]?\d*\.?\d+([eE][+-]?\d+)?)')
    
    # Clean and preprocess lines
    cleaned_lines = (re.sub(r'#.*', '', line).strip() for line in lattice_text.splitlines())
    lattice_text = "\n".join(filter(None, cleaned_lines))  # Remove empty lines
    
    elements = {}
    lines = {}
    
    # Parse element definitions
    for match in element_pattern.finditer(lattice_text):
        name, type_, properties_str = match.groups()
        properties = {"name": name, "type": type_, "L": 0}
        for prop_match in property_pattern.finditer(properties_str.strip()):
            prop_name, prop_value = prop_match.group(1), prop_match.group(2)
            properties[prop_name] = float(prop_value) if '.' in prop_value or 'e' in prop_value.lower() else int(prop_value)
        elements[name] = properties
    
    # Parse LINE definitions
    for match in line_pattern.finditer(lattice_text):
        line_name, line_content = match.groups()
        lines[line_name] = [elem.strip() for elem in line_content.split(",")]
    
    return elements, lines

def get_all_line_elements(line_name, lines, elements):
    """
    Recursively retrieves all elements from a nested LINE definition with indices, positions, and original properties.
    """
    visited = set()  # Avoid infinite recursion

    def recursive_helper(line_name):
        if line_name in visited:
            return []  # Avoid infinite recursion
        visited.add(line_name)
        
        all_elements = []
        if line_name in lines:  # If line_name exists in lines
            for elem in lines[line_name]:
                if elem in lines:  # Nested LINE definition
                    all_elements.extend(recursive_helper(elem))
                else:  # Single element or multiplied element
                    match = re.match(r"([a-zA-Z0-9_:]+)\*(\d+)", elem)
                    if match:
                        base_elem, multiplier = match.groups()
                        for _ in range(int(multiplier)):
                            all_elements.append(elements.get(base_elem, {"name": base_elem}))
                    else:
                        all_elements.append(elements.get(elem, {"name": elem}))
        return all_elements

    # Retrieve elements
    flat_elements = recursive_helper(line_name)
    
    # Add indices and positions
    cumulative_pos = 0
    for idx, elem in enumerate(flat_elements):
        elem["index"] = idx
        elem["pos"] = cumulative_pos
        cumulative_pos += elem.get("L", 0)  # Increment cumulative position by the element length (L)
    return flat_elements


def combine_lattice_elements_quads_only(filename, from_elem, to_elem, marker_types=['bpm'], Brho=None, line_name=None):    
    with open(filename, 'r') as file:
        lattice_text = file.read()
    
    # Find default line_name if not provided
    if line_name is None:
        use_match = re.search(r'USE:\s*([a-zA-Z0-9_]+);', lattice_text)
        if use_match:
            line_name = use_match.group(1)
        else:
            raise ValueError("No default LINE name found in the lattice file, and no line_name was provided.")

    elements, lines = parse_lattice(lattice_text)
    all_line_elements = get_all_line_elements(line_name, lines, elements)

    # Map from element name to element data
    name_to_data = {elem["name"]: elem for elem in all_line_elements}

    # Ensure from_elem and to_elem exist in the element list
    if from_elem not in name_to_data:
        raise ValueError(f"Starting element '{from_elem}' not found in the LINE.")
    if to_elem not in name_to_data:
        raise ValueError(f"Ending element '{to_elem}' not found in the LINE.")

    # Get indices of from_elem and to_elem
    start_index = name_to_data[from_elem]["index"]
    end_index = name_to_data[to_elem]["index"]

    if start_index > end_index:
        raise ValueError("Starting element must come before ending element in the LINE.")

    # Subset of elements between from_elem and to_elem
    sub_elements = [elem for elem in all_line_elements if start_index <= elem["index"] <= end_index]

    combined_elements = []
    current_drift = None

    for elem in sub_elements[:-1]:
        elem_type = elem.get("type", "").lower()

        if elem_type == "quadrupole":
            # Handle quadrupoles
            if current_drift:
                combined_elements.append(current_drift)
                current_drift = None     
            elem['Brho'] = Brho
            combined_elements.append(elem)

        elif elem_type in marker_types:
            # Convert BPMs to drifts
            if current_drift:
                combined_elements.append(current_drift)
            current_drift = elem
            current_drift["type"] = "drift"

        else:
            # Combine other element types into a drift
            if current_drift:
                current_drift["L"] += elem.get("L", 0)
            else:
                current_drift = {
                    "name": elem["name"],
                    "type": "drift",
                    "index": elem["index"],
                    "pos": elem["pos"],
                    "L": elem.get("L", 0),
                    "aper": elem.get("aper", 0.1),   
                }
    # Append the last drift
    if current_drift:
        combined_elements.append(current_drift)

    combined_elements.append(sub_elements[-1])
    
    return combined_elements


# main function

In [3]:
# Example usage (assuming you have the helper functions defined as before):
filename = "lattice.txt"  # Replace with your lattice file
from_element = "BDS_BTS:QV_D5501"
to_element = "BDS_BTS:PM_D5567"
elements = combine_lattice_elements_quads_only(filename, from_element, to_element)#, line_name = "cell")
elements

[{'name': 'BDS_BTS:QV_D5501',
  'type': 'quadrupole',
  'L': 0.261,
  'B2': -10.8047349,
  'aper': 0.025,
  'index': 157,
  'pos': 80.116998,
  'Brho': None},
 {'name': 'drift_567',
  'type': 'drift',
  'index': 158,
  'pos': 80.37799799999999,
  'L': 0.489,
  'aper': 0.1},
 {'name': 'BDS_BTS:QH_D5509',
  'type': 'quadrupole',
  'L': 0.261,
  'B2': 11.477140200000001,
  'aper': 0.025,
  'index': 159,
  'pos': 80.866998,
  'Brho': None},
 {'name': 'drift_568',
  'type': 'drift',
  'index': 160,
  'pos': 81.12799799999999,
  'L': 0.268080814,
  'aper': 0.1},
 {'name': 'BDS_BTS:BPM_D5513',
  'type': 'drift',
  'L': 3.779195186,
  'index': 161,
  'pos': 81.39607881399999},
 {'name': 'BDS_BTS:QV_D5552',
  'type': 'quadrupole',
  'L': 0.261,
  'B2': -16.661776409999998,
  'aper': 0.025,
  'index': 169,
  'pos': 85.17527399999999,
  'Brho': None},
 {'name': 'drift_573',
  'type': 'drift',
  'index': 170,
  'pos': 85.43627399999998,
  'L': 0.489,
  'aper': 0.1},
 {'name': 'BDS_BTS:QH_D5559',
 

### check internal function

In [4]:
filename = "lattice.txt"  # Replace with your lattice file
with open(filename, 'r') as file:
    lattice_text = file.read()

elements, lines = parse_lattice(lattice_text)

# Retrieve all elements in a line with indices and properties
line_name = "cell"
all_line_elements = get_all_line_elements(line_name, lines, elements)

# print("Elements:", elements)
i = 33
print(lines['ls3bts'][i:i+4])
print()
print(all_line_elements[i+1:i+5])


['LS3_BTS:PM_D4827', 'drift_514', 'LS3_BTS:DCH_D4841', 'LS3_BTS:DCV_D4841']

[{'name': 'LS3_BTS:PM_D4827', 'type': 'marker', 'L': 0, 'index': 34, 'pos': 12.843081999999997}, {'name': 'drift_514', 'type': 'drift', 'L': 1.391397, 'index': 35, 'pos': 12.843081999999997}, {'name': 'LS3_BTS:DCH_D4841', 'type': 'orbtrim', 'L': 0, 'realpara': 1.0, 'tm_xkick': 0.00030752094, 'index': 36, 'pos': 14.234478999999997}, {'name': 'LS3_BTS:DCV_D4841', 'type': 'orbtrim', 'L': 0, 'realpara': 1.0, 'tm_ykick': 0.0, 'index': 37, 'pos': 14.234478999999997}]


# write flame file

In [5]:
import numpy as np
import torch

def update_lattice_file(
    filename,
    IonEk=None,
    IonQ=None,
    IonA=None,
    IonChargeStates=None,
    NCharge=None,
    BaryCenter0=None,
    S0=None,
):
    """
    Updates the lattice file with new values for specified parameters.

    Parameters:
        filename (str): The path to the lattice file.
        IonEk (float, optional): New kinetic energy [eV/u].
        IonQ (float, optional): New charge state.
        IonA (float, optional): New mass number.
        IonChargeStates (list[float], optional): New charge states.
        NCharge (list[float], optional): New number of charges.
        BaryCenter0 (list, np.array, or torch.Tensor, optional): New barycenter (shape 7).
        S0 (list, np.array, or torch.Tensor, optional): New beam envelope parameters (shape 49 or 7x7).

    Returns:
        None: Modifies the file in place.
    """
    # Read the file
    with open(filename, 'r') as file:
        lines = file.readlines()

    updated_lines = []
    skip_until_semicolon = False

    for line in lines:
        stripped = line.strip()

        # Keep commented lines intact
        if stripped.startswith("#"):
            updated_lines.append(line)
            continue

        # Skip lines until the end of the current block (marked by ';')
        if skip_until_semicolon:
            if stripped.endswith(";"):
                skip_until_semicolon = False
            continue

        # Update IonEk
        if IonEk is not None and stripped.startswith("IonEk"):
            updated_lines.append(f"IonEk = {IonEk}; \n")
            continue

        # Update IonQ
        if IonQ is not None and stripped.startswith("IonQ"):
            updated_lines.append(f"IonQ = {IonQ}; \n")
            continue

        # Update IonA
        if IonA is not None and stripped.startswith("IonA"):
            updated_lines.append(f"IonA = {IonA}; \n")
            continue

        # Update IonChargeStates
        if IonChargeStates is not None and stripped.startswith("IonChargeStates"):
            charges = ", ".join(map(str, IonChargeStates))
            updated_lines.append(f"IonChargeStates = [{charges}]; \n")
            continue

        # Update NCharge
        if NCharge is not None and stripped.startswith("NCharge"):
            charges = ", ".join(map(str, NCharge))
            updated_lines.append(f"NCharge = [{charges}]; \n")
            continue

        # Update BaryCenter0
        if BaryCenter0 is not None and stripped.startswith("BaryCenter0"):
            if isinstance(BaryCenter0, (np.ndarray, torch.Tensor)):
                BaryCenter0 = BaryCenter0.tolist()
            barycenter = ", ".join(map(str, BaryCenter0))
            updated_lines.append(f"BaryCenter0 = [{barycenter}]; \n")
            continue

        # Update S0
        if S0 is not None and stripped.startswith("S0"):
            skip_until_semicolon = True  # Skip existing S0 block
            if isinstance(S0, (np.ndarray, torch.Tensor)):
                S0 = S0.flatten().tolist()
            s_values = ",\n    ".join(", ".join(map(str, S0[i:i+7])) for i in range(0, len(S0), 7))
            updated_lines.append(f"S0 = [\n    {s_values}\n]; \n")
            continue

        # Keep other lines intact
        updated_lines.append(line)

    # Combine updated lines
    updated_content = "".join(updated_lines)

    # Write back to the file
    with open(filename, 'w') as file:
        file.write(updated_content)

# Example usage
if __name__ == "__main__":
    update_lattice_file(
        "lattice.txt",
        IonEk=300000.0,
        BaryCenter0=torch.zeros(7),#[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
        S0=np.zeros((7, 7))
    )


In [6]:
# def twiss2covar(xalpha: float, xbeta: float, xnemit: float,
#                 yalpha: float, ybeta: float, ynemit: float,
#                 cxy: float, cxyp: float, cxpy: float, cxpyp: float,
#                 betagamma: float) -> Optional[np.ndarray]:
#     """
#     Map twiss parameters to covariance matrix.

#     Parameters:
#     - xalpha (float): Initial xalpha value.
#     - xbeta (float): Initial xbeta value.
#     - xnemit (float): Initial xnemit value.
#     - yalpha (float): Initial yalpha value.
#     - ybeta (float): Initial ybeta value.
#     - ynemit (float): Initial ynemit value.
#     - cxy (float): Initial cxy value.
#     - cxyp (float): Initial cxyp value.
#     - cxpy (float): Initial cxpy value.
#     - cxpyp (float): Initial cxpyp value.
#     - betagamma (float): Product of beta and gamma.

#     Returns:
#     np.ndarray: Covariance matrix.

#     This function maps twiss parameters to a covariance matrix using predefined formulas.
#     """
#     if xbeta < 1e-1 or ybeta < 1e-1 or \
#        xnemit < 1e-4 or ynemit < 1e-4 or \
#        xnemit > 1.2 or ynemit > 1.2 or \
#        math.fabs(cxy) >= 1 or math.fabs(cxyp) >= 1 or math.fabs(cxpy) >= 1 or math.fabs(cxpyp) >= 1:
#         raise ValueError('RMS beam quantities are not physical')
#     xalpha = -xalpha
#     yalpha = -yalpha
#     xemit = xnemit / betagamma
#     xgamma, xrms, xprms = (1 + xalpha**2) / xbeta, (xemit * xbeta)**0.5, (xemit * ((1 + xalpha**2) / xbeta))**0.5
#     yemit = ynemit / betagamma
#     ygamma, yrms, yprms = (1 + yalpha**2) / ybeta, (yemit * ybeta)**0.5, (yemit * ((1 + yalpha**2) / ybeta))**0.5
#     c13, c14, c23, c24 = cxy * xrms * yrms, cxyp * xrms * yprms, cxpy * xprms * yrms, cxpyp * xprms * yprms
#     moment1 = np.array([
#         [xemit * xbeta, xemit * xalpha * 1e-3, c13, c14 * 1e-3],
#         [xemit * xalpha * 1e-3, xemit * xgamma * 1e-6, c23 * 1e-3, c24 * 1e-6],
#         [c13, c23 * 1e-3, yemit * ybeta, yemit * yalpha * 1e-3],
#         [c14 * 1e-3, c24 * 1e-6, yemit * yalpha * 1e-3, yemit * ygamma * 1e-6]
#     ])

#     Det4D = np.linalg.det(moment1[0:4, 0:4])
#     if Det4D < 1.0e-30:
#         print('RMS beam quantities are not physical')
#         return None

#     return moment1

NameError: name 'Optional' is not defined