## Prediction example of trained amortization network

This is a minimal example how to use our amortization network for GP inference.

In [None]:
from amorstructgp.models.gp_model_amortized_structured import GPModelAmortizedStructured
from amorstructgp.config.models.gp_model_amortized_structured_config import PaperAmortizedStructuredConfig,AmortizedStructuredWithMaternConfig
from amorstructgp.models.model_factory import ModelFactory
from amorstructgp.gp.base_symbols import BaseKernelTypes
from amorstructgp.data_generators.simulator import Simulator
from amorstructgp.config.prior_parameters import NOISE_VARIANCE_EXPONENTIAL_LAMBDA
from amorstructgp.gp.base_kernels import transform_kernel_list_to_expression
import torch
import numpy as np

### Checkpoint Paths 

We uploaded two pretrained weights - the one used in the paper (with base kernels SE,PER and LIN and their two-gram multiplications) and one with the Matern-52 kernel additionally

In [None]:
USE_PAPER_MODEL = False
PATH_TO_PRETRAINED=""
paper_model_path = os.path.join(PATH_TO_PRETRAINED,"main_state_dict_paper.pth")
matern_model_path = os.path.join(PATH_TO_PRETRAINED,"main_state_dict_with_matern.pth")
if USE_PAPER_MODEL:
    model_path = paper_model_path
else:
    model_path = matern_model_path

### Load Model

We use the factory pattern to build models that admit to the `BaseModel` interface. The `BaseModel` child class assiciated with the amortization networks is `GPModelAmortizedStructured` and bascically forms wrapper around the actual torch models.

In [None]:
if USE_PAPER_MODEL:
    amortized_model_config = PaperAmortizedStructuredConfig(checkpoint_path=model_path)
else:
    amortized_model_config = AmortizedStructuredWithMaternConfig(checkpoint_path=model_path)
model = ModelFactory.build(amortized_model_config)

### Specify input kernel

Input kernels are configured as nested lists for example `[[BaseKernelTypes.SE,BaseKernelTypes.LIN]]` where the elements of the outer list describe the kernels in the seperate dimensions (that are multiplied) and the elements of the inner list describe the base kernels inside each dimension - the list in this case is interpreted as an addition of the base kernels. The possible base kernel can be deduced from the `BaseKernelTypes` enums (the paper model is not trained on `BaseKernelTypes` that involve Matern52 kernels).

Example kernels:

1D data:
- Linear + SE kernel: `[[BaseKernelTypes.LIN,BaseKernelTypes.SE]]`
- Linear + SE x PER kernel:  `[[BaseKernelTypes.LIN,BaseKernelTypes.SE_MULT_PER]]`

2D data:
- SE1 x SE2 (RBF kernel): `[[BaseKernelTypes.SE],[BaseKernelTypes.SE]]`
- (LIN1 + SE1) x SE2: `[[BaseKernelTypes.LIN,BaseKernelTypes.SE],[BaseKernelTypes.SE]]`

In case the dataset is multidimensional and the outer list has only one element e.g. `[[BaseKernelTypes.LIN,BaseKernelTypes.SE]]` - this is interpreted by the `GPModelAmortizedStructured` object as applying the same kernel structure to each dimension in the dataset.

In [None]:
input_kernel = [[BaseKernelTypes.SE,BaseKernelTypes.LIN]]

### Generate simulated dataset

Here we create an example dataset - generated from a GP with ground truth kernel

In [None]:
simulator_kernel = [[BaseKernelTypes.MATERN52,BaseKernelTypes.LIN]] 

########### Dataset configurations ###########

# range from which gt data is generated - input data of amortization should lie between 0.0 and 1.0 but test data can lie outside
data_a = -0.5
data_b = 1.5 

n_data = 10 # number of training datapoint - in (0.0,1.0)
n_test = 400 # number of test datapoints - uniform random in (data_a,data_b)
observation_noise = 0.05 # obersation noise that is added to the data

n_sim = 500 # number of simulated points in (data_a,data_b) - can be set to a high number - n_data is drawn from this set
simulator = Simulator(data_a,data_b,NOISE_VARIANCE_EXPONENTIAL_LAMBDA)
kernel_expression = transform_kernel_list_to_expression(simulator_kernel)
simulated_dataset = simulator.create_sample(n_sim,n_test,kernel_expression,observation_noise)
simulated_dataset.add_kernel_list_gt(simulator_kernel)
input_dimension = simulated_dataset.get_input_dimension()
x_data,y_data = simulated_dataset.get_dataset()
y_data =np.expand_dims(y_data[(x_data >=0.0) & (x_data<=1.0)],axis=1)[:n_data]
x_data =np.expand_dims(x_data[(x_data >=0.0) & (x_data<=1.0)],axis=1)[:n_data]
x_test,y_test = simulated_dataset.get_test_dataset()
_,f_test = simulated_dataset.get_ground_truth_f()



### Make inference and prediction

We use the `BaseModel` function to do inference and prediction. First we set the input kernel in the `GPModelAmortizedStructured` object via `.set_kernel_list(input_kernel)`. This configures the kernel inside the amortization network. When calling `infer(x_data,y_data)` on an `GPModelAmortizedStructured` a forward pass through the amortization network is done to predict all hyperparameters - this also caches the dataset for prediction. When calling `predictive_dist(x_test)`, we evaluated the predictive distribtion of the GP with the predicted hyperparameter and get predictive mus and sigmas.

In [None]:

from amorstructgp.utils.plotter import PlotterPlotly


model.set_kernel_list(input_kernel)
kernel_parameters,noise_variances =model.get_predicted_parameters(x_data,y_data)
model.infer(x_data,y_data)
pred_mu,pred_sigma = model.predictive_dist(x_test)
nll_model = -1*np.mean(model.predictive_log_likelihood(x_test,y_test))

plotter = PlotterPlotly(1)
plotter.add_gt_function(x_test,f_test,"red",0)
plotter.add_predictive_dist(np.squeeze(x_test),np.squeeze(pred_mu),np.squeeze(pred_sigma),0)
plotter.add_datapoints(x_data,y_data,"limegreen",0)
plotter.show()

display("Predicted parameters:")
display(kernel_parameters)
display(torch.sqrt(noise_variances))
display("GT parameters:")
display(simulated_dataset.get_gt_kernel_parameter_list())
display(simulated_dataset.get_observation_noise())
display(f"NLL amor model: {nll_model}")
