In [1]:
import sys
sys.path.append('../')
from msBO import MultiStateBO
from msBO.objective import BPMvar_minimization
sys.path.append('../../machineIO/')
from machineIO import construct_machineIO, StatefulOracleEvaluator
from machineIO.preset import get_limits

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from epics import caget, caget_many, caput_many
from phantasy import fetch_data
import datetime
from phantasy import ensure_set, fetch_data

# import importlib, msBO  # or from boom import msBO if that's your import
# importlib.reload(msBO)

In [2]:
n_init = 16
n_each = 3
n_states = 2
n_iter = 3

In [3]:
SCS = caget("ACS_DIAG:DEST:ACTIVE_ION_SOURCE")
ion = caget("FE_ISRC"+str(SCS)+":BEAM:ELMT_BOOK")
Q = caget("FE_ISRC"+str(SCS)+":BEAM:Q_BOOK")
A = caget("FE_ISRC"+str(SCS)+":BEAM:A_BOOK")
# AQ = caget("FE_ISRC2:BEAM:MOVRQ_BOOK")
AQ = A/Q
ion = str(A)+ion+str(Q)
print('SCS'+str(SCS), ion, 'A/Q=',AQ)

SCS2 64Zn19 A/Q= 3.3684210526315788


In [4]:
now0 = datetime.datetime.now()
fname = now0.strftime('%Y%m%d_%H%M')+'['+ion+'][msBO]test'
fname

'20251119_2342[64Zn19][msBO]test'

### construct machineIO

In [5]:
mIO = construct_machineIO(ensure_set_timeout = 30, 
                          ensure_set_timewait_after_ramp = 0.3,
                          fetch_data_time_span = 2.0)

In [None]:
mIO.__dict__

In [None]:
ramping_time = 1
BO_comp_time = 10
dt = mIO._ensure_set_timewait_after_ramp + mIO._fetch_data_time_span + ramping_time + BO_comp_time
print(f'expected run time: {int((n_init*n_state + n_iter*n_state*n_each)*dt/60)} min')
print(f'expected number of data: {n_init*n_state + n_iter*n_state*n_each}')

### control knobs

In [6]:
control_CSETs = [
    'FS2_BTS:PSC2_D3930:I_CSET',
    'FS2_BTS:PSC1_D3930:I_CSET',
    'FS2_BTS:PSC2_D3962:I_CSET',
    'FS2_BTS:PSC1_D3962:I_CSET',
#     'FS2_BBS:PSC2_D4010:I_CSET',
#     'FS2_BBS:PSC1_D4010:I_CSET',
#     'FS2_BBS:PSC2_D4055:I_CSET',
#     'FS2_BBS:PSC1_D4055:I_CSET',
#     'FS2_BBS:PSC2_D4096:I_CSET',
#     'FS2_BBS:PSC1_D4096:I_CSET',
    'FS2_BBS:PSQ_D3996:I_CSET',
    'FS2_BBS:PSS_D4000:I_CSET',
    'FS2_BMS:PSC2_D4146:I_CSET',
    'FS2_BMS:PSC1_D4146:I_CSET',
]
control_RDs = [pv.replace('CSET','RD') for pv in control_CSETs]

control_ref = caget_many(control_CSETs)
control_rd_ref = caget_many(control_RDs)
control_tols_ref = 2*np.abs(control_ref - control_rd_ref)

control_tols = []
control_min = []
control_max = []
i = 0
for v, PV in zip(control_ref,control_CSETs):
    if ':PSC' in PV:
        control_min.append(v - 3)
        control_max.append(v + 3)
        control_tols.append(max(0.1,control_tols_ref[i]))
    elif ':PSQ_' in PV:
        control_min.append(v*0.95)
        control_max.append(v*1.05)
        control_tols.append(max(0.1,control_tols_ref[i]))
    elif ':PSS_' in PV:
        control_min.append(v*0.3)
        control_max.append(v*1.7)
        control_tols.append(max(0.05,control_tols_ref[i]))
    else:
        raise ValueError(f'control min/max cannot be determined automatically for {PV}')
    i+=1
        
        
##== Manually set decision bounds and tolerance
# control_min = [ -5, -5, -5, -5]
# control_max = [  5,  5,  5,  5]
# control_tols = [0.2,0.2,0.2,0.2]

assert len(control_CSETs) == len(control_min) == len(control_max) == len(control_tols)
control_Lo_limit, control_Hi_limit = get_limits(control_CSETs)
control_min = np.clip(control_min, a_min = control_Lo_limit, a_max = None)
control_max = np.clip(control_max, a_min = None, a_max = control_Hi_limit)
assert np.all(control_max > control_min)

pd.DataFrame(np.array([control_ref,control_min,control_max,control_tols,control_Lo_limit,control_Hi_limit]).T,
             index=control_CSETs, 
             columns=['current value','decision min','decision max','tol','LoLim','HiLim'])

Unnamed: 0,current value,decision min,decision max,tol,LoLim,HiLim
FS2_BTS:PSC2_D3930:I_CSET,0.0,-5.0,5.0,0.1,-39.0,39.0
FS2_BTS:PSC1_D3930:I_CSET,1.16,-3.84,6.16,0.1,-19.0,19.0
FS2_BBS:PSC1_D4010:I_CSET,0.0,-5.0,5.0,0.1,-186.0,186.0
FS2_BBS:PSC1_D4055:I_CSET,7.704,2.704,12.704,0.1,-186.0,186.0
FS2_BBS:PSC1_D4096:I_CSET,0.0,-5.0,5.0,0.1,-186.0,186.0
FS2_BBS:PSQ_D3996:I_CSET,181.735,172.64825,190.82175,0.1,0.0,309.0
FS2_BBS:PSS_D4000:I_CSET,1.769,0.3538,3.1842,0.05,0.0,4.8


In [7]:
control_couplings = {
    'FS2_BBS:PSQ_D3996:I_CSET':
        {
            'CSETs' :['FS2_BBS:PSQ_D4004:I_CSET','FS2_BBS:PSQ_D4014:I_CSET',
                      'FS2_BBS:PSQ_D4092:I_CSET','FS2_BBS:PSQ_D4102:I_CSET','FS2_BBS:PSQ_D4109:I_CSET'],
            'RDs'   :[],
            'tols'  :[],
            'coeffs':[],
        },
     'FS2_BBS:PSS_D4000:I_CSET':
        {
            'CSETs' :['FS2_BBS:PSS_D4106:I_CSET'],
            'RDs'   :[],
            'tols'  :[],
            'coeffs':[],
        },
}

for pv, val in control_couplings.items():
    x = caget(pv)
    y = np.array(caget_many(val['CSETs']))
    ipv = control_CSETs.index(pv)
    tol = control_tols[ipv]
    val['RDs'] = [pv.replace('CSET','RD') for pv in val['CSETs']]
    val['tols'] = [tol]*len(val['CSETs'])
    val['coeffs'] = y/x
    
control_couplings

{'FS2_BBS:PSQ_D3996:I_CSET': {'CSETs': ['FS2_BBS:PSQ_D4004:I_CSET',
   'FS2_BBS:PSQ_D4014:I_CSET',
   'FS2_BBS:PSQ_D4092:I_CSET',
   'FS2_BBS:PSQ_D4102:I_CSET',
   'FS2_BBS:PSQ_D4109:I_CSET'],
  'RDs': ['FS2_BBS:PSQ_D4004:I_RD',
   'FS2_BBS:PSQ_D4014:I_RD',
   'FS2_BBS:PSQ_D4092:I_RD',
   'FS2_BBS:PSQ_D4102:I_RD',
   'FS2_BBS:PSQ_D4109:I_RD'],
  'tols': [0.1, 0.1, 0.1, 0.1, 0.1],
  'coeffs': array([0.89769169, 0.92529232, 0.92529232, 0.89769169, 1.        ])},
 'FS2_BBS:PSS_D4000:I_CSET': {'CSETs': ['FS2_BBS:PSS_D4106:I_CSET'],
  'RDs': ['FS2_BBS:PSS_D4106:I_RD'],
  'tols': [0.05],
  'coeffs': array([1.])}}

### state definition

In [9]:
state_CSETs  = ['FS1_BBS:CSEL_D2405:CTR_MTR.VAL']
state_RDs    = ['FS1_BBS:CSEL_D2405:CTR_MTR.RBV']
state_tols   = [0.1]*len(state_CSETs)
state_val0   = caget_many(state_CSETs)
state_key_vals  = {'28+':[10],
                   '29+':[-20]}
states = list(state_key_vals.keys())
n_states = len(states)
state_key_vals

{'28': [10], '29': [-20]}

### monitors

In [None]:
monitor_BPMs = ['FS2_BMS:BPM_D4142','FS2_BMS:BPM_D4164','FS2_BMS:BPM_D4177','FS2_BMS:BPM_D4216']
monitor_RDs = []
for bpm in monitor_BPMs:
    monitor_RDs += [bpm+':XPOS_RD',bpm+':YPOS_RD',bpm+':MAG_RD']
#     monitor_RDs += [bpm+':YPOS_RD',bpm+':MAG_RD']

### Define Oracle 
# WARN!  This cell may change state!

In [None]:
bpm_posRDs = [pv for pv in monitor_RDs if 'POS_RD' in pv]
bpm_magRDs = [pv for pv in monitor_RDs if 'MAG_RD' in pv]

BPM_MAGs_ref = {}
for state, goal in state_key_vals.items():
    ret = ensure_set(state_CSETs,state_RDs,goal,state_tols,timeout=30)
    val, _ = fetch_data(bpm_magRDs,5)
    BPM_MAGs_ref[state] = val
BPM_MAGs_ref

In [None]:
def BPM_MAG_obj(df,s):
    df['BPM:MAG_min_ratio'] = (df[bpm_magRDs].values/BPM_MAGs_ref[s][None,:]).min()
    return df

In [None]:
oracle_key_names = {'x':control_RDs,
                    'y':bpm_posRDs + ['BPM:MAG_min_ratio']}
oracle_key_names

In [None]:
oracleEvaluator = StatefulOracleEvaluator(
    mIO,
    control_CSETs= control_CSETs,
    control_RDs  = control_RDs,
    control_tols = control_tols,
    state_CSETs  = state_CSETs,
    state_RDs    = state_RDs,
    state_tols   = state_tols,
    state_key_vals = state_key_vals,
    oracle_key_names = oracle_key_names,
    monitor_RDs  = monitor_RDs,
    control_couplings = control_couplings,
    state_df_manipulators = [BPM_MAG_obj],
)

# Define Composite Object

In [None]:
n_state = len(states)
n_task = len(bpm_posRDs+1)
composite_objective_function = BPMvar_minimization(S=n_state, J=n_task)

# Define msBO

In [None]:
local_optimization = False
acq_type = 'EI'
fix_acq_state = False

msbo = MultiStateBO(
    states = states,
    tasks  = oracle_key_names['y'], 
    control_min = control_min,
    control_max = control_max,
    multistate_oracle_evaluator = oracleEvaluator,
    composite_objective_function = composite_objective_function,
    local_bound_size = 0.2*(np.asarray(control_max) - np.asarray(control_min))
    )

# run msBO

In [None]:
msbo.init(n_init=n_init, local_optimization=local_optimization)
for i_iter in range(n_iter):
    if i_iter%2==0:
        states_order = states[::-1]
    else:
        states_order = states
    for s in states_order:
        print(i_iter, s)
        for i_each in range(n_each):
            msbo.step(s=s, local_optimization=local_optimization, acq_type=acq_type, fix_acq_state=fix_acq_state)

### visualize optimization result

In [None]:
fig,ax, virtualObjMean = msbo.plot_composite_objective()
ax[0].set_ylim(0.97,1)

In [None]:
fig, axes = msbo.plot_state_predictions_history()
plt.show()

### set to best solution

In [None]:
x_best = msbo.history['predictions'][-1]['x_best']
caput_many(control_CSETs,x_best)

### restore states

In [None]:
caput_many(state_CSETs,state_val0)

### restore control if optimization result is not good

In [None]:
# caput_many(control_CSETs,control_ref)