In [1]:
# optionally add scripts location to path
if True:
    import sys
    sys.path.append("/home/physics3/rroussel/SLAC_Xopt/")
    sys.path.append("/home/physics3/rroussel/SLAC_Xopt/NN_prior/")

print(sys.path)

['/home/physics3/rroussel/SLAC_Xopt/lcls/nn_prior', '/usr/local/lcls/tools/python/toolbox', '/usr/local/lcls/package/anaconda/envs/rhel7_devel/lib/python39.zip', '/usr/local/lcls/package/anaconda/envs/rhel7_devel/lib/python3.9', '/usr/local/lcls/package/anaconda/envs/rhel7_devel/lib/python3.9/lib-dynload', '', '/home/physics3/.local/lib/python3.9/site-packages', '/usr/local/lcls/package/anaconda/envs/rhel7_devel/lib/python3.9/site-packages', '/home/physics3/rroussel/SLAC_Xopt/', '/home/physics3/rroussel/SLAC_Xopt/NN_prior/']


In [2]:
import time
from typing import Dict

import torch
import numpy as np
import matplotlib.pyplot as plt

from scripts.evaluate_function.screen_image import measure_beamsize, measure_background
from scripts.optimize_function import optimize_function
from scripts.characterize_emittance import characterize_emittance

from xopt import Xopt, VOCS
from xopt.evaluator import Evaluator
from xopt.numerical_optimizer import LBFGSOptimizer
from xopt.generators import ExpectedImprovementGenerator
from xopt.generators.bayesian.models.standard import StandardModelConstructor
from lume_model.utils import variables_from_yaml
from lume_model.torch import LUMEModule, PyTorchModel

from custom_mean import CustomMean
from dynamic_custom_mean import DynamicCustomMean
from metric_informed_custom_mean import MetricInformedCustomMean

## Initialization

In [3]:
# set up data saving locations
data_dir = "/home/physics3/ml_tuning/20230729_LCLS_Injector"

run_name = "nn_optimize_1"
run_dir = f"{data_dir}/{run_name}"
import os
if not os.path.exists(run_dir):
    os.mkdir(run_dir)

In [4]:
## import variable range
import pandas as pd
filename = "../variables.csv"
VARIABLE_RANGES = pd.read_csv(filename, index_col=0, header=None).T.to_dict(orient='list')
SCREEN_NAME = "OTRS:IN20:621" # OTR 3

In [5]:
# set up background
BACKGROUND_FILE = f"{data_dir}/{SCREEN_NAME}_background.npy".replace(":","_")

## define evaluate function

In [6]:
from epics import caget_many
import json
#with open("../../secondary_variables.json", "r") as f:
#    secondary_observables = json.loads(f)

# define function to measure the total size on OTR4
def eval_beamsize(input_dict):
    results = measure_beamsize(input_dict)

    # get secondary PV settings/readbacks
    # secondary_results = caget_many(secondary_observables)
    # results = results | secondary_results

    results["S_x_mm"] = np.array(results["Sx"]) * 1e-3
    results["S_y_mm"] = np.array(results["Sy"]) * 1e-3

    #add total beam size
    xnp = results["S_x_mm"]
    ynp = results["S_y_mm"]
    results["total_size"] = np.sqrt(xnp**2 + ynp**2)
    return results

## Create model

In [7]:
model_path = "/home/physics3/rroussel/lcls_cu_injector_nn_model/"

In [8]:
# load sim_to_nn transformers
input_sim_to_nn = torch.load(model_path + "model/input_sim_to_nn.pt")
output_sim_to_nn = torch.load(model_path + "model/output_sim_to_nn.pt")

# load pv_to_sim transformers
input_pv_to_sim = torch.load(model_path + "model/input_pv_to_sim.pt")
output_pv_to_sim = torch.load(model_path + "model/output_pv_to_sim.pt")

In [9]:
# load in- and output variable specification
input_variables, output_variables = variables_from_yaml(open(model_path + "model/pv_variables.yml"))

In [10]:
# create LUME-model
# replace keys in input variables
input_variables = {name.replace("BACT", "BCTRL"): ele for name, ele in input_variables.items()}

lume_model = PyTorchModel(
    model_file=model_path + "model/model.pt",
    input_variables=input_variables,
    output_variables=output_variables,
    input_transformers=[input_pv_to_sim, input_sim_to_nn],  # pv_to_sim before sim_to_nn
    output_transformers=[output_sim_to_nn, output_pv_to_sim],  # sim_to_nn before pv_to_sim
)
print(lume_model.features)

['CAMR:IN20:186:R_DIST', 'Pulse_length', 'FBCK:BCI0:1:CHRG_S', 'SOLN:IN20:121:BCTRL', 'QUAD:IN20:121:BCTRL', 'QUAD:IN20:122:BCTRL', 'ACCL:IN20:300:L0A_ADES', 'ACCL:IN20:300:L0A_PDES', 'ACCL:IN20:400:L0B_ADES', 'ACCL:IN20:400:L0B_PDES', 'QUAD:IN20:361:BCTRL', 'QUAD:IN20:371:BCTRL', 'QUAD:IN20:425:BCTRL', 'QUAD:IN20:441:BCTRL', 'QUAD:IN20:511:BCTRL', 'QUAD:IN20:525:BCTRL']


## Build vocs

In [11]:
# inputs
TUNING_VARIABLES = [
    "SOLN:IN20:121:BCTRL",
    "QUAD:IN20:121:BCTRL",
    "QUAD:IN20:122:BCTRL",
    "QUAD:IN20:361:BCTRL",
    "QUAD:IN20:371:BCTRL",
    "QUAD:IN20:425:BCTRL",
    "QUAD:IN20:441:BCTRL",
    "QUAD:IN20:511:BCTRL",
    "QUAD:IN20:525:BCTRL"
]

ROI = None
THRESHOLD = 3

measurement_options = {
    "screen": SCREEN_NAME,
    "background": BACKGROUND_FILE,
    "threshold": THRESHOLD,
    "roi": ROI,
    "bb_half_width": 2.0, # half width of the bounding box in terms of std
    "visualize": False,
    "save_img_location": run_dir,
    "sleep_time": 3.0,
    "n_shots": 5
}

image_constraints = {
    "bb_penalty": ["LESS_THAN", 0.0],
    "log10_total_intensity": ["GREATER_THAN", 4]
}

image_constraints = {
    "bb_penalty": ["LESS_THAN", 0.0],
    "log10_total_intensity": ["GREATER_THAN", 4]
}

vocs = VOCS(
    variables={ele: VARIABLE_RANGES[ele] for ele in TUNING_VARIABLES},
    constants=measurement_options,
    objectives={"total_size": "MINIMIZE"},
    constraints=image_constraints,
)

In [12]:
print(vocs.as_yaml())

variables:
  SOLN:IN20:121:BCTRL: [0.465, 0.48]
  QUAD:IN20:121:BCTRL: [-0.015, 0.015]
  QUAD:IN20:122:BCTRL: [-0.015, 0.015]
  QUAD:IN20:361:BCTRL: [-3.7, -3.0]
  QUAD:IN20:371:BCTRL: [2.448, 2.992]
  QUAD:IN20:425:BCTRL: [-3.0, 3.0]
  QUAD:IN20:441:BCTRL: [-3.0, 3.0]
  QUAD:IN20:511:BCTRL: [-3.0, 3.0]
  QUAD:IN20:525:BCTRL: [-5.0, 5.0]
constraints:
  bb_penalty: [LESS_THAN, 0.0]
  log10_total_intensity: [GREATER_THAN, 4.0]
objectives: {total_size: MINIMIZE}
constants: {screen: 'OTRS:IN20:621', background: /home/physics3/ml_tuning/20230729_LCLS_Injector/OTRS_IN20_621_background.npy,
  threshold: 3, roi: null, bb_half_width: 2.0, visualize: false, save_img_location: /home/physics3/ml_tuning/20230729_LCLS_Injector/nn_optimize_1,
  sleep_time: 3.0, n_shots: 5}
observables: []



## Define prior mean

In [13]:
# define custom mean
# wrap in LUMEModule
lume_module = LUMEModule(
    model=lume_model,
    feature_order=vocs.variable_names,
    output_order=lume_model.outputs[0:2],
)

# define objective model
class ObjectiveModel(torch.nn.Module):
    def __init__(self, model: LUMEModule):
        super(ObjectiveModel, self).__init__()
        self.model = model

    @staticmethod
    def function(sigma_x: torch.Tensor, sigma_y: torch.Tensor) -> torch.Tensor:
        # using this calculation due to occasional negative values
        return torch.sqrt(sigma_x ** 2 + sigma_y ** 2)

    def forward(self, x) -> torch.Tensor:
        idx_sigma_x = self.model.output_order.index("OTRS:IN20:571:XRMS")
        idx_sigma_y = self.model.output_order.index("OTRS:IN20:571:YRMS")
        sigma_x = self.model(x)[..., idx_sigma_x]
        sigma_y = self.model(x)[..., idx_sigma_y]
        return self.function(sigma_x, sigma_y)


objective_model = ObjectiveModel(lume_module)

print(lume_module.feature_order)
print(lume_module.output_order)

mean_class = CustomMean
mean_kwargs = {"model": objective_model}

['QUAD:IN20:121:BCTRL', 'QUAD:IN20:122:BCTRL', 'QUAD:IN20:361:BCTRL', 'QUAD:IN20:371:BCTRL', 'QUAD:IN20:425:BCTRL', 'QUAD:IN20:441:BCTRL', 'QUAD:IN20:511:BCTRL', 'QUAD:IN20:525:BCTRL', 'SOLN:IN20:121:BCTRL']
['OTRS:IN20:571:XRMS', 'OTRS:IN20:571:YRMS']


## Create Xopt instance

In [14]:
# Xopt definition
class ConstantPrior(torch.nn.Module):
    def forward(self, X):
        return torch.ones_like(X).squeeze(dim=-1)


model_constructor = StandardModelConstructor(
    mean_modules={"total_size": ConstantPrior()},
)
generator = ExpectedImprovementGenerator(
    vocs=vocs,
    model_constructor=model_constructor,
)
generator.numerical_optimizer.max_iter = 200
evaluator = Evaluator(function=eval_beamsize)
X = Xopt(generator=generator, evaluator=evaluator, vocs=vocs)

## Create initial samples

In [15]:
# create initial samples
n_init = 3
X.random_evaluate(n_samples=n_init)

CAPUT SOLN:IN20:121:BCTRL 0.4764640625054597
CAPUT QUAD:IN20:121:BCTRL 0.007687674501379412
CAPUT QUAD:IN20:122:BCTRL 0.013061132028389787
CAPUT QUAD:IN20:361:BCTRL -3.1647115149501728
CAPUT QUAD:IN20:371:BCTRL 2.5338658596085732
CAPUT QUAD:IN20:425:BCTRL -1.7235986031966783
CAPUT QUAD:IN20:441:BCTRL 0.720413205408446
CAPUT QUAD:IN20:511:BCTRL -1.3736112190369523
CAPUT QUAD:IN20:525:BCTRL -0.23788606607317853
CAPUT SOLN:IN20:121:BCTRL 0.4696208708140325
CAPUT QUAD:IN20:121:BCTRL -0.011665122522723066
CAPUT QUAD:IN20:122:BCTRL -0.002656760512756374
CAPUT QUAD:IN20:361:BCTRL -3.661616381157516
CAPUT QUAD:IN20:371:BCTRL 2.568857379837244
CAPUT QUAD:IN20:425:BCTRL -0.19324449612886596
CAPUT QUAD:IN20:441:BCTRL -1.2940203892881696
CAPUT QUAD:IN20:511:BCTRL -1.4279072911245496
CAPUT QUAD:IN20:525:BCTRL 4.2669434230676995
CAPUT SOLN:IN20:121:BCTRL 0.47997329899409985
CAPUT QUAD:IN20:121:BCTRL -0.012658097764569512
CAPUT QUAD:IN20:122:BCTRL 0.01154296391900237
CAPUT QUAD:IN20:361:BCTRL -3.0331

Unnamed: 0,SOLN:IN20:121:BCTRL,QUAD:IN20:121:BCTRL,QUAD:IN20:122:BCTRL,QUAD:IN20:361:BCTRL,QUAD:IN20:371:BCTRL,QUAD:IN20:425:BCTRL,QUAD:IN20:441:BCTRL,QUAD:IN20:511:BCTRL,QUAD:IN20:525:BCTRL,screen,...,Sy,bb_penalty,total_intensity,log10_total_intensity,time,S_x_mm,S_y_mm,total_size,xopt_runtime,xopt_error
1,0.476464,0.007688,0.013061,-3.164712,2.533866,-1.723599,0.720413,-1.373611,-0.237886,OTRS:IN20:621,...,"[101.27724349942272, 102.46104405310868, 106.5...",-201.010268,"[20943.305274115268, 20810.73854983238, 20212....",4.321045,"[1690675773.1961293, 1690675774.611186, 169067...","[0.1794187282388385, 0.18236752032125514, 0.16...","[0.10127724349942273, 0.10246104405310869, 0.1...",0.20603,10.762008,False
1,0.476464,0.007688,0.013061,-3.164712,2.533866,-1.723599,0.720413,-1.373611,-0.237886,OTRS:IN20:621,...,"[101.27724349942272, 102.46104405310868, 106.5...",-199.826507,"[20943.305274115268, 20810.73854983238, 20212....",4.318287,"[1690675773.1961293, 1690675774.611186, 169067...","[0.1794187282388385, 0.18236752032125514, 0.16...","[0.10127724349942273, 0.10246104405310869, 0.1...",0.20918,10.762008,False
1,0.476464,0.007688,0.013061,-3.164712,2.533866,-1.723599,0.720413,-1.373611,-0.237886,OTRS:IN20:621,...,"[101.27724349942272, 102.46104405310868, 106.5...",-207.311361,"[20943.305274115268, 20810.73854983238, 20212....",4.305617,"[1690675773.1961293, 1690675774.611186, 169067...","[0.1794187282388385, 0.18236752032125514, 0.16...","[0.10127724349942273, 0.10246104405310869, 0.1...",0.196092,10.762008,False
1,0.476464,0.007688,0.013061,-3.164712,2.533866,-1.723599,0.720413,-1.373611,-0.237886,OTRS:IN20:621,...,"[101.27724349942272, 102.46104405310868, 106.5...",-206.961309,"[20943.305274115268, 20810.73854983238, 20212....",4.336802,"[1690675773.1961293, 1690675774.611186, 169067...","[0.1794187282388385, 0.18236752032125514, 0.16...","[0.10127724349942273, 0.10246104405310869, 0.1...",0.187008,10.762008,False
1,0.476464,0.007688,0.013061,-3.164712,2.533866,-1.723599,0.720413,-1.373611,-0.237886,OTRS:IN20:621,...,"[101.27724349942272, 102.46104405310868, 106.5...",-209.323529,"[20943.305274115268, 20810.73854983238, 20212....",4.318746,"[1690675773.1961293, 1690675774.611186, 169067...","[0.1794187282388385, 0.18236752032125514, 0.16...","[0.10127724349942273, 0.10246104405310869, 0.1...",0.180666,10.762008,False
2,0.469621,-0.011665,-0.002657,-3.661616,2.568857,-0.193244,-1.29402,-1.427907,4.266943,OTRS:IN20:621,...,"[162.60438155432823, 186.08864949210735, 309.2...",-230.884829,"[9901.015827648345, 3777.588430796379, 3775.73...",3.99568,"[1690675783.4952655, 1690675785.0601184, 16906...","[0.33197943075664893, 0.48595313059818623, 0.4...","[0.16260438155432824, 0.18608864949210735, 0.3...",0.369663,10.32026,False
2,0.469621,-0.011665,-0.002657,-3.661616,2.568857,-0.193244,-1.29402,-1.427907,4.266943,OTRS:IN20:621,...,"[162.60438155432823, 186.08864949210735, 309.2...",-216.797266,"[9901.015827648345, 3777.588430796379, 3775.73...",3.577215,"[1690675783.4952655, 1690675785.0601184, 16906...","[0.33197943075664893, 0.48595313059818623, 0.4...","[0.16260438155432824, 0.18608864949210735, 0.3...",0.520365,10.32026,False
2,0.469621,-0.011665,-0.002657,-3.661616,2.568857,-0.193244,-1.29402,-1.427907,4.266943,OTRS:IN20:621,...,"[162.60438155432823, 186.08864949210735, 309.2...",-211.894361,"[9901.015827648345, 3777.588430796379, 3775.73...",3.577002,"[1690675783.4952655, 1690675785.0601184, 16906...","[0.33197943075664893, 0.48595313059818623, 0.4...","[0.16260438155432824, 0.18608864949210735, 0.3...",0.567757,10.32026,False
2,0.469621,-0.011665,-0.002657,-3.661616,2.568857,-0.193244,-1.29402,-1.427907,4.266943,OTRS:IN20:621,...,"[162.60438155432823, 186.08864949210735, 309.2...",-227.79147,"[9901.015827648345, 3777.588430796379, 3775.73...",3.583598,"[1690675783.4952655, 1690675785.0601184, 16906...","[0.33197943075664893, 0.48595313059818623, 0.4...","[0.16260438155432824, 0.18608864949210735, 0.3...",0.514341,10.32026,False
2,0.469621,-0.011665,-0.002657,-3.661616,2.568857,-0.193244,-1.29402,-1.427907,4.266943,OTRS:IN20:621,...,"[162.60438155432823, 186.08864949210735, 309.2...",-236.466396,"[9901.015827648345, 3777.588430796379, 3775.73...",3.710182,"[1690675783.4952655, 1690675785.0601184, 16906...","[0.33197943075664893, 0.48595313059818623, 0.4...","[0.16260438155432824, 0.18608864949210735, 0.3...",0.507349,10.32026,False


## Run optimization

In [16]:
n_step = 50
for step in range(n_step):
    # define prior mean
    # optimization step
    t0 = time.time()
    X.step()
    print("Completed step {:d} ({:.2f} sec)".format(step, time.time() - t0))

RuntimeError: Shape mismatch: objects cannot be broadcast to a single shape