Prediction plots of trained amortization network

In [None]:
from amorstructgp.models.gp_model_amortized_structured import GPModelAmortizedStructured
from amorstructgp.models.gp_model_amortized_ensemble import GPModelAmortizedEnsemble
from amorstructgp.config.models.gp_model_amortized_structured_config import PaperAmortizedStructuredConfig,AmortizedStructuredWithMaternConfig
from amorstructgp.config.models.gp_model_amortized_ensemble_config import PaperAmortizedEnsembleConfig,AmortizedEnsembleWithMaternConfig
from amorstructgp.models.model_factory import ModelFactory
from amorstructgp.config.nn.amortized_infer_models_configs import WiderCrossAttentionKernelEncSharedDatasetEncMLPWrapperAmortizedModelConfig,SmallerStandardNoDropoutCrossAttentionKernelEncSharedDatasetEncMLPWrapperAmortizedModelConfig,SmallerStandardSmallNoiseBoundNoDropoutCrossAttentionKernelEncSharedDatasetEncMLPWrapperAmortizedModelConfig
from amorstructgp.utils.enums import PredictionQuantity
from amorstructgp.gp.base_symbols import BaseKernelTypes
from amorstructgp.data_generators.simulator import Simulator
from amorstructgp.config.prior_parameters import NOISE_VARIANCE_EXPONENTIAL_LAMBDA
from amorstructgp.gp.base_kernels import transform_kernel_list_to_expression
from amorstructgp.utils.plotter import Plotter
import torch
import numpy as np
import os

Input Settings

In [None]:
USE_PAPER_MODEL = False
PATH_TO_PRETRAINED=""
paper_model_path = os.path.join(PATH_TO_PRETRAINED,"main_state_dict_paper.pth")
matern_model_path = os.path.join(PATH_TO_PRETRAINED,"main_state_dict_with_matern.pth")
if USE_PAPER_MODEL:
    model_path = paper_model_path
else:
    model_path = matern_model_path

Load Model

In [None]:
if USE_PAPER_MODEL:
    amortized_model_config = PaperAmortizedStructuredConfig(checkpoint_path=model_path)
else:
    amortized_model_config = AmortizedStructuredWithMaternConfig(checkpoint_path=model_path)
model = ModelFactory.build(amortized_model_config)

Load Ensemble Model

In [None]:
from amorstructgp.utils.gaussian_mixture_density import EntropyApproximation
if USE_PAPER_MODEL:
    amortized_ensemble_model_config = PaperAmortizedEnsembleConfig(checkpoint_path=model_path)
else:
    amortized_ensemble_model_config = AmortizedEnsembleWithMaternConfig(checkpoint_path=model_path)
ensemble = ModelFactory.build(amortized_ensemble_model_config)

Main Settings

In [None]:
######## Kernel configurations ############

# Kernel from which data is generated - inner list is interpreted as sum (see paper)
simulator_kernel_list = [[BaseKernelTypes.MATERN52,BaseKernelTypes.LIN]] 

# a collection of possible kernel structures (each represented as nested list - here only 1D data is considered so the first list over dimension has only one element)
input_kernel_list_1 = [[BaseKernelTypes.MATERN52]]
input_kernel_list_2 = [[BaseKernelTypes.LIN]]
input_kernel_list_3 = [[BaseKernelTypes.PER]]
input_kernel_list_4 = [[BaseKernelTypes.SE_MULT_LIN]]
input_kernel_list_5 = [[BaseKernelTypes.LIN_MULT_PER]]
input_kernel_list_6 = [[BaseKernelTypes.SE,BaseKernelTypes.LIN]]
input_kernel_list_7 = [[BaseKernelTypes.PER,BaseKernelTypes.LIN]]
input_kernel_list_8 = [[BaseKernelTypes.PER,BaseKernelTypes.SE_MULT_LIN,BaseKernelTypes.LIN]]
input_kernel_list_9 = [[BaseKernelTypes.SE,BaseKernelTypes.LIN,BaseKernelTypes.SE_MULT_LIN]]
input_kernel_list_10= [[BaseKernelTypes.LIN,BaseKernelTypes.SE_MULT_LIN]]

# here you can configure which kernels should be tried - prediction will loop over this list 
input_kernel_lists =[input_kernel_list_1]

# configuration of ensemble of kernel strucutres
ensemble_kernel_list = [input_kernel_list_1,input_kernel_list_2,input_kernel_list_3,input_kernel_list_4,input_kernel_list_5,input_kernel_list_6,input_kernel_list_7,input_kernel_list_8,input_kernel_list_9,input_kernel_list_10]

########### Dataset configurations ###########

# range from which gt data is generated - input data of amortization should lie between 0.0 and 1.0 but test data can lie outside
data_a = -0.5
data_b = 1.5 

n_data = 10 # number of training datapoint - uniform random in (0.0,1.0)
n_test = 400 # number of test datapoints - uniform random in (data_a,data_b)
observation_noise = 0.05 # obersation noise that is added to the data

########## Model that should be printed ###########

add_ml_model = True # if predictive dist of standard ML (with repeated optimitation) should be included
add_gt_model = True # if gt GP predictive dist should be included
add_ensemble = True # if amortized ensemble predictive dist should be included



Initialize GP Models

In [None]:
from amorstructgp.config.kernels.gpytorch_kernels.elementary_kernels_pytorch_configs import BasicRBFPytorchConfig,RBFWithPriorPytorchConfig
from amorstructgp.config.models.gp_model_gpytorch_config import BasicGPModelPytorchConfig,GPModelPytorchMultistartConfig
gp_model_config  = GPModelPytorchMultistartConfig(kernel_config=RBFWithPriorPytorchConfig(input_dimension=1))
gp_model_config.add_constant_mean_function = False
gp_model_config.set_prior_on_observation_noise=True
gp_model_gt_config  = BasicGPModelPytorchConfig(kernel_config=BasicRBFPytorchConfig(input_dimension=0))
gp_model_gt_config.add_constant_mean_function = False
gp_model_gt_config.optimize_hps=False


Simulate Dataset

In [None]:
import random
seed = random.randint(0,10000)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
n_sim = 500
simulator = Simulator(data_a,data_b,NOISE_VARIANCE_EXPONENTIAL_LAMBDA)
kernel_expression = transform_kernel_list_to_expression(simulator_kernel_list)
simulated_dataset = simulator.create_sample(n_sim,n_test,kernel_expression,observation_noise)
simulated_dataset.add_kernel_list_gt(simulator_kernel_list)
input_dimension = simulated_dataset.get_input_dimension()
normalize_data = False
x_data,y_data = simulated_dataset.get_dataset(normalized_version=normalize_data)
y_data =np.expand_dims(y_data[(x_data >=0.0) & (x_data<=1.0)],axis=1)[:n_data]
x_data =np.expand_dims(x_data[(x_data >=0.0) & (x_data<=1.0)],axis=1)[:n_data]
x_test,y_test = simulated_dataset.get_test_dataset(normalized_version=normalize_data)
_,f_test = simulated_dataset.get_ground_truth_f(normalized_version=normalize_data)

Make inference and prediction

In [None]:
from amorstructgp.utils.plotter import PlotterPlotly
from amorstructgp.models.model_factory import ModelFactory
import gpytorch
import time
print("seed: "+str(seed))

#print(f"n_data: {n_sim}")
for input_kernel_list in input_kernel_lists:
    ## Amor model
    model.set_kernel_list(input_kernel_list)
    
    kernel_parameters,noise_variances =model.get_predicted_parameters(x_data,y_data)
    time_before_model_infer = time.perf_counter()
    model.infer(x_data,y_data)
    pred_mu,pred_sigma = model.predictive_dist(x_test)
    time_after_model_infer = time.perf_counter()
    model_time = time_after_model_infer-time_before_model_infer
    nll_model = -1*np.mean(model.predictive_log_likelihood(x_test,y_test))
    print("Predicted parameters:")
    print(kernel_parameters)
    print(torch.sqrt(noise_variances))
    print("GT parameters:")
    print(simulated_dataset.get_gt_kernel_parameter_list())
    print(simulated_dataset.get_observation_noise())
    ## gt model
    n_plots = 1
    if add_ml_model:
        gp_model = ModelFactory.build(gp_model_config)
        print(gp_model)
        gp_model.kernel_module =  gpytorch.kernels.AdditiveKernel(transform_kernel_list_to_expression(input_kernel_list).get_kernel())
        time_before_model_infer = time.perf_counter()
        gp_model.infer(x_data,y_data)
        pred_mu_gp,pred_sigma_gp = gp_model.predictive_dist(x_test)
        time_after_model_infer = time.perf_counter()
        gp_model_time = time_after_model_infer-time_before_model_infer
        nll_gp_model = -1*np.mean(gp_model.predictive_log_likelihood(x_test,y_test))
        gp_model.eval_log_posterior_density(x_data,y_data)
        ml_model_index=n_plots
        n_plots+=1
    if add_gt_model:
        gp_model_gt_config.initial_likelihood_noise=simulated_dataset.get_observation_noise()
        gt_model = ModelFactory.build(gp_model_gt_config)
        gt_model.kernel_module = simulated_dataset.get_kernel_expression_gt().get_kernel()
        time_before_model_infer = time.perf_counter()
        gt_model.infer(x_data,y_data)
        pred_mu_gt,pred_sigma_gt = gt_model.predictive_dist(x_test)
        time_after_model_infer = time.perf_counter()
        gt_model_time = time_after_model_infer-time_before_model_infer
        nll_gt_model = -1*np.mean(gt_model.predictive_log_likelihood(x_test,y_test))
        gt_model_index = n_plots
        n_plots+=1  
    if add_ensemble:
        ensemble.set_kernel_list(ensemble_kernel_list)
        ensemble.fast_batch_inference=True
        time_before_model_infer = time.perf_counter()
        ensemble.infer(x_data,y_data)
        pred_mu_ensemble,pred_sigma_ensemble = ensemble.predictive_dist(x_test)
        time_after_model_infer = time.perf_counter()
        ensemble_model_time = time_after_model_infer-time_before_model_infer
        pred_mus_ensemble,_ = ensemble.predict(x_test)
        nll_ensemble = -1*np.mean(ensemble.predictive_log_likelihood(x_test,y_test))
        ensemble_index = n_plots
        n_plots +=1


    plotter = PlotterPlotly(n_plots,share_y=True)
    plotter.add_gt_function(x_test,f_test,"red",0)
    plotter.add_predictive_dist(np.squeeze(x_test),np.squeeze(pred_mu),np.squeeze(pred_sigma),0)
    plotter.add_datapoints(x_data,y_data,"limegreen",0)
    display(f"NLL amor model: {nll_model}")
    display(f"Time amor model: {model_time} sec")
    if add_ml_model:
        plotter.add_gt_function(x_test,f_test,"red",ml_model_index)
        plotter.add_predictive_dist(np.squeeze(x_test),np.squeeze(pred_mu_gp),np.squeeze(pred_sigma_gp),ml_model_index)
        plotter.add_datapoints(x_data,y_data,"limegreen",ml_model_index)
        display(f"NLL gp model: {nll_gp_model}")
        display(f"Time GP model: {gp_model_time} sec")
    if add_gt_model:
        plotter.add_gt_function(x_test,f_test,"red",gt_model_index)
        plotter.add_predictive_dist(np.squeeze(x_test),np.squeeze(pred_mu_gt),np.squeeze(pred_sigma_gt),gt_model_index)
        plotter.add_datapoints(x_data,y_data,"limegreen",gt_model_index)
        display(f"NLL gt model: {nll_gt_model}")
        display(f"Time gt model: {gt_model_time} sec")
    if add_ensemble:
        for i in range(len(pred_mus_ensemble)):
            plotter.add_gt_function(x_test,pred_mus_ensemble[i,:],"fuchsia",ensemble_index,line_opacity=1.0)
        plotter.add_predictive_dist(np.squeeze(x_test),np.squeeze(pred_mu_ensemble),np.squeeze(pred_sigma_ensemble),ensemble_index,opacity_scale=0.7)
        #plotter.add_gt_function(x_test,f_test,"red",ensemble_index)
        plotter.add_datapoints(x_data,y_data,"limegreen",ensemble_index)
        display(f"NLL amor ensemble: {nll_ensemble}")
        display(f"Time amor ensemble: {ensemble_model_time} sec")
    
    #plotter.add_datapoints(x_test,y_test,"red",0)
    plotter.show()