In [1]:
from plugins.environments.awa_environment import AWAEnvironment
from plugins.interfaces.awa_interface import AWAInterface

# import data from csv file
import pandas as pd
variable_file = "plugins/environments/awa_variables.csv"
observable_file = "plugins/environments/awa_observables.csv"

env = AWAEnvironment(variable_file, observable_file, interface=AWAInterface(), target_charge=1.0)

In [2]:
env.variables

{'AWA:Drive:DS1:Ctrl': [500.0, 550.0],
 'AWA:Drive:DS3:Ctrl': [180.0, 260.0],
 'AWA:Bira3Ctrl:Ch03': [-5.0, 5.0],
 'AWA:Bira3Ctrl:Ch04': [-5.0, 5.0],
 'AWA:Bira3Ctrl:Ch05': [-5.0, 5.0],
 'AWA:Bira3Ctrl:Ch06': [-5.0, 5.0],
 'AWALLRF:K1:SetPhase': [236.0, 256.0],
 'AWA:DAC0:Ch08': [4.5, 5.9]}

In [3]:
from xopt import Xopt, Evaluator, VOCS
from xopt.generators.bayesian import BayesianExplorationGenerator, UpperConfidenceBoundGenerator
from xopt.generators.bayesian.models.standard import StandardModelConstructor
from xopt.generators.bayesian.turbo import SafetyTurboController, OptimizeTurboController

import time

def evaluate(inputs):
    env.set_variables(inputs)
    time.sleep(2.0)
    result = env.get_observables(["13ARV1:Sx"])
    result["total_rms_size"] = (result["13ARV1:Sx"]**2 + result["13ARV1:Sy"]**2)**0.5
    return result

# use only one variable
vocs = VOCS(variables=env.variables,
            objectives={"total_rms_size":"MINIMIZE"},
            constraints={"13ARV1:penalty":["LESS_THAN",0.0]})


In [4]:

model_constructor = StandardModelConstructor(use_low_noise_prior=False)
generator = BayesianExplorationGenerator(
    vocs=vocs, 
    model_constructor=model_constructor,
    turbo_controller=SafetyTurboController(vocs=vocs, length=0.1)
)
evaluator = Evaluator(function=evaluate)
X = Xopt(vocs=vocs, evaluator=evaluator, generator=generator)
X.options.dump_file = "exploration_2_nd_filter.yml"

In [5]:
X


            Xopt
________________________________
Version: 1.4.1+50.ge9fc8ac.dirty
Data size: 0
Config as YAML:
xopt: {asynch: false, strict: false, dump_file: exploration_2_nd_filter.yml, max_evaluations: null}
generator:
  name: bayesian_exploration
  model: null
  turbo_controller:
    dim: 8
    batch_size: 1
    length: 0.1
    length_min: 0.0078125
    length_max: 2.0
    failure_counter: 0
    failure_tolerance: 4
    success_counter: 0
    success_tolerance: 4
    center_x: null
    scale_factor: 1.25
    tkwargs: {dtype: torch.float64}
  use_cuda: false
  model_constructor:
    name: standard
    use_low_noise_prior: false
    covar_modules: {}
    mean_modules: {}
    trainable_mean_keys: []
    dtype: torch.float64
    device: cpu
  numerical_optimizer: {name: LBFGS, n_raw_samples: 20, n_restarts: 20, max_iter: 2000}
  max_travel_distances: null
  n_monte_carlo_samples: 128
evaluator:
  function: __main__.evaluate
  max_workers: 1
  function_kwargs: {}
  vectorized: false
v

In [6]:
# get the current quad setpoint
#current_val = env.get_variables(['AWA:Bira3Ctrl:Ch03'])
import numpy as np
default_pt = np.array([0, 0, 0, 0,5.9,550,190,246])
default_val = dict(zip(X.vocs.variable_names, default_pt))
default_val = pd.DataFrame(default_val, index=[0])
default_val

Unnamed: 0,AWA:Bira3Ctrl:Ch03,AWA:Bira3Ctrl:Ch04,AWA:Bira3Ctrl:Ch05,AWA:Bira3Ctrl:Ch06,AWA:DAC0:Ch08,AWA:Drive:DS1:Ctrl,AWA:Drive:DS3:Ctrl,AWALLRF:K1:SetPhase
0,0.0,0.0,0.0,0.0,5.9,550.0,190.0,246.0


In [7]:
# evaluate that point in xopt
X.evaluate_data(default_val)

CA.Client.Exception...............................................
    Context: "Channel: "13ARV1:image1:ArraySize1_RBV", Connecting to: 146.139.52.185:5064, Ignored: awa5:5064"
    Source File: ../cac.cpp line 1320
    Current Time: Thu Jun 15 2023 15:04:23.649051673
..................................................................
CA.Client.Exception...............................................
    Context: "Channel: "13ARV1:image1:ArrayData", Connecting to: 146.139.52.185:5064, Ignored: awa5:5064"
    Source File: ../cac.cpp line 1320
    Current Time: Thu Jun 15 2023 15:04:23.649200834
..................................................................
CA.Client.Exception...............................................
    Context: "Channel: "13ARV1:image1:ArraySize0_RBV", Connecting to: 146.139.52.185:5064, Ignored: awa5:5064"
    Source File: ../cac.cpp line 1320
    Current Time: Thu Jun 15 2023 15:04:23.649287060
................................................................

Unnamed: 0,AWA:Bira3Ctrl:Ch03,AWA:Bira3Ctrl:Ch04,AWA:Bira3Ctrl:Ch05,AWA:Bira3Ctrl:Ch06,AWA:DAC0:Ch08,AWA:Drive:DS1:Ctrl,AWA:Drive:DS3:Ctrl,AWALLRF:K1:SetPhase,13ARV1:image1:ArraySize1_RBV,AWAVXI11ICT:Ch1,...,AWAVXI11ICT:Ch1_std,13ARV1:image1:ArraySize0_RBV_std,13ARV1:Cx_std,13ARV1:Cy_std,13ARV1:Sx_std,13ARV1:Sy_std,13ARV1:penalty_std,total_rms_size,xopt_runtime,xopt_error
1,0.0,0.0,0.0,0.0,5.9,550.0,190.0,246.0,1200.0,9.637042e-10,...,5.452919e-11,0.0,1.042786,1.078548,2.934228,1.814375,6.755243,57.000278,6.519966,False


In [8]:
# evaluate a second nearby point
second_pt = np.array([0.1, 0.1, 0.1, 0.1, 5.8,540,185,245])
second_pt = dict(zip(X.vocs.variable_names, second_pt))
second_pt = pd.DataFrame(second_pt, index=[0])
X.evaluate_data(second_pt)

Unnamed: 0,AWA:Bira3Ctrl:Ch03,AWA:Bira3Ctrl:Ch04,AWA:Bira3Ctrl:Ch05,AWA:Bira3Ctrl:Ch06,AWA:DAC0:Ch08,AWA:Drive:DS1:Ctrl,AWA:Drive:DS3:Ctrl,AWALLRF:K1:SetPhase,13ARV1:image1:ArraySize1_RBV,AWAVXI11ICT:Ch1,...,AWAVXI11ICT:Ch1_std,13ARV1:image1:ArraySize0_RBV_std,13ARV1:Cx_std,13ARV1:Cy_std,13ARV1:Sx_std,13ARV1:Sy_std,13ARV1:penalty_std,total_rms_size,xopt_runtime,xopt_error
2,0.1,0.1,0.1,0.1,5.8,540.0,185.0,245.0,1200.0,9.656808e-10,...,3.334863e-11,0.0,3.716726,0.674299,5.276837,0.253519,8.346779,105.464284,7.477461,False


In [9]:
X.data

Unnamed: 0,AWA:Bira3Ctrl:Ch03,AWA:Bira3Ctrl:Ch04,AWA:Bira3Ctrl:Ch05,AWA:Bira3Ctrl:Ch06,AWA:DAC0:Ch08,AWA:Drive:DS1:Ctrl,AWA:Drive:DS3:Ctrl,AWALLRF:K1:SetPhase,13ARV1:image1:ArraySize1_RBV,AWAVXI11ICT:Ch1,...,AWAVXI11ICT:Ch1_std,13ARV1:image1:ArraySize0_RBV_std,13ARV1:Cx_std,13ARV1:Cy_std,13ARV1:Sx_std,13ARV1:Sy_std,13ARV1:penalty_std,total_rms_size,xopt_runtime,xopt_error
1,0.0,0.0,0.0,0.0,5.9,550.0,190.0,246.0,1200.0,9.637042e-10,...,5.452919e-11,0.0,1.042786,1.078548,2.934228,1.814375,6.755243,57.000278,6.519966,False
2,0.1,0.1,0.1,0.1,5.8,540.0,185.0,245.0,1200.0,9.656808e-10,...,3.334863e-11,0.0,3.716726,0.674299,5.276837,0.253519,8.346779,105.464284,7.477461,False


In [10]:
# run exploration
n_steps = 200
X.generator.numerical_optimizer.max_iter = 50
for i in range(n_steps):
    print(i)
    start = time.time()
    X.step()
    print(time.time() - start)
    

0
9.095255851745605
1
8.19329309463501
2
10.320689916610718
3
13.00560975074768
4
11.943837642669678
5
9.826449155807495
6


Trying again with a new set of initial conditions.


13.45125126838684
7
9.964389324188232
8
15.270519495010376
9
11.201481580734253
10
9.869854211807251
11
8.958905696868896
12
11.765093088150024
13
10.46632981300354
14
10.548745393753052
15
10.273592472076416
16
10.465414762496948
17
10.991957902908325
18
11.500792980194092
19
9.606010913848877
20
10.343140840530396
21
11.435575246810913
22
10.359249114990234
23
11.51128077507019
24
10.280836582183838
25
10.48784875869751
26
10.821216583251953
27
10.195933818817139
28
11.997673034667969
29
10.234082698822021
30
10.986330270767212
31
17.041923761367798
32
10.975954294204712
33
11.51518702507019
34
6.8291566371917725
35
7.4649152755737305
36
7.635048151016235
37
7.699148178100586
38
6.661015033721924
39
11.032748937606812
40
8.705167293548584
41
13.775813341140747
42
12.561385869979858
43
10.658863306045532
44
13.144157409667969
45
11.66443133354187
46
13.216874837875366
47
15.07082724571228
48
12.27759599685669
49
15.233331203460693
50
13.216281652450562
51
13.233703851699829
52
14.0247

Trying again with a new set of initial conditions.


14.784202098846436
93
13.733761310577393
94
13.796674966812134
95


Trying again with a new set of initial conditions.


17.238654136657715
96
16.49406123161316
97
17.031900882720947
98
10.283908367156982
99
10.508649349212646
100
11.854274034500122
101
8.635127782821655
102
8.590194463729858
103
10.92165470123291
104
11.808954238891602
105
19.461774826049805
106
15.994142293930054
107


Trying again with a new set of initial conditions.


14.794168472290039
108
16.72261333465576
109
12.300495862960815
110
14.766252756118774
111
11.86870789527893
112
13.092350721359253
113
11.506585597991943
114
10.044602870941162
115


Trying again with a new set of initial conditions.


12.014853954315186
116
11.967277765274048
117
11.489463806152344
118
20.556246042251587
119
14.458847761154175
120
14.02691102027893
121
14.285445213317871
122
16.7245934009552
123
14.428364038467407
124


Trying again with a new set of initial conditions.


17.368242025375366
125


Trying again with a new set of initial conditions.


14.007001399993896
126
11.733444213867188
127


Trying again with a new set of initial conditions.


13.055169343948364
128
12.060308933258057
129
12.168162822723389
130
11.794053316116333
131
20.444621562957764
132
11.562167167663574
133
15.958341598510742
134
11.31304669380188
135
17.67631435394287
136
13.518329858779907
137
13.822909355163574
138
16.180906534194946
139
14.996986627578735
140
10.612012386322021
141
16.19920539855957
142
15.500113487243652
143
17.218245029449463
144
11.512613534927368
145
13.28104305267334
146
11.071063756942749
147
14.674134016036987
148
11.783320426940918
149


Trying again with a new set of initial conditions.


9.228481531143188
150
12.662703275680542
151
14.668523788452148
152
12.156877279281616
153
11.523325204849243
154
16.745733499526978
155
13.995652675628662
156
12.701575994491577
157
12.388213157653809
158
11.419625997543335
159
15.015034675598145
160
10.50534176826477
161
16.7334725856781
162
12.780522108078003
163
14.132830381393433
164


Trying again with a new set of initial conditions.


10.941314935684204
165
16.67477250099182
166
11.99449896812439
167
11.900442838668823
168
11.887123346328735
169
13.73704981803894
170
14.509513139724731
171
12.025269269943237
172
11.98357343673706
173
15.511986494064331
174
11.35812497138977
175
16.944671154022217
176
11.410186767578125
177
15.325345516204834
178
12.509100437164307
179
16.07947301864624
180
12.43864631652832
181
17.332826614379883
182
12.738662958145142
183
12.360551118850708
184
17.902915716171265
185
12.213597297668457
186


Trying again with a new set of initial conditions.


11.751350164413452
187
12.25439190864563
188
15.027029991149902
189
10.684764623641968
190
17.33001184463501
191
15.043863773345947
192
11.779417514801025
193
13.6987144947052
194
14.069447040557861
195
15.591857433319092
196
13.733793258666992
197
11.955246925354004
198


Trying again with a new set of initial conditions.


KeyboardInterrupt: 

In [11]:
X.data

Unnamed: 0,AWA:Bira3Ctrl:Ch03,AWA:Bira3Ctrl:Ch04,AWA:Bira3Ctrl:Ch05,AWA:Bira3Ctrl:Ch06,AWA:DAC0:Ch08,AWA:Drive:DS1:Ctrl,AWA:Drive:DS3:Ctrl,AWALLRF:K1:SetPhase,13ARV1:image1:ArraySize1_RBV,AWAVXI11ICT:Ch1,...,AWAVXI11ICT:Ch1_std,13ARV1:image1:ArraySize0_RBV_std,13ARV1:Cx_std,13ARV1:Cy_std,13ARV1:Sx_std,13ARV1:Sy_std,13ARV1:penalty_std,total_rms_size,xopt_runtime,xopt_error
1,0.000000,0.000000,0.000000,0.000000,5.900000,550.000000,190.000000,246.000000,1200.0,9.637042e-10,...,5.452919e-11,0.0,1.042786,1.078548,2.934228,1.814375,6.755243e+00,57.000278,6.519966,False
2,0.100000,0.100000,0.100000,0.100000,5.800000,540.000000,185.000000,245.000000,1200.0,9.656808e-10,...,3.334863e-11,0.0,3.716726,0.674299,5.276837,0.253519,8.346779e+00,105.464284,7.477461,False
3,-0.454108,-0.454108,-0.454108,0.554108,5.779836,547.391843,183.484956,246.505424,1200.0,1.030938e-09,...,4.375419e-11,0.0,2.125831,4.200516,1.526299,2.346365,8.304717e+00,62.265338,5.903093,False
4,0.395742,-0.631813,0.395742,-0.304557,5.750015,547.670059,181.811309,244.845979,1200.0,9.630984e-10,...,6.639414e-11,0.0,0.602160,0.573510,1.498979,1.260569,3.907168e+00,55.391063,6.975663,False
5,-0.510595,0.260076,0.558230,-0.461170,5.885191,544.893596,180.652692,246.672404,1200.0,9.707082e-10,...,4.464725e-11,0.0,0.547941,1.253067,1.349200,1.924967,5.067530e+00,59.863967,6.340186,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
196,0.201306,-1.196940,0.335558,-2.702751,5.125500,510.924392,254.353161,248.218073,1200.0,9.604107e-10,...,3.400496e-11,0.0,,,,,6.355287e-14,,6.860541,False
197,-4.949508,1.660176,0.919836,-4.569611,4.500000,511.401803,255.130379,245.571934,1200.0,9.891284e-10,...,6.017652e-11,0.0,,,,,6.355287e-14,,5.877605,False
198,1.906800,-1.725656,0.819764,-1.156309,4.978984,535.988141,197.978271,249.797942,1200.0,9.576114e-10,...,4.131832e-11,0.0,0.436775,0.715766,0.135660,0.802182,7.094045e-01,26.709147,7.313716,False
199,3.436252,-2.773484,-1.583363,1.489563,4.500000,505.630724,246.203325,248.181481,1200.0,9.486493e-10,...,2.258473e-11,0.0,0.678232,2.076514,0.231962,1.494963,2.419613e+00,52.268928,6.836376,False


In [12]:
X.evaluate_data(X.data.iloc[-4])

Unnamed: 0,197,xopt_runtime,xopt_error,xopt_error_str
201,-4.949508,0.001524,True,"Traceback (most recent call last):\n File ""/h..."
202,1.660176,6.8e-05,True,"Traceback (most recent call last):\n File ""/h..."
203,0.919836,4.3e-05,True,"Traceback (most recent call last):\n File ""/h..."
204,-4.569611,3.9e-05,True,"Traceback (most recent call last):\n File ""/h..."
205,4.5,3.7e-05,True,"Traceback (most recent call last):\n File ""/h..."
206,511.401803,3.8e-05,True,"Traceback (most recent call last):\n File ""/h..."
207,255.130379,3.7e-05,True,"Traceback (most recent call last):\n File ""/h..."
208,245.571934,3.7e-05,True,"Traceback (most recent call last):\n File ""/h..."
209,1200.0,3.7e-05,True,"Traceback (most recent call last):\n File ""/h..."
210,0.0,3.6e-05,True,"Traceback (most recent call last):\n File ""/h..."


In [None]:
X.data.plot(y=X.vocs.variable_names[:4])

In [None]:
X.data.plot(y=X.vocs.variable_names[4])

In [None]:
X.data.plot(y=X.vocs.variable_names[5:])

In [None]:
import torch
import matplotlib.pyplot as plt
test_x = torch.linspace(-2,2, 100)
model = X.generator.train_model()

fig,ax = plt.subplots(2,1, sharex="all")
fig.set_size_inches(6,6)
with torch.no_grad():
    post = model.posterior(test_x.reshape(-1,1,1).double())
    for i in range(post.event_shape[-1]):
        mean = post.mean[...,i].squeeze()
        l,u = post.mvn.confidence_region()
        ax[0].plot(test_x, mean,f"C{i}", label=generator.vocs.output_names[i])
        ax[0].fill_between(test_x, l[...,i].squeeze(), u[...,i].squeeze(), alpha=0.5)


    acq = generator.get_acquisition(model)(test_x.reshape(-1,1,1).double())

    ax[1].plot(test_x, acq, label='Acquisition Function')
    ax[1].legend()

In [None]:
from emitopt.utils import get_valid_emittance_samples
beam_energy = 45*10**-3 # GeV 
q_len = 0.12 # m
distance = 1.33-0.265 #m

data = X.data

data["grad"] = data["AWA:Bira3Ctrl:Ch04"] * 100*8.93e-3
data["int_grad"] = data["grad"]*q_len*10

data["sx_m"] = 3.9232781168265036e-05 * data["13ARV1:Sx"]

data


In [None]:
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.priors import GammaPrior
from gpytorch.kernels import MaternKernel, PolynomialKernel, ScaleKernel
from gpytorch import ExactMarginalLogLikelihood

from botorch.models.gp_regression import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.fit import fit_gpytorch_mll

train_x = torch.tensor(data.dropna()["int_grad"].to_numpy()).double().unsqueeze(1)
train_y = torch.tensor(data.dropna()["sx_m"].to_numpy()).double().unsqueeze(1)

print(train_x.shape)
print(train_y.shape)
input_transform = Normalize(1)
outcome_transform = Standardize(1)
covar_module = ScaleKernel(PolynomialKernel(power=2))
#covar_module = MaternKernel()

model = SingleTaskGP(train_x, 
                     train_y, 
                     input_transform=input_transform,
                     outcome_transform=outcome_transform, 
                     covar_module = covar_module
                     )

mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)


(emits_at_target_valid,
 emits_sq_at_target,
 is_valid,
 sample_validity_rate) = get_valid_emittance_samples(model, beam_energy,
                                                     q_len,
                                                     distance, n_samples=50, n_steps_quad_scan=10, visualize=True)

In [None]:
plt.hist(emits_at_target_valid.flatten()*90)
plt.title('Distribution of Sampled Emittances')
plt.xlabel('Emittance')
plt.ylabel('Probability Density')