# To-Do:

- InputStandardize vs Normalize

- check hyperparam/prior definitions (scaled space vs raw space)

- keep in mind modality of NaN results in emittance

- Try different number of steps along measurement dimension

- inrease dimensionality of tuning space

- fix legend location (only plot on 1 heatmap)

# In this notebook, we fit a gpytorch GP to a simple emittance model with 1 tuning parameter. We use the GP to evaluate the Expected Information Gain toward the result of a grid-scan minimization routine.

In [1]:
import torch
from emitutils import toy_beam_size_squared_nd, fit_gp_model_emittance
from utils import unif_random_sample_domain
from matplotlib import pyplot as plt
from algorithms import GridMinimizeEmittance
from acquisition import ExpectedInformationGain
from botorch.optim import optimize_acqf
import time
from mpl_toolkits.axes_grid1 import make_axes_locatable
import copy

# Settings

In [2]:
# domain = torch.tensor([[-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
domain = torch.tensor([[-3,1], [-3,1], [-40,60]]).double() #the acquisition domain, must have shape = (ndim, 2)
ndim = domain.shape[0]                               #where domain[i,0] and domain[i,1] represent
                                                        #the lower and upper bounds of the ith input dimension
                                                        #(these same bounds will be applied to the sampled execution paths) 

    

    
n_samples = 100 #number of posterior samples on which to evaluate execution paths
n_steps_tuning_params = 11 #number of steps per dimension in the posterior sample grid scans 
n_steps_measurement_param = 11
squared = True #whether or not to minimize the "emittance squared" (which can be negative according to the model)





random_acq = False
n_trials = 5
n_iter = 50
n_obs_init = 5 #number of random observations on which to initialize model


In [3]:
# domain = torch.tensor([[-2,2], [-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
#                                                         #where domain[i,0] and domain[i,1] represent
#                                                         #the lower and upper bounds of the ith input dimension
#                                                         #(these same bounds will be applied to the sampled execution paths)
        
# ndim = domain.shape[0]
    





# n_samples = 100 #number of posterior samples on which to evaluate execution paths
# n_steps_tuning_params = 11 #number of steps per dimension in the posterior sample grid scans 
# n_steps_measurement_param = 11
# squared = True #whether or not to minimize the "emittance squared" (which can be negative according to the model)





# random_acq = True
# n_trials = 20
# n_iter = 20
# n_obs_init = 5 #number of random observations on which to initialize model

# Initialize

In [4]:
trial_data = {}
trial_data['settings'] = {'domain':domain,
                         'ndim':ndim,
                         'n_obs_init': n_obs_init,
                         'n_samples':n_samples,
                         'n_steps_tuning_params':n_steps_tuning_params,
                         'n_steps_measurement_param': n_steps_measurement_param,
                         'n_trials':n_trials,
                         'n_iter':n_iter,
                         'squared':squared,
                         'random_acq':random_acq}

for trial in range(n_trials):
    torch.manual_seed(trial)

    #build ndim dimensional parabolic target function
    target_func = toy_beam_size_squared_nd


    ##########################################
    #Observe target function n_obs_init times using a uniform sample of the domain
    x_obs = unif_random_sample_domain(n_samples = n_obs_init, domain = domain)
    y_obs = target_func(x_obs) 




    #fit model on initial observations
    model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

    algo = GridMinimizeEmittance(domain = domain, 
                   n_samples = n_samples, 
                   n_steps_tuning_params = n_steps_tuning_params,
                    n_steps_measurement_param = n_steps_measurement_param,
                    squared = squared)
    
    rng_state = torch.get_rng_state()
    
    acq_fn = ExpectedInformationGain(model = model, algo = algo)

    if random_acq:
        x_next = None
    else:
        x_next, _ = optimize_acqf(
            acq_function=acq_fn,
            bounds=acq_fn.algo.domain.T,
            q=1,
            num_restarts=20,
            raw_samples=100,
            options={},
            )
    
    iter_data = {}
    iter_data[0] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}
    
    for i in range(1, n_iter+1):
        start = time.time()
        print('Iteration', trial*n_iter + i, '/', n_trials*n_iter)
        
        if random_acq:
            x_new = unif_random_sample_domain(n_samples = 1, domain = domain)
        else:
            x_new = x_next
            
        y_new = target_func(x_new)

        x_obs = torch.cat((x_obs, x_new), dim=0)
        y_obs = torch.cat((y_obs, y_new), dim=0)

        model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

        rng_state = torch.get_rng_state()
        
        acq_fn = ExpectedInformationGain(model = model, algo = algo)

        if random_acq:
            x_next = None
        else:
            x_next, _ = optimize_acqf(
                acq_function=acq_fn,
                bounds=acq_fn.algo.domain.T,
                q=1,
                num_restarts=20,
                raw_samples=100,
                options={},
                )
            
        end = time.time()
        print('Operation took', end - start, 'seconds.')
        
        iter_data[i] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}

    trial_data[trial] = iter_data 




Iteration 1 / 250




Operation took 0.5951366424560547 seconds.
Iteration 2 / 250




Operation took 11.622281551361084 seconds.
Iteration 3 / 250




Operation took 4.2119057178497314 seconds.
Iteration 4 / 250




Operation took 8.856431484222412 seconds.
Iteration 5 / 250




Operation took 8.08725905418396 seconds.
Iteration 6 / 250




Operation took 6.299430847167969 seconds.
Iteration 7 / 250




Operation took 11.338035345077515 seconds.
Iteration 8 / 250




Operation took 13.756847381591797 seconds.
Iteration 9 / 250




Operation took 20.232379913330078 seconds.
Iteration 10 / 250




Operation took 12.188704252243042 seconds.
Iteration 11 / 250




Operation took 11.029697895050049 seconds.
Iteration 12 / 250




Operation took 9.252173900604248 seconds.
Iteration 13 / 250




Operation took 8.611566066741943 seconds.
Iteration 14 / 250




Operation took 10.930546283721924 seconds.
Iteration 15 / 250




Operation took 17.919996976852417 seconds.
Iteration 16 / 250




Operation took 18.79954481124878 seconds.
Iteration 17 / 250




Operation took 21.100425481796265 seconds.
Iteration 18 / 250




Operation took 27.676061868667603 seconds.
Iteration 19 / 250




Operation took 41.21087169647217 seconds.
Iteration 20 / 250




Operation took 13.581188917160034 seconds.
Iteration 21 / 250




Operation took 24.898972749710083 seconds.
Iteration 22 / 250




Operation took 16.591513633728027 seconds.
Iteration 23 / 250




Operation took 6.140655040740967 seconds.
Iteration 24 / 250




Operation took 4.402473211288452 seconds.
Iteration 25 / 250




Operation took 14.67047643661499 seconds.
Iteration 26 / 250




Operation took 8.629069805145264 seconds.
Iteration 27 / 250




Operation took 25.795382976531982 seconds.
Iteration 28 / 250




Operation took 5.082986354827881 seconds.
Iteration 29 / 250




Operation took 7.945439577102661 seconds.
Iteration 30 / 250




Operation took 13.200979471206665 seconds.
Iteration 31 / 250




Operation took 9.768194198608398 seconds.
Iteration 32 / 250




Operation took 14.552184104919434 seconds.
Iteration 33 / 250




Operation took 8.377030849456787 seconds.
Iteration 34 / 250




Operation took 7.310768365859985 seconds.
Iteration 35 / 250




Operation took 15.19032073020935 seconds.
Iteration 36 / 250




Operation took 12.950366973876953 seconds.
Iteration 37 / 250




Operation took 8.154961824417114 seconds.
Iteration 38 / 250




Operation took 22.421923398971558 seconds.
Iteration 39 / 250




Operation took 21.74246621131897 seconds.
Iteration 40 / 250




Operation took 6.440871477127075 seconds.
Iteration 41 / 250


Trying again with a new set of initial conditions.


Operation took 23.515208959579468 seconds.
Iteration 42 / 250




Operation took 13.019025802612305 seconds.
Iteration 43 / 250




Operation took 8.237618684768677 seconds.
Iteration 44 / 250




Operation took 15.84885859489441 seconds.
Iteration 45 / 250


Trying again with a new set of initial conditions.


Operation took 27.167447566986084 seconds.
Iteration 46 / 250




Operation took 25.224257707595825 seconds.
Iteration 47 / 250




Operation took 21.008172273635864 seconds.
Iteration 48 / 250




Operation took 23.82477879524231 seconds.
Iteration 49 / 250


Trying again with a new set of initial conditions.


Operation took 34.51429891586304 seconds.
Iteration 50 / 250




Operation took 17.256624221801758 seconds.




Iteration 51 / 250




Operation took 22.86375403404236 seconds.
Iteration 52 / 250




Operation took 0.8120977878570557 seconds.
Iteration 53 / 250




Operation took 16.540895462036133 seconds.
Iteration 54 / 250




Operation took 12.703195333480835 seconds.
Iteration 55 / 250




Operation took 29.40591526031494 seconds.
Iteration 56 / 250




Operation took 13.208673477172852 seconds.
Iteration 57 / 250




Operation took 14.899308681488037 seconds.
Iteration 58 / 250




Operation took 21.819032430648804 seconds.
Iteration 59 / 250




Operation took 48.55161714553833 seconds.
Iteration 60 / 250




Operation took 18.103031873703003 seconds.
Iteration 61 / 250




Operation took 10.481316804885864 seconds.
Iteration 62 / 250




Operation took 58.05393147468567 seconds.
Iteration 63 / 250




Operation took 9.266867399215698 seconds.
Iteration 64 / 250




Operation took 17.687456846237183 seconds.
Iteration 65 / 250




Operation took 11.25882363319397 seconds.
Iteration 66 / 250




Operation took 64.53298783302307 seconds.
Iteration 67 / 250




Operation took 28.51418662071228 seconds.
Iteration 68 / 250




Operation took 40.662365436553955 seconds.
Iteration 69 / 250




Operation took 34.72114539146423 seconds.
Iteration 70 / 250




Operation took 37.078045129776 seconds.
Iteration 71 / 250




Operation took 27.50834560394287 seconds.
Iteration 72 / 250




Operation took 11.076736450195312 seconds.
Iteration 73 / 250




Operation took 31.715935230255127 seconds.
Iteration 74 / 250




Operation took 10.608288526535034 seconds.
Iteration 75 / 250




Operation took 9.561359167098999 seconds.
Iteration 76 / 250




Operation took 11.430431604385376 seconds.
Iteration 77 / 250




Operation took 10.723898649215698 seconds.
Iteration 78 / 250




Operation took 7.388714075088501 seconds.
Iteration 79 / 250




Operation took 13.353243827819824 seconds.
Iteration 80 / 250




Operation took 8.903457880020142 seconds.
Iteration 81 / 250




Operation took 11.868966579437256 seconds.
Iteration 82 / 250




Operation took 12.95896291732788 seconds.
Iteration 83 / 250




Operation took 12.658265113830566 seconds.
Iteration 84 / 250




Operation took 7.936623811721802 seconds.
Iteration 85 / 250




Operation took 16.87240433692932 seconds.
Iteration 86 / 250




Operation took 16.857123613357544 seconds.
Iteration 87 / 250




Operation took 15.927941799163818 seconds.
Iteration 88 / 250




Operation took 9.081196546554565 seconds.
Iteration 89 / 250


Trying again with a new set of initial conditions.


Operation took 38.48580574989319 seconds.
Iteration 90 / 250




Operation took 12.04929518699646 seconds.
Iteration 91 / 250


Trying again with a new set of initial conditions.


Operation took 19.335794925689697 seconds.
Iteration 92 / 250


Trying again with a new set of initial conditions.


Operation took 19.036084413528442 seconds.
Iteration 93 / 250




Operation took 19.376843690872192 seconds.
Iteration 94 / 250




Operation took 20.411036252975464 seconds.
Iteration 95 / 250




Operation took 16.75018072128296 seconds.
Iteration 96 / 250


Trying again with a new set of initial conditions.


Operation took 15.425008773803711 seconds.
Iteration 97 / 250




Operation took 8.116418361663818 seconds.
Iteration 98 / 250


Trying again with a new set of initial conditions.


Operation took 59.43424654006958 seconds.
Iteration 99 / 250




Operation took 26.57988452911377 seconds.
Iteration 100 / 250


Trying again with a new set of initial conditions.


Operation took 53.48202967643738 seconds.




Iteration 101 / 250




Operation took 7.192124843597412 seconds.
Iteration 102 / 250




Operation took 4.178504228591919 seconds.
Iteration 103 / 250




Operation took 0.6968247890472412 seconds.
Iteration 104 / 250




Operation took 11.240174055099487 seconds.
Iteration 105 / 250




Operation took 4.417217016220093 seconds.
Iteration 106 / 250




Operation took 2.229933500289917 seconds.
Iteration 107 / 250




Operation took 15.818317651748657 seconds.
Iteration 108 / 250




Operation took 12.84285831451416 seconds.
Iteration 109 / 250




Operation took 32.59651017189026 seconds.
Iteration 110 / 250




Operation took 11.931344747543335 seconds.
Iteration 111 / 250




Operation took 27.61566162109375 seconds.
Iteration 112 / 250




Operation took 6.781786680221558 seconds.
Iteration 113 / 250




Operation took 18.6323504447937 seconds.
Iteration 114 / 250




Operation took 26.595143795013428 seconds.
Iteration 115 / 250




Operation took 29.199714422225952 seconds.
Iteration 116 / 250




Operation took 52.61343431472778 seconds.
Iteration 117 / 250




Operation took 26.606133460998535 seconds.
Iteration 118 / 250




Operation took 40.52986454963684 seconds.
Iteration 119 / 250




Operation took 19.562267303466797 seconds.
Iteration 120 / 250




Operation took 42.96505308151245 seconds.
Iteration 121 / 250




Operation took 63.944324254989624 seconds.
Iteration 122 / 250




Operation took 25.712027072906494 seconds.
Iteration 123 / 250




Operation took 17.661481857299805 seconds.
Iteration 124 / 250




Operation took 37.237826108932495 seconds.
Iteration 125 / 250




Operation took 14.679784059524536 seconds.
Iteration 126 / 250




Operation took 11.816922664642334 seconds.
Iteration 127 / 250




Operation took 10.804047107696533 seconds.
Iteration 128 / 250




Operation took 12.336942434310913 seconds.
Iteration 129 / 250




Operation took 10.322232723236084 seconds.
Iteration 130 / 250




Operation took 11.750911235809326 seconds.
Iteration 131 / 250




Operation took 13.12127137184143 seconds.
Iteration 132 / 250




Operation took 17.59068751335144 seconds.
Iteration 133 / 250




Operation took 17.384315490722656 seconds.
Iteration 134 / 250




Operation took 16.446138381958008 seconds.
Iteration 135 / 250




Operation took 22.110710382461548 seconds.
Iteration 136 / 250




Operation took 16.489445447921753 seconds.
Iteration 137 / 250


Trying again with a new set of initial conditions.


Operation took 15.38353967666626 seconds.
Iteration 138 / 250




Operation took 5.052512884140015 seconds.
Iteration 139 / 250




Operation took 13.61495590209961 seconds.
Iteration 140 / 250




Operation took 13.864623069763184 seconds.
Iteration 141 / 250




Operation took 12.444512367248535 seconds.
Iteration 142 / 250




Operation took 11.058007717132568 seconds.
Iteration 143 / 250




Operation took 10.70081114768982 seconds.
Iteration 144 / 250




Operation took 12.620197772979736 seconds.
Iteration 145 / 250


Trying again with a new set of initial conditions.


Operation took 48.71432280540466 seconds.
Iteration 146 / 250


Trying again with a new set of initial conditions.


Operation took 24.919740676879883 seconds.
Iteration 147 / 250




Operation took 16.75124502182007 seconds.
Iteration 148 / 250




Operation took 19.364403247833252 seconds.
Iteration 149 / 250




Operation took 25.901832818984985 seconds.
Iteration 150 / 250




Operation took 11.242137908935547 seconds.




Iteration 151 / 250




Operation took 4.567394971847534 seconds.
Iteration 152 / 250




Operation took 3.1980504989624023 seconds.
Iteration 153 / 250




Operation took 4.265796899795532 seconds.
Iteration 154 / 250




Operation took 7.376360893249512 seconds.
Iteration 155 / 250




Operation took 16.91477918624878 seconds.
Iteration 156 / 250




Operation took 13.308047533035278 seconds.
Iteration 157 / 250




Operation took 25.978585481643677 seconds.
Iteration 158 / 250




Operation took 8.751350164413452 seconds.
Iteration 159 / 250




Operation took 23.502462148666382 seconds.
Iteration 160 / 250




Operation took 25.56118083000183 seconds.
Iteration 161 / 250




Operation took 65.10478568077087 seconds.
Iteration 162 / 250




Operation took 27.866142511367798 seconds.
Iteration 163 / 250




Operation took 31.88875913619995 seconds.
Iteration 164 / 250




Operation took 54.2431366443634 seconds.
Iteration 165 / 250




Operation took 36.56377410888672 seconds.
Iteration 166 / 250




Operation took 24.57958436012268 seconds.
Iteration 167 / 250




Operation took 10.400006771087646 seconds.
Iteration 168 / 250




Operation took 22.946537971496582 seconds.
Iteration 169 / 250




Operation took 17.023967266082764 seconds.
Iteration 170 / 250




Operation took 15.010732173919678 seconds.
Iteration 171 / 250




Operation took 7.201876401901245 seconds.
Iteration 172 / 250




Operation took 15.591720342636108 seconds.
Iteration 173 / 250




Operation took 9.618608951568604 seconds.
Iteration 174 / 250




Operation took 4.46114444732666 seconds.
Iteration 175 / 250




Operation took 5.0759663581848145 seconds.
Iteration 176 / 250




Operation took 7.117524862289429 seconds.
Iteration 177 / 250




Operation took 22.01649236679077 seconds.
Iteration 178 / 250




Operation took 20.20735192298889 seconds.
Iteration 179 / 250




Operation took 7.750939846038818 seconds.
Iteration 180 / 250




Operation took 5.586484670639038 seconds.
Iteration 181 / 250




Operation took 12.89521837234497 seconds.
Iteration 182 / 250




Operation took 3.513056516647339 seconds.
Iteration 183 / 250




Operation took 18.572382926940918 seconds.
Iteration 184 / 250




Operation took 20.239585399627686 seconds.
Iteration 185 / 250




Operation took 22.082828283309937 seconds.
Iteration 186 / 250




Operation took 14.07226824760437 seconds.
Iteration 187 / 250




Operation took 18.385629653930664 seconds.
Iteration 188 / 250




Operation took 18.022754669189453 seconds.
Iteration 189 / 250




Operation took 22.78997802734375 seconds.
Iteration 190 / 250




Operation took 8.564190864562988 seconds.
Iteration 191 / 250


Trying again with a new set of initial conditions.


Operation took 35.98107123374939 seconds.
Iteration 192 / 250




Operation took 20.524348735809326 seconds.
Iteration 193 / 250




Operation took 34.974141120910645 seconds.
Iteration 194 / 250


Trying again with a new set of initial conditions.


Operation took 48.88768291473389 seconds.
Iteration 195 / 250


Trying again with a new set of initial conditions.


Operation took 39.59233117103577 seconds.
Iteration 196 / 250




Operation took 10.1803457736969 seconds.
Iteration 197 / 250




Operation took 24.9035427570343 seconds.
Iteration 198 / 250




Operation took 7.906152248382568 seconds.
Iteration 199 / 250




Operation took 24.814323902130127 seconds.
Iteration 200 / 250




Operation took 24.953022718429565 seconds.




Iteration 201 / 250




Operation took 0.2923259735107422 seconds.
Iteration 202 / 250




Operation took 17.582939863204956 seconds.
Iteration 203 / 250




Operation took 4.27402925491333 seconds.
Iteration 204 / 250




Operation took 14.799147844314575 seconds.
Iteration 205 / 250




Operation took 8.927725791931152 seconds.
Iteration 206 / 250




Operation took 20.7370445728302 seconds.
Iteration 207 / 250




Operation took 98.97805595397949 seconds.
Iteration 208 / 250




Operation took 2.4467437267303467 seconds.
Iteration 209 / 250




Operation took 1.2258117198944092 seconds.
Iteration 210 / 250




Operation took 16.741041660308838 seconds.
Iteration 211 / 250




Operation took 24.949247360229492 seconds.
Iteration 212 / 250




Operation took 79.07435178756714 seconds.
Iteration 213 / 250




Operation took 49.4382381439209 seconds.
Iteration 214 / 250




Operation took 1.7135000228881836 seconds.
Iteration 215 / 250




Operation took 1.43198823928833 seconds.
Iteration 216 / 250




Operation took 13.699976205825806 seconds.
Iteration 217 / 250




Operation took 44.56188201904297 seconds.
Iteration 218 / 250




Operation took 14.731920957565308 seconds.
Iteration 219 / 250




Operation took 19.569080591201782 seconds.
Iteration 220 / 250




Operation took 27.54439067840576 seconds.
Iteration 221 / 250




Operation took 7.769123077392578 seconds.
Iteration 222 / 250




Operation took 8.567857265472412 seconds.
Iteration 223 / 250




Operation took 18.100083351135254 seconds.
Iteration 224 / 250




Operation took 7.241258144378662 seconds.
Iteration 225 / 250




Operation took 14.907915115356445 seconds.
Iteration 226 / 250




Operation took 21.422094106674194 seconds.
Iteration 227 / 250




Operation took 14.258193016052246 seconds.
Iteration 228 / 250




Operation took 19.58192729949951 seconds.
Iteration 229 / 250




Operation took 2.9527058601379395 seconds.
Iteration 230 / 250




Operation took 5.330028772354126 seconds.
Iteration 231 / 250




Operation took 13.23324728012085 seconds.
Iteration 232 / 250




Operation took 5.22284460067749 seconds.
Iteration 233 / 250




Operation took 10.120109558105469 seconds.
Iteration 234 / 250




Operation took 5.541484832763672 seconds.
Iteration 235 / 250




Operation took 15.967736959457397 seconds.
Iteration 236 / 250




Operation took 12.837082862854004 seconds.
Iteration 237 / 250




Operation took 10.621501684188843 seconds.
Iteration 238 / 250




Operation took 15.437013387680054 seconds.
Iteration 239 / 250




Operation took 11.80711817741394 seconds.
Iteration 240 / 250




Operation took 15.294490337371826 seconds.
Iteration 241 / 250




Operation took 5.914621114730835 seconds.
Iteration 242 / 250




Operation took 22.364749431610107 seconds.
Iteration 243 / 250




Operation took 22.5723717212677 seconds.
Iteration 244 / 250




Operation took 31.3913094997406 seconds.
Iteration 245 / 250


Trying again with a new set of initial conditions.


Operation took 59.77801966667175 seconds.
Iteration 246 / 250




Operation took 23.662986278533936 seconds.
Iteration 247 / 250




Operation took 19.796268701553345 seconds.
Iteration 248 / 250




Operation took 22.38390588760376 seconds.
Iteration 249 / 250




Operation took 19.739720821380615 seconds.
Iteration 250 / 250


Trying again with a new set of initial conditions.


Operation took 41.58762288093567 seconds.


In [5]:
# import dill
# with open('MC-Emittance-Phys-Random-2d-Results.pkl', 'wb') as f:
#     dill.dump(trial_data, f)

In [6]:
import dill
with open('MC-Emittance-NonPhys-BAX-3d-Results-test.pkl', 'wb') as f:
    dill.dump(trial_data, f)