# To-Do:

- InputStandardize vs Normalize

- check hyperparam/prior definitions (scaled space vs raw space)

- keep in mind modality of NaN results in emittance

- Try different number of steps along measurement dimension

- inrease dimensionality of tuning space

- fix legend location (only plot on 1 heatmap)

# In this notebook, we fit a gpytorch GP to a simple emittance model with 1 tuning parameter. We use the GP to evaluate the Expected Information Gain toward the result of a grid-scan minimization routine.

In [1]:
import torch
from emitutils import toy_beam_size_squared_nd, fit_gp_model_emittance
from utils import unif_random_sample_domain
from matplotlib import pyplot as plt
from algorithms import GridMinimizeEmittance
from acquisition import ExpectedInformationGain
from botorch.optim import optimize_acqf
import time
from mpl_toolkits.axes_grid1 import make_axes_locatable
import copy

# Settings

In [2]:
# domain = torch.tensor([[-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
domain = torch.tensor([[-3,1], [-40,60]]).double() #the acquisition domain, must have shape = (ndim, 2)
ndim = domain.shape[0]                               #where domain[i,0] and domain[i,1] represent
                                                        #the lower and upper bounds of the ith input dimension
                                                        #(these same bounds will be applied to the sampled execution paths) 

    

    
n_samples = 100 #number of posterior samples on which to evaluate execution paths
n_steps_tuning_params = 51 #number of steps per dimension in the posterior sample grid scans 
n_steps_measurement_param = 51
squared = True #whether or not to minimize the "emittance squared" (which can be negative according to the model)





random_acq = False
n_trials = 5
n_iter = 20
n_obs_init = 3 #number of random observations on which to initialize model


In [3]:
# domain = torch.tensor([[-2,2], [-2,2], [-65,35]]).double() #the acquisition domain, must have shape = (ndim, 2)
#                                                         #where domain[i,0] and domain[i,1] represent
#                                                         #the lower and upper bounds of the ith input dimension
#                                                         #(these same bounds will be applied to the sampled execution paths)
        
# ndim = domain.shape[0]
    





# n_samples = 100 #number of posterior samples on which to evaluate execution paths
# n_steps_tuning_params = 11 #number of steps per dimension in the posterior sample grid scans 
# n_steps_measurement_param = 11
# squared = True #whether or not to minimize the "emittance squared" (which can be negative according to the model)





# random_acq = True
# n_trials = 20
# n_iter = 20
# n_obs_init = 5 #number of random observations on which to initialize model

# Initialize

In [None]:
trial_data = {}
trial_data['settings'] = {'domain':domain,
                         'ndim':ndim,
                         'n_obs_init': n_obs_init,
                         'n_samples':n_samples,
                         'n_steps_tuning_params':n_steps_tuning_params,
                         'n_steps_measurement_param': n_steps_measurement_param,
                         'n_trials':n_trials,
                         'n_iter':n_iter,
                         'squared':squared,
                         'random_acq':random_acq}

for trial in range(n_trials):
    torch.manual_seed(trial)

    #build ndim dimensional parabolic target function
    target_func = toy_beam_size_squared_nd


    ##########################################
    #Observe target function n_obs_init times using a uniform sample of the domain
    x_obs = unif_random_sample_domain(n_samples = n_obs_init, domain = domain)
    y_obs = target_func(x_obs) 




    #fit model on initial observations
    model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

    algo = GridMinimizeEmittance(domain = domain, 
                   n_samples = n_samples, 
                   n_steps_tuning_params = n_steps_tuning_params,
                    n_steps_measurement_param = n_steps_measurement_param,
                    squared = squared)
    
    rng_state = torch.get_rng_state()
    
    acq_fn = ExpectedInformationGain(model = model, algo = algo)

    if random_acq:
        x_next = None
    else:
        x_next, _ = optimize_acqf(
            acq_function=acq_fn,
            bounds=acq_fn.algo.domain.T,
            q=1,
            num_restarts=20,
            raw_samples=100,
            options={},
            )
    
    iter_data = {}
    iter_data[0] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}
    
    for i in range(1, n_iter+1):
        start = time.time()
        print('Iteration', trial*n_iter + i, '/', n_trials*n_iter)
        
        if random_acq:
            x_new = unif_random_sample_domain(n_samples = 1, domain = domain)
        else:
            x_new = x_next
            
        y_new = target_func(x_new)

        x_obs = torch.cat((x_obs, x_new), dim=0)
        y_obs = torch.cat((y_obs, y_new), dim=0)

        model = fit_gp_model_emittance(x_obs, y_obs*1.e6)

        rng_state = torch.get_rng_state()
        
        acq_fn = ExpectedInformationGain(model = model, algo = algo)

        if random_acq:
            x_next = None
        else:
            x_next, _ = optimize_acqf(
                acq_function=acq_fn,
                bounds=acq_fn.algo.domain.T,
                q=1,
                num_restarts=20,
                raw_samples=100,
                options={},
                )
            
        end = time.time()
        print('Operation took', end - start, 'seconds.')
        
        iter_data[i] = {'x_obs': x_obs,
                   'y_obs': y_obs,
                    'x_next': x_next,
                   'model':  copy.deepcopy(model),
                   'rng_state': rng_state}

    trial_data[trial] = iter_data 




Iteration 1 / 100




Operation took 13.021301984786987 seconds.
Iteration 2 / 100




Operation took 40.086755990982056 seconds.
Iteration 3 / 100




Operation took 19.531930446624756 seconds.
Iteration 4 / 100




Operation took 47.07220911979675 seconds.
Iteration 5 / 100




Operation took 34.70206117630005 seconds.
Iteration 6 / 100




Operation took 98.61155724525452 seconds.
Iteration 7 / 100




Operation took 49.666378021240234 seconds.
Iteration 8 / 100




Operation took 10.810072422027588 seconds.
Iteration 9 / 100




Operation took 12.454872608184814 seconds.
Iteration 10 / 100




Operation took 15.58082890510559 seconds.
Iteration 11 / 100




Operation took 6.419825315475464 seconds.
Iteration 12 / 100




Operation took 9.772191762924194 seconds.
Iteration 13 / 100




Operation took 18.7216854095459 seconds.
Iteration 14 / 100




Operation took 10.923776865005493 seconds.
Iteration 15 / 100




Operation took 12.059475183486938 seconds.
Iteration 16 / 100




Operation took 46.70242643356323 seconds.
Iteration 17 / 100




Operation took 10.711866855621338 seconds.
Iteration 18 / 100




Operation took 21.48720622062683 seconds.
Iteration 19 / 100




Operation took 14.328136682510376 seconds.
Iteration 20 / 100




Operation took 12.332969427108765 seconds.




Iteration 21 / 100




Operation took 1.9393839836120605 seconds.
Iteration 22 / 100




Operation took 9.496249198913574 seconds.
Iteration 23 / 100




Operation took 45.71478819847107 seconds.
Iteration 24 / 100




Operation took 41.712040424346924 seconds.
Iteration 25 / 100




Operation took 128.67060947418213 seconds.
Iteration 26 / 100




Operation took 24.403102159500122 seconds.
Iteration 27 / 100




Operation took 19.62763524055481 seconds.
Iteration 28 / 100




Operation took 14.177024602890015 seconds.
Iteration 29 / 100




Operation took 9.532472133636475 seconds.
Iteration 30 / 100




Operation took 10.757670879364014 seconds.
Iteration 31 / 100




Operation took 12.867915868759155 seconds.
Iteration 32 / 100




Operation took 32.153347969055176 seconds.
Iteration 33 / 100




Operation took 25.677207946777344 seconds.
Iteration 34 / 100




Operation took 20.026148796081543 seconds.
Iteration 35 / 100




Operation took 38.48443007469177 seconds.
Iteration 36 / 100




Operation took 8.622680425643921 seconds.
Iteration 37 / 100




Operation took 17.702862977981567 seconds.
Iteration 38 / 100




Operation took 19.876479864120483 seconds.
Iteration 39 / 100




Operation took 21.540583610534668 seconds.
Iteration 40 / 100




Operation took 31.455955505371094 seconds.




Iteration 41 / 100




Operation took 8.565735578536987 seconds.
Iteration 42 / 100




Operation took 25.701690912246704 seconds.
Iteration 43 / 100




Operation took 1.76792573928833 seconds.
Iteration 44 / 100




Operation took 23.913989305496216 seconds.
Iteration 45 / 100




Operation took 8.954219341278076 seconds.
Iteration 46 / 100




Operation took 52.02328395843506 seconds.
Iteration 47 / 100




Operation took 17.644250631332397 seconds.
Iteration 48 / 100




Operation took 36.852325439453125 seconds.
Iteration 49 / 100




Operation took 166.25622177124023 seconds.
Iteration 50 / 100




Operation took 12.503758668899536 seconds.
Iteration 51 / 100




Operation took 69.43336033821106 seconds.
Iteration 52 / 100




Operation took 18.434372186660767 seconds.
Iteration 53 / 100




Operation took 14.160531282424927 seconds.
Iteration 54 / 100




Operation took 126.71080732345581 seconds.
Iteration 55 / 100




Operation took 21.768287658691406 seconds.
Iteration 56 / 100




Operation took 56.26460909843445 seconds.
Iteration 57 / 100




Operation took 10.225687026977539 seconds.
Iteration 58 / 100




Operation took 30.264554023742676 seconds.
Iteration 59 / 100




Operation took 49.81946063041687 seconds.
Iteration 60 / 100




Operation took 54.92922782897949 seconds.




Iteration 61 / 100




Operation took 1.3291091918945312 seconds.
Iteration 62 / 100




Operation took 4.142114877700806 seconds.
Iteration 63 / 100




Operation took 5.577220439910889 seconds.
Iteration 64 / 100




Operation took 6.41988205909729 seconds.
Iteration 65 / 100




Operation took 2.6748082637786865 seconds.
Iteration 66 / 100




Operation took 24.408491849899292 seconds.
Iteration 67 / 100




Operation took 11.477437496185303 seconds.
Iteration 68 / 100




Operation took 56.44148373603821 seconds.
Iteration 69 / 100




Operation took 36.28406620025635 seconds.
Iteration 70 / 100




Operation took 28.843416690826416 seconds.
Iteration 71 / 100




Operation took 30.81887650489807 seconds.
Iteration 72 / 100




Operation took 82.8038969039917 seconds.
Iteration 73 / 100




Operation took 36.859132051467896 seconds.
Iteration 74 / 100




Operation took 39.97119140625 seconds.
Iteration 75 / 100




Operation took 23.845489263534546 seconds.
Iteration 76 / 100




Operation took 36.32708168029785 seconds.
Iteration 77 / 100




Operation took 30.16129994392395 seconds.
Iteration 78 / 100




Operation took 30.025294303894043 seconds.
Iteration 79 / 100




Operation took 16.861082553863525 seconds.
Iteration 80 / 100




Operation took 164.61326003074646 seconds.




Iteration 81 / 100




Operation took 122.22316765785217 seconds.
Iteration 82 / 100




Operation took 62.687846183776855 seconds.
Iteration 83 / 100




Operation took 17.835548639297485 seconds.
Iteration 84 / 100




Operation took 36.091410875320435 seconds.
Iteration 85 / 100




Operation took 44.63975405693054 seconds.
Iteration 86 / 100




Operation took 109.24133586883545 seconds.
Iteration 87 / 100




Operation took 53.84236979484558 seconds.
Iteration 88 / 100




Operation took 19.743927240371704 seconds.
Iteration 89 / 100




Operation took 13.528844594955444 seconds.
Iteration 90 / 100




Operation took 15.688855648040771 seconds.
Iteration 91 / 100




Operation took 24.478126049041748 seconds.
Iteration 92 / 100




Operation took 9.17183804512024 seconds.
Iteration 93 / 100




Operation took 12.133857488632202 seconds.
Iteration 94 / 100




Operation took 26.031176567077637 seconds.
Iteration 95 / 100




Operation took 20.4365177154541 seconds.
Iteration 96 / 100




Operation took 18.301671266555786 seconds.
Iteration 97 / 100




Operation took 17.1375253200531 seconds.
Iteration 98 / 100




Operation took 24.99575424194336 seconds.
Iteration 99 / 100




Operation took 31.368059873580933 seconds.
Iteration 100 / 100




In [None]:
# import dill
# with open('MC-Emittance-Phys-Random-2d-Results.pkl', 'wb') as f:
#     dill.dump(trial_data, f)

In [None]:
import dill
with open('MC-Emittance-NonPhys-BAX-2d-Results-test.pkl', 'wb') as f:
    dill.dump(trial_data, f)