# To-Do:

- InputStandardize vs Normalize

- check hyperparam/prior definitions (scaled space vs raw space)

- keep in mind modality of NaN results in emittance

- Try different number of steps along measurement dimension

- inrease dimensionality of tuning space

- fix legend location (only plot on 1 heatmap)

# In this notebook, we fit a gpytorch GP to a simple emittance model with 1 tuning parameter. We use the GP to evaluate the Expected Information Gain toward the result of a grid-scan minimization routine.

In [1]:
import torch
from emitutils import toy_beam_size_squared_nd, fit_gp_model_emittance
from utils import unif_random_sample_domain
from matplotlib import pyplot as plt
from algorithms import GridMinimizeEmittance
from acquisition import ExpectedInformationGain
from botorch.optim import optimize_acqf
import time
from mpl_toolkits.axes_grid1 import make_axes_locatable
import copy

# Settings

In [2]:
# domain = torch.tensor([[-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
domain = torch.tensor([[-3,1], [-3,1], [-40,60]]).double() #the acquisition domain, must have shape = (ndim, 2)
ndim = domain.shape[0]                               #where domain[i,0] and domain[i,1] represent
                                                        #the lower and upper bounds of the ith input dimension
                                                        #(these same bounds will be applied to the sampled execution paths) 

    

    
n_samples = 100 #number of posterior samples on which to evaluate execution paths
n_steps_tuning_params = 11 #number of steps per dimension in the posterior sample grid scans 
n_steps_measurement_param = 11
squared = False #whether or not to minimize the "emittance squared" (which can be negative according to the model)





random_acq = False
n_trials = 5
n_iter = 50
n_obs_init = 5 #number of random observations on which to initialize model


In [3]:
# domain = torch.tensor([[-2,2], [-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
#                                                         #where domain[i,0] and domain[i,1] represent
#                                                         #the lower and upper bounds of the ith input dimension
#                                                         #(these same bounds will be applied to the sampled execution paths)
        
# ndim = domain.shape[0]
    





# n_samples = 100 #number of posterior samples on which to evaluate execution paths
# n_steps_tuning_params = 11 #number of steps per dimension in the posterior sample grid scans 
# n_steps_measurement_param = 11
# squared = True #whether or not to minimize the "emittance squared" (which can be negative according to the model)





# random_acq = True
# n_trials = 20
# n_iter = 20
# n_obs_init = 5 #number of random observations on which to initialize model

# Initialize

In [4]:
trial_data = {}
trial_data['settings'] = {'domain':domain,
                         'ndim':ndim,
                         'n_obs_init': n_obs_init,
                         'n_samples':n_samples,
                         'n_steps_tuning_params':n_steps_tuning_params,
                         'n_steps_measurement_param': n_steps_measurement_param,
                         'n_trials':n_trials,
                         'n_iter':n_iter,
                         'squared':squared,
                         'random_acq':random_acq}

for trial in range(n_trials):
    torch.manual_seed(trial)

    #build ndim dimensional parabolic target function
    target_func = toy_beam_size_squared_nd


    ##########################################
    #Observe target function n_obs_init times using a uniform sample of the domain
    x_obs = unif_random_sample_domain(n_samples = n_obs_init, domain = domain)
    y_obs = target_func(x_obs) 




    #fit model on initial observations
    model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

    algo = GridMinimizeEmittance(domain = domain, 
                   n_samples = n_samples, 
                   n_steps_tuning_params = n_steps_tuning_params,
                    n_steps_measurement_param = n_steps_measurement_param,
                    squared = squared)
    
    rng_state = torch.get_rng_state()
    
    acq_fn = ExpectedInformationGain(model = model, algo = algo)

    if random_acq:
        x_next = None
    else:
        x_next, _ = optimize_acqf(
            acq_function=acq_fn,
            bounds=acq_fn.algo.domain.T,
            q=1,
            num_restarts=20,
            raw_samples=100,
            options={},
            )
    
    iter_data = {}
    iter_data[0] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}
    
    for i in range(1, n_iter+1):
        start = time.time()
        print('Iteration', trial*n_iter + i, '/', n_trials*n_iter)
        
        if random_acq:
            x_new = unif_random_sample_domain(n_samples = 1, domain = domain)
        else:
            x_new = x_next
            
        y_new = target_func(x_new)

        x_obs = torch.cat((x_obs, x_new), dim=0)
        y_obs = torch.cat((y_obs, y_new), dim=0)

        model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

        rng_state = torch.get_rng_state()
        
        acq_fn = ExpectedInformationGain(model = model, algo = algo)

        if random_acq:
            x_next = None
        else:
            x_next, _ = optimize_acqf(
                acq_function=acq_fn,
                bounds=acq_fn.algo.domain.T,
                q=1,
                num_restarts=20,
                raw_samples=100,
                options={},
                )
            
        end = time.time()
        print('Operation took', end - start, 'seconds.')
        
        iter_data[i] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}

    trial_data[trial] = iter_data 




Iteration 1 / 250




Operation took 0.5836865901947021 seconds.
Iteration 2 / 250




Operation took 69.59668326377869 seconds.
Iteration 3 / 250




Operation took 10.425024509429932 seconds.
Iteration 4 / 250




Operation took 19.855926275253296 seconds.
Iteration 5 / 250




Operation took 20.585057020187378 seconds.
Iteration 6 / 250




Operation took 5.39989447593689 seconds.
Iteration 7 / 250




Operation took 48.21183753013611 seconds.
Iteration 8 / 250




Operation took 86.91738700866699 seconds.
Iteration 9 / 250




Operation took 24.255075454711914 seconds.
Iteration 10 / 250




Operation took 33.82695984840393 seconds.
Iteration 11 / 250




Operation took 51.220815896987915 seconds.
Iteration 12 / 250




Operation took 103.52649474143982 seconds.
Iteration 13 / 250




Operation took 47.06308102607727 seconds.
Iteration 14 / 250




Operation took 281.9206931591034 seconds.
Iteration 15 / 250




Operation took 41.38756084442139 seconds.
Iteration 16 / 250




Operation took 192.7562005519867 seconds.
Iteration 17 / 250




Operation took 25.21068835258484 seconds.
Iteration 18 / 250




Operation took 25.21246337890625 seconds.
Iteration 19 / 250




Operation took 49.22711443901062 seconds.
Iteration 20 / 250




Operation took 42.14255452156067 seconds.
Iteration 21 / 250




Operation took 36.397095680236816 seconds.
Iteration 22 / 250




Operation took 39.51630139350891 seconds.
Iteration 23 / 250




Operation took 55.626506328582764 seconds.
Iteration 24 / 250




Operation took 33.254300355911255 seconds.
Iteration 25 / 250




Operation took 23.40985870361328 seconds.
Iteration 26 / 250




Operation took 17.771971940994263 seconds.
Iteration 27 / 250




Operation took 36.70347809791565 seconds.
Iteration 28 / 250




Operation took 40.45195698738098 seconds.
Iteration 29 / 250




Operation took 14.26879596710205 seconds.
Iteration 30 / 250




Operation took 46.94612908363342 seconds.
Iteration 31 / 250




Operation took 44.62866282463074 seconds.
Iteration 32 / 250




Operation took 57.37561345100403 seconds.
Iteration 33 / 250




Operation took 70.18668961524963 seconds.
Iteration 34 / 250




Operation took 48.47767782211304 seconds.
Iteration 35 / 250




Operation took 39.91523098945618 seconds.
Iteration 36 / 250




Operation took 68.7559015750885 seconds.
Iteration 37 / 250




Operation took 79.67453026771545 seconds.
Iteration 38 / 250




Operation took 28.12661862373352 seconds.
Iteration 39 / 250




Operation took 20.92165970802307 seconds.
Iteration 40 / 250




Operation took 26.120103120803833 seconds.
Iteration 41 / 250




Operation took 36.056307554244995 seconds.
Iteration 42 / 250




Operation took 174.1460738182068 seconds.
Iteration 43 / 250




Operation took 36.604684352874756 seconds.
Iteration 44 / 250




Operation took 42.56279158592224 seconds.
Iteration 45 / 250




Operation took 32.837810754776 seconds.
Iteration 46 / 250




Operation took 36.011372327804565 seconds.
Iteration 47 / 250




Operation took 54.51244878768921 seconds.
Iteration 48 / 250




Operation took 25.692259311676025 seconds.
Iteration 49 / 250




Operation took 15.91281008720398 seconds.
Iteration 50 / 250




Operation took 23.707348585128784 seconds.




Iteration 51 / 250




Operation took 28.749364137649536 seconds.
Iteration 52 / 250




Operation took 25.72080135345459 seconds.
Iteration 53 / 250




Operation took 21.6758553981781 seconds.
Iteration 54 / 250




Operation took 19.801119804382324 seconds.
Iteration 55 / 250




Operation took 10.752876996994019 seconds.
Iteration 56 / 250




Operation took 8.709319353103638 seconds.
Iteration 57 / 250




Operation took 49.13880228996277 seconds.
Iteration 58 / 250




Operation took 23.781910181045532 seconds.
Iteration 59 / 250


Trying again with a new set of initial conditions.


Operation took 364.85542154312134 seconds.
Iteration 60 / 250




Operation took 87.69872164726257 seconds.
Iteration 61 / 250




Operation took 118.17856025695801 seconds.
Iteration 62 / 250




Operation took 149.80295133590698 seconds.
Iteration 63 / 250




Operation took 98.60225772857666 seconds.
Iteration 64 / 250




Operation took 274.6373543739319 seconds.
Iteration 65 / 250




Operation took 51.98129391670227 seconds.
Iteration 66 / 250




Operation took 29.39082431793213 seconds.
Iteration 67 / 250




Operation took 131.50131011009216 seconds.
Iteration 68 / 250




Operation took 23.870829343795776 seconds.
Iteration 69 / 250




Operation took 40.16134071350098 seconds.
Iteration 70 / 250




Operation took 105.48115301132202 seconds.
Iteration 71 / 250




Operation took 52.031935930252075 seconds.
Iteration 72 / 250




Operation took 29.380329608917236 seconds.
Iteration 73 / 250




Operation took 51.50411939620972 seconds.
Iteration 74 / 250




Operation took 21.915040016174316 seconds.
Iteration 75 / 250




Operation took 49.27386116981506 seconds.
Iteration 76 / 250




Operation took 17.57485032081604 seconds.
Iteration 77 / 250




Operation took 39.84673261642456 seconds.
Iteration 78 / 250




Operation took 54.201915979385376 seconds.
Iteration 79 / 250




Operation took 82.54523420333862 seconds.
Iteration 80 / 250




Operation took 15.172305822372437 seconds.
Iteration 81 / 250




Operation took 29.440914630889893 seconds.
Iteration 82 / 250




Operation took 20.36540412902832 seconds.
Iteration 83 / 250




Operation took 14.654191255569458 seconds.
Iteration 84 / 250




Operation took 26.44592308998108 seconds.
Iteration 85 / 250




Operation took 42.29031705856323 seconds.
Iteration 86 / 250




Operation took 22.08726167678833 seconds.
Iteration 87 / 250




Operation took 40.45809745788574 seconds.
Iteration 88 / 250




Operation took 65.36248636245728 seconds.
Iteration 89 / 250




Operation took 25.350048542022705 seconds.
Iteration 90 / 250




Operation took 29.14751648902893 seconds.
Iteration 91 / 250




Operation took 90.13376879692078 seconds.
Iteration 92 / 250




Operation took 45.20994710922241 seconds.
Iteration 93 / 250




Operation took 58.181458473205566 seconds.
Iteration 94 / 250




Operation took 62.812137603759766 seconds.
Iteration 95 / 250




Operation took 22.15482258796692 seconds.
Iteration 96 / 250




Operation took 27.691478967666626 seconds.
Iteration 97 / 250




Operation took 39.97381782531738 seconds.
Iteration 98 / 250




Operation took 34.87810015678406 seconds.
Iteration 99 / 250




Operation took 13.235070705413818 seconds.
Iteration 100 / 250




Operation took 27.617701292037964 seconds.




Iteration 101 / 250




Operation took 10.158295631408691 seconds.
Iteration 102 / 250




Operation took 1.8749802112579346 seconds.
Iteration 103 / 250




Operation took 9.963179588317871 seconds.
Iteration 104 / 250




Operation took 9.487940311431885 seconds.
Iteration 105 / 250




Operation took 16.913678407669067 seconds.
Iteration 106 / 250




Operation took 79.94099712371826 seconds.
Iteration 107 / 250




Operation took 51.24648594856262 seconds.
Iteration 108 / 250




Operation took 93.80779957771301 seconds.
Iteration 109 / 250




Operation took 74.52963256835938 seconds.
Iteration 110 / 250




Operation took 28.642570972442627 seconds.
Iteration 111 / 250




Operation took 10.635670900344849 seconds.
Iteration 112 / 250




Operation took 28.59625220298767 seconds.
Iteration 113 / 250




Operation took 38.8838369846344 seconds.
Iteration 114 / 250




Operation took 26.838054656982422 seconds.
Iteration 115 / 250




Operation took 23.032397270202637 seconds.
Iteration 116 / 250




Operation took 32.56122422218323 seconds.
Iteration 117 / 250




Operation took 43.41992712020874 seconds.
Iteration 118 / 250




Operation took 64.90433382987976 seconds.
Iteration 119 / 250




Operation took 88.83197832107544 seconds.
Iteration 120 / 250




Operation took 40.54688763618469 seconds.
Iteration 121 / 250




Operation took 75.95414233207703 seconds.
Iteration 122 / 250




Operation took 128.9901406764984 seconds.
Iteration 123 / 250




Operation took 131.6344940662384 seconds.
Iteration 124 / 250




Operation took 136.87487936019897 seconds.
Iteration 125 / 250




Operation took 25.0877468585968 seconds.
Iteration 126 / 250




Operation took 56.8318247795105 seconds.
Iteration 127 / 250




Operation took 67.604238986969 seconds.
Iteration 128 / 250




Operation took 40.041712284088135 seconds.
Iteration 129 / 250




Operation took 52.492302894592285 seconds.
Iteration 130 / 250




Operation took 32.44178223609924 seconds.
Iteration 131 / 250




Operation took 23.150970458984375 seconds.
Iteration 132 / 250




Operation took 28.76347041130066 seconds.
Iteration 133 / 250




Operation took 16.792389631271362 seconds.
Iteration 134 / 250




Operation took 20.4934241771698 seconds.
Iteration 135 / 250




Operation took 20.61402702331543 seconds.
Iteration 136 / 250




Operation took 23.604618549346924 seconds.
Iteration 137 / 250




Operation took 33.28693175315857 seconds.
Iteration 138 / 250




Operation took 43.07893419265747 seconds.
Iteration 139 / 250




Operation took 24.972266674041748 seconds.
Iteration 140 / 250




Operation took 18.39188575744629 seconds.
Iteration 141 / 250




Operation took 18.84420371055603 seconds.
Iteration 142 / 250




Operation took 24.118848085403442 seconds.
Iteration 143 / 250




Operation took 26.205150604248047 seconds.
Iteration 144 / 250




Operation took 34.92729210853577 seconds.
Iteration 145 / 250




Operation took 22.215231895446777 seconds.
Iteration 146 / 250




Operation took 47.24680948257446 seconds.
Iteration 147 / 250




Operation took 19.88274383544922 seconds.
Iteration 148 / 250




Operation took 18.13741898536682 seconds.
Iteration 149 / 250




Operation took 38.184128284454346 seconds.
Iteration 150 / 250




Operation took 15.890381097793579 seconds.




Iteration 151 / 250




Operation took 37.2683789730072 seconds.
Iteration 152 / 250




Operation took 26.604265689849854 seconds.
Iteration 153 / 250




Operation took 17.28608274459839 seconds.
Iteration 154 / 250




Operation took 33.98696708679199 seconds.
Iteration 155 / 250




Operation took 46.10784196853638 seconds.
Iteration 156 / 250




Operation took 8.479487419128418 seconds.
Iteration 157 / 250




Operation took 27.772531032562256 seconds.
Iteration 158 / 250




Operation took 28.2308566570282 seconds.
Iteration 159 / 250




Operation took 7.103296756744385 seconds.
Iteration 160 / 250




Operation took 6.370581150054932 seconds.
Iteration 161 / 250




Operation took 28.330405712127686 seconds.
Iteration 162 / 250




Operation took 45.280967712402344 seconds.
Iteration 163 / 250




Operation took 45.626235485076904 seconds.
Iteration 164 / 250




Operation took 47.231703758239746 seconds.
Iteration 165 / 250




Operation took 27.901227474212646 seconds.
Iteration 166 / 250




Operation took 27.998197317123413 seconds.
Iteration 167 / 250




Operation took 49.82306218147278 seconds.
Iteration 168 / 250




Operation took 94.45419025421143 seconds.
Iteration 169 / 250




Operation took 52.991156816482544 seconds.
Iteration 170 / 250




Operation took 25.23879313468933 seconds.
Iteration 171 / 250




Operation took 12.483375787734985 seconds.
Iteration 172 / 250




Operation took 46.11701703071594 seconds.
Iteration 173 / 250




Operation took 24.08855962753296 seconds.
Iteration 174 / 250




Operation took 22.44302201271057 seconds.
Iteration 175 / 250




Operation took 37.495675802230835 seconds.
Iteration 176 / 250




Operation took 23.73256254196167 seconds.
Iteration 177 / 250




Operation took 22.109366416931152 seconds.
Iteration 178 / 250




Operation took 43.87615370750427 seconds.
Iteration 179 / 250




Operation took 56.92978310585022 seconds.
Iteration 180 / 250




Operation took 16.73576021194458 seconds.
Iteration 181 / 250




Operation took 17.93135404586792 seconds.
Iteration 182 / 250




Operation took 33.52702736854553 seconds.
Iteration 183 / 250




Operation took 13.135967254638672 seconds.
Iteration 184 / 250




Operation took 19.244394302368164 seconds.
Iteration 185 / 250




Operation took 30.770976066589355 seconds.
Iteration 186 / 250




Operation took 67.49591994285583 seconds.
Iteration 187 / 250




Operation took 24.914307594299316 seconds.
Iteration 188 / 250




Operation took 21.78612184524536 seconds.
Iteration 189 / 250




Operation took 19.21573567390442 seconds.
Iteration 190 / 250




Operation took 24.090635299682617 seconds.
Iteration 191 / 250




Operation took 25.02867817878723 seconds.
Iteration 192 / 250




Operation took 23.137032747268677 seconds.
Iteration 193 / 250




Operation took 19.816287994384766 seconds.
Iteration 194 / 250


Trying again with a new set of initial conditions.


Operation took 44.35692477226257 seconds.
Iteration 195 / 250




Operation took 24.29187512397766 seconds.
Iteration 196 / 250




Operation took 20.57752752304077 seconds.
Iteration 197 / 250




Operation took 29.26692509651184 seconds.
Iteration 198 / 250




Operation took 54.2033588886261 seconds.
Iteration 199 / 250




Operation took 22.037179708480835 seconds.
Iteration 200 / 250




Operation took 33.11057257652283 seconds.




Iteration 201 / 250




Operation took 10.411766290664673 seconds.
Iteration 202 / 250




Operation took 13.570268392562866 seconds.
Iteration 203 / 250




Operation took 13.654623985290527 seconds.
Iteration 204 / 250




Operation took 4.029750823974609 seconds.
Iteration 205 / 250




Operation took 8.32880425453186 seconds.
Iteration 206 / 250




Operation took 14.929072618484497 seconds.
Iteration 207 / 250




Operation took 46.426759004592896 seconds.
Iteration 208 / 250




Operation took 13.377853631973267 seconds.
Iteration 209 / 250




Operation took 8.229377031326294 seconds.
Iteration 210 / 250




Operation took 18.942062377929688 seconds.
Iteration 211 / 250




Operation took 13.567198276519775 seconds.
Iteration 212 / 250




Operation took 13.809942245483398 seconds.
Iteration 213 / 250




Operation took 5.095235586166382 seconds.
Iteration 214 / 250




Operation took 9.22952914237976 seconds.
Iteration 215 / 250




Operation took 9.138210773468018 seconds.
Iteration 216 / 250




Operation took 6.4061150550842285 seconds.
Iteration 217 / 250




Operation took 18.473100423812866 seconds.
Iteration 218 / 250




Operation took 6.927709102630615 seconds.
Iteration 219 / 250




Operation took 8.619102001190186 seconds.
Iteration 220 / 250




Operation took 7.742298603057861 seconds.
Iteration 221 / 250




Operation took 11.465794563293457 seconds.
Iteration 222 / 250




Operation took 8.120166778564453 seconds.
Iteration 223 / 250




Operation took 11.732980489730835 seconds.
Iteration 224 / 250




Operation took 11.769858598709106 seconds.
Iteration 225 / 250




Operation took 8.396206617355347 seconds.
Iteration 226 / 250




Operation took 12.706088066101074 seconds.
Iteration 227 / 250




Operation took 9.940029859542847 seconds.
Iteration 228 / 250




Operation took 12.608113527297974 seconds.
Iteration 229 / 250




Operation took 18.915907859802246 seconds.
Iteration 230 / 250




Operation took 17.350945234298706 seconds.
Iteration 231 / 250




Operation took 29.426334619522095 seconds.
Iteration 232 / 250




Operation took 10.934624910354614 seconds.
Iteration 233 / 250




Operation took 18.894349336624146 seconds.
Iteration 234 / 250




Operation took 12.537399053573608 seconds.
Iteration 235 / 250




Operation took 11.017449140548706 seconds.
Iteration 236 / 250




Operation took 9.439123392105103 seconds.
Iteration 237 / 250




Operation took 14.451641082763672 seconds.
Iteration 238 / 250




Operation took 18.83786177635193 seconds.
Iteration 239 / 250




Operation took 12.260863304138184 seconds.
Iteration 240 / 250




Operation took 9.042762279510498 seconds.
Iteration 241 / 250




Operation took 29.09670400619507 seconds.
Iteration 242 / 250




Operation took 18.610153913497925 seconds.
Iteration 243 / 250




Operation took 18.855323791503906 seconds.
Iteration 244 / 250




Operation took 22.319166898727417 seconds.
Iteration 245 / 250




Operation took 14.13136076927185 seconds.
Iteration 246 / 250




Operation took 26.16221809387207 seconds.
Iteration 247 / 250




Operation took 8.493204593658447 seconds.
Iteration 248 / 250




Operation took 21.724406003952026 seconds.
Iteration 249 / 250




Operation took 28.056666374206543 seconds.
Iteration 250 / 250




Operation took 25.760644912719727 seconds.


In [5]:
# import dill
# with open('MC-Emittance-Phys-Random-2d-Results.pkl', 'wb') as f:
#     dill.dump(trial_data, f)

In [6]:
import dill
with open('MC-Emittance-Phys-BAX-3d-Results-test.pkl', 'wb') as f:
    dill.dump(trial_data, f)