# To-Do:

- InputStandardize vs Normalize

- check hyperparam/prior definitions (scaled space vs raw space)

- keep in mind modality of NaN results in emittance

- Try different number of steps along measurement dimension

- inrease dimensionality of tuning space

- fix legend location (only plot on 1 heatmap)

# In this notebook, we fit a gpytorch GP to a simple emittance model with 1 tuning parameter. We use the GP to evaluate the Expected Information Gain toward the result of a grid-scan minimization routine.

In [1]:
import torch
from emitutils import toy_beam_size_squared_nd, fit_gp_model_emittance
from utils import unif_random_sample_domain
from matplotlib import pyplot as plt
from algorithms import GridMinimizeEmittance
from acquisition import ExpectedInformationGain
from botorch.optim import optimize_acqf
import time
from mpl_toolkits.axes_grid1 import make_axes_locatable
import copy

# Settings

In [2]:
# domain = torch.tensor([[-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
domain = torch.tensor([[-3,1], [-40,60]]).double() #the acquisition domain, must have shape = (ndim, 2)
ndim = domain.shape[0]                               #where domain[i,0] and domain[i,1] represent
                                                        #the lower and upper bounds of the ith input dimension
                                                        #(these same bounds will be applied to the sampled execution paths) 

    

    
n_samples = 100 #number of posterior samples on which to evaluate execution paths
n_steps_tuning_params = 51 #number of steps per dimension in the posterior sample grid scans 
n_steps_measurement_param = 51
squared = False #whether or not to minimize the "emittance squared" (which can be negative according to the model)





random_acq = False
n_trials = 5
n_iter = 20
n_obs_init = 3 #number of random observations on which to initialize model


In [3]:
# domain = torch.tensor([[-2,2], [-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
#                                                         #where domain[i,0] and domain[i,1] represent
#                                                         #the lower and upper bounds of the ith input dimension
#                                                         #(these same bounds will be applied to the sampled execution paths)
        
# ndim = domain.shape[0]
    





# n_samples = 100 #number of posterior samples on which to evaluate execution paths
# n_steps_tuning_params = 11 #number of steps per dimension in the posterior sample grid scans 
# n_steps_measurement_param = 11
# squared = True #whether or not to minimize the "emittance squared" (which can be negative according to the model)





# random_acq = True
# n_trials = 20
# n_iter = 20
# n_obs_init = 5 #number of random observations on which to initialize model

# Initialize

In [4]:
trial_data = {}
trial_data['settings'] = {'domain':domain,
                         'ndim':ndim,
                         'n_obs_init': n_obs_init,
                         'n_samples':n_samples,
                         'n_steps_tuning_params':n_steps_tuning_params,
                         'n_steps_measurement_param': n_steps_measurement_param,
                         'n_trials':n_trials,
                         'n_iter':n_iter,
                         'squared':squared,
                         'random_acq':random_acq}

for trial in range(n_trials):
    torch.manual_seed(trial)

    #build ndim dimensional parabolic target function
    target_func = toy_beam_size_squared_nd


    ##########################################
    #Observe target function n_obs_init times using a uniform sample of the domain
    x_obs = unif_random_sample_domain(n_samples = n_obs_init, domain = domain)
    y_obs = target_func(x_obs) 




    #fit model on initial observations
    model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

    algo = GridMinimizeEmittance(domain = domain, 
                   n_samples = n_samples, 
                   n_steps_tuning_params = n_steps_tuning_params,
                    n_steps_measurement_param = n_steps_measurement_param,
                    squared = squared)
    
    rng_state = torch.get_rng_state()
    
    acq_fn = ExpectedInformationGain(model = model, algo = algo)

    if random_acq:
        x_next = None
    else:
        x_next, _ = optimize_acqf(
            acq_function=acq_fn,
            bounds=acq_fn.algo.domain.T,
            q=1,
            num_restarts=20,
            raw_samples=100,
            options={},
            )
    
    iter_data = {}
    iter_data[0] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}
    
    for i in range(1, n_iter+1):
        start = time.time()
        print('Iteration', trial*n_iter + i, '/', n_trials*n_iter)
        
        if random_acq:
            x_new = unif_random_sample_domain(n_samples = 1, domain = domain)
        else:
            x_new = x_next
            
        y_new = target_func(x_new)

        x_obs = torch.cat((x_obs, x_new), dim=0)
        y_obs = torch.cat((y_obs, y_new), dim=0)

        model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

        rng_state = torch.get_rng_state()
        
        acq_fn = ExpectedInformationGain(model = model, algo = algo)

        if random_acq:
            x_next = None
        else:
            x_next, _ = optimize_acqf(
                acq_function=acq_fn,
                bounds=acq_fn.algo.domain.T,
                q=1,
                num_restarts=20,
                raw_samples=100,
                options={},
                )
            
        end = time.time()
        print('Operation took', end - start, 'seconds.')
        
        iter_data[i] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}

    trial_data[trial] = iter_data 




Iteration 1 / 100




Operation took 5.923458576202393 seconds.
Iteration 2 / 100




Operation took 1.015073537826538 seconds.
Iteration 3 / 100




Operation took 35.77023124694824 seconds.
Iteration 4 / 100




Operation took 24.02318572998047 seconds.
Iteration 5 / 100




Operation took 5.7902727127075195 seconds.
Iteration 6 / 100




Operation took 16.361685037612915 seconds.
Iteration 7 / 100




Operation took 16.20863175392151 seconds.
Iteration 8 / 100




Operation took 12.930803775787354 seconds.
Iteration 9 / 100




Operation took 23.11672806739807 seconds.
Iteration 10 / 100




Operation took 17.578937292099 seconds.
Iteration 11 / 100




Operation took 16.374406814575195 seconds.
Iteration 12 / 100




Operation took 29.225892543792725 seconds.
Iteration 13 / 100




Operation took 14.998883724212646 seconds.
Iteration 14 / 100




Operation took 21.437307834625244 seconds.
Iteration 15 / 100




Operation took 34.5930449962616 seconds.
Iteration 16 / 100




Operation took 26.568583250045776 seconds.
Iteration 17 / 100




Operation took 60.400023460388184 seconds.
Iteration 18 / 100




Operation took 27.400992393493652 seconds.
Iteration 19 / 100




Operation took 11.193016529083252 seconds.
Iteration 20 / 100




Operation took 20.0681312084198 seconds.


Trying again with a new set of initial conditions.


Iteration 21 / 100




Operation took 0.8749549388885498 seconds.
Iteration 22 / 100




Operation took 20.00701355934143 seconds.
Iteration 23 / 100




Operation took 30.44447660446167 seconds.
Iteration 24 / 100




Operation took 31.38685965538025 seconds.
Iteration 25 / 100




Operation took 65.98337459564209 seconds.
Iteration 26 / 100




Operation took 49.72317934036255 seconds.
Iteration 27 / 100




Operation took 20.62986946105957 seconds.
Iteration 28 / 100




Operation took 7.315247535705566 seconds.
Iteration 29 / 100




Operation took 9.075652122497559 seconds.
Iteration 30 / 100




Operation took 12.981511116027832 seconds.
Iteration 31 / 100




Operation took 7.228952407836914 seconds.
Iteration 32 / 100




Operation took 30.82413101196289 seconds.
Iteration 33 / 100




Operation took 37.57856202125549 seconds.
Iteration 34 / 100




Operation took 27.472299337387085 seconds.
Iteration 35 / 100




Operation took 10.706819295883179 seconds.
Iteration 36 / 100




Operation took 19.68525195121765 seconds.
Iteration 37 / 100




Operation took 39.09226679801941 seconds.
Iteration 38 / 100




Operation took 38.43990516662598 seconds.
Iteration 39 / 100




Operation took 27.607048273086548 seconds.
Iteration 40 / 100




Operation took 23.584920406341553 seconds.




Iteration 41 / 100




Operation took 17.13505220413208 seconds.
Iteration 42 / 100




Operation took 251.60341215133667 seconds.
Iteration 43 / 100




Operation took 61.789382457733154 seconds.
Iteration 44 / 100




Operation took 154.5232949256897 seconds.
Iteration 45 / 100




Operation took 55.802428245544434 seconds.
Iteration 46 / 100




Operation took 46.1994354724884 seconds.
Iteration 47 / 100




Operation took 16.947425365447998 seconds.
Iteration 48 / 100




Operation took 17.93860912322998 seconds.
Iteration 49 / 100




Operation took 32.62141466140747 seconds.
Iteration 50 / 100




Operation took 24.281391620635986 seconds.
Iteration 51 / 100




Operation took 87.39247250556946 seconds.
Iteration 52 / 100




Operation took 4.586270809173584 seconds.
Iteration 53 / 100




Operation took 19.29021430015564 seconds.
Iteration 54 / 100




Operation took 21.194556713104248 seconds.
Iteration 55 / 100




Operation took 27.87487506866455 seconds.
Iteration 56 / 100




Operation took 12.029344320297241 seconds.
Iteration 57 / 100




Operation took 22.41781735420227 seconds.
Iteration 58 / 100




Operation took 43.26779532432556 seconds.
Iteration 59 / 100




Operation took 30.90449619293213 seconds.
Iteration 60 / 100




Operation took 29.982768774032593 seconds.




Iteration 61 / 100




Operation took 1.4110383987426758 seconds.
Iteration 62 / 100




Operation took 9.726462602615356 seconds.
Iteration 63 / 100




Operation took 14.950825691223145 seconds.
Iteration 64 / 100




Operation took 33.446385622024536 seconds.
Iteration 65 / 100




Operation took 75.02229452133179 seconds.
Iteration 66 / 100




Operation took 12.005374908447266 seconds.
Iteration 67 / 100




Operation took 38.53338956832886 seconds.
Iteration 68 / 100




Operation took 27.396394729614258 seconds.
Iteration 69 / 100




Operation took 41.40396046638489 seconds.
Iteration 70 / 100




Operation took 20.763717889785767 seconds.
Iteration 71 / 100




Operation took 27.056302070617676 seconds.
Iteration 72 / 100




Operation took 34.98971343040466 seconds.
Iteration 73 / 100




Operation took 32.19004511833191 seconds.
Iteration 74 / 100




Operation took 44.3727285861969 seconds.
Iteration 75 / 100




Operation took 47.02225613594055 seconds.
Iteration 76 / 100




Operation took 38.0305871963501 seconds.
Iteration 77 / 100




Operation took 12.76411509513855 seconds.
Iteration 78 / 100




Operation took 26.89466881752014 seconds.
Iteration 79 / 100




Operation took 20.307164669036865 seconds.
Iteration 80 / 100




Operation took 34.367703437805176 seconds.




Iteration 81 / 100




Operation took 74.85226321220398 seconds.
Iteration 82 / 100




Operation took 17.828404903411865 seconds.
Iteration 83 / 100




Operation took 60.32107925415039 seconds.
Iteration 84 / 100




Operation took 41.480682373046875 seconds.
Iteration 85 / 100




Operation took 12.776601314544678 seconds.
Iteration 86 / 100




Operation took 20.547049045562744 seconds.
Iteration 87 / 100




Operation took 10.887054920196533 seconds.
Iteration 88 / 100




Operation took 28.08390188217163 seconds.
Iteration 89 / 100




Operation took 7.785062074661255 seconds.
Iteration 90 / 100




Operation took 19.39391016960144 seconds.
Iteration 91 / 100




Operation took 16.320213556289673 seconds.
Iteration 92 / 100




Operation took 20.39414072036743 seconds.
Iteration 93 / 100




Operation took 20.622052669525146 seconds.
Iteration 94 / 100




Operation took 35.32005977630615 seconds.
Iteration 95 / 100




Operation took 11.245167016983032 seconds.
Iteration 96 / 100




Operation took 24.47987961769104 seconds.
Iteration 97 / 100




Operation took 24.197529554367065 seconds.
Iteration 98 / 100




Operation took 11.458471298217773 seconds.
Iteration 99 / 100




Operation took 18.670682430267334 seconds.
Iteration 100 / 100




Operation took 23.470306873321533 seconds.


In [5]:
# import dill
# with open('MC-Emittance-Phys-Random-2d-Results.pkl', 'wb') as f:
#     dill.dump(trial_data, f)

In [6]:
import dill
with open('MC-Emittance-Phys-BAX-2d-Results-test.pkl', 'wb') as f:
    dill.dump(trial_data, f)