In [5]:
import jax
# Enable Float64 for more stable matrix inversions.
jax.config.update("jax_enable_x64", True)

from dataclasses import dataclass
import warnings
from typing import List, Union
import pandas as pd
import cola


import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
    Array,
    Float,
    install_import_hook,
    Num,
)
import tensorflow_probability.substrates.jax.bijectors as tfb
import optax as ox

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
from gpjax.typing import (
    Array,
    ScalarFloat,
)
from gpjax.distributions import GaussianDistribution
from gpjax.kernels import AdditiveKernel

import matplotlib.pyplot as plt
from matplotlib import rcParams
# plt.style.use(
#     "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
# )
# colors = rcParams["axes.prop_cycle"].by_key()["color"]
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    mean_squared_error,
    r2_score,
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
mpl.rcParams.update(mpl.rcParamsDefault)

#imports AdditiveConjugatePosterior
# from OAK import AdditiveConjugatePosterior
from ChangingOptimiser import *


from pprint import pprint
import numpy as np
def calculate_nlpd(predictive_mean, predictive_stddev, actual_values):
    variance = predictive_stddev ** 2
    nlpd_sum = 0.5 * np.log(2 * np.pi * variance) + ((predictive_mean - actual_values) ** 2) / (2 * variance)
    total_nlpd = np.sum(nlpd_sum)
    return total_nlpd

In [29]:
def best_optimiser(
        D: gpx.Dataset = None,
        optimiser_list: list[ox.GradientTransformation] = None, 
        num_iters = 100,
        key: jr.PRNGKey = jr.PRNGKey(0),
        noises = 0.01,
        lengthscales = 1.0,
        trainable_noise: bool = True,
        ):
    if not isinstance(num_iters, list):
        if isinstance(num_iters, int):
            num_iters = [num_iters]
        else:
            raise ValueError("num_iters should be a list of integers")
    if len(num_iters) == 1:
        num_iters = num_iters * len(optimiser_list)
    else:
        assert len(num_iters) == len(optimiser_list), "Number of iterations should be same as number of optimisers"
    
    if not isinstance(noises, list):
        if isinstance(noises, Union[float, int]):
            noises = [noises]
        else:
            raise ValueError("Noises should be a list of floats or something like that")
    if len(noises) == 1:
        noises = noises * len(optimiser_list)
    else:
        assert len(noises) == len(optimiser_list), "Number of iterations should be same as number of optimisers"

    feature_dimension = X.shape[1]
    number_of_optimisers = len(optimiser_list)

    if not isinstance(lengthscales, list):
        if isinstance(lengthscales, Union[float, int]):
            lengthscales = [[float(lengthscales)]*feature_dimension]
        else:
            raise ValueError("num_iters should be a list of integers")
    if len(lengthscales) == 1:
        lengthscales = lengthscales * len(optimiser_list)
    else:
        assert len(lengthscales) == len(optimiser_list), "Number of iterations should be same as number of optimisers"
    for index in range(number_of_optimisers):
        if isinstance(lengthscales[index], Union[float, int]):
            lengthscales[index] = [lengthscales[index]] * feature_dimension
        else:
            assert len(lengthscales[index]) == feature_dimension, "Lengthscales should be equal to feature dimension"

    # hi kishan

    opt_posteriors = []
    minimum_value = []
    index = list(range(number_of_optimisers))

    for i in range(number_of_optimisers):
        noise = noises[i]
        lengthscales_specific = lengthscales[i]
        print(i)
        meanf = gpx.mean_functions.Zero()
        base_kernels = [gpx.kernels.RBF(active_dims=[j],
            lengthscale=jnp.array(lengthscales_specific[j])
            ) for j in range(feature_dimension)]
        #base_kernels = [OrthogonalRBF(active_dims=[i], lengthscale=jnp.array([1.0])) for i in range(feature_dimension)]
        likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=noise)
        if not trainable_noise:
            likelihood = likelihood.replace_trainable(obs_stddev=False)
        obj = gpx.objectives.ConjugateLOOCV(negative=True)
        maximum_interaction_depth = 2
        kernel = AdditiveKernel(
            kernels=base_kernels,
            interaction_variances=jnp.array([1.0]*(maximum_interaction_depth + 1)) * jnp.var(D.y), 
            max_interaction_depth=maximum_interaction_depth, 
            )
        prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
        posterior = AdditiveConjugatePosterior(prior =prior, likelihood=likelihood)
        optimiser = optimiser_list[i]
        num_iter = num_iters[i]
        opt_posterior, history = gpx.fit(
            model=posterior,
            objective=obj,
            train_data=D,
            optim=optimiser,
            num_iters=num_iter,
            key=key, 
            safe=False,
            verbose=False)
        opt_posteriors.append(opt_posterior)
        minimum_value.append(history[-1])       
    values_to_return = sorted(zip(index,
                                optimiser_list, 
                                opt_posteriors, 
                                minimum_value,
                                noises, 
                                lengthscales
                                ), key = lambda x: x[3])
    print("Returning values")
    print("\n")
    return values_to_return 

def best_optimiser_over_datasets(D: list[gpx.Dataset] = None,
        optimiser_list: list[ox.GradientTransformation] = None, 
        num_iters = 100,
        key: jr.PRNGKey = jr.PRNGKey(0),
        noises = 0.01,
        lengthscales = 1.0,
        trainable_noise: bool = True,
        ):
    if not isinstance(D, list):
        raise ValueError("D should be a list of datasets")
    

In [32]:
def f1(x):
    return x**2

def f2(x):
    return jnp.sin(x)

def f3(x):
    return jnp.cos(x**2) 

def f4(x):
    return x**3

def f5(x):
    return jnp.exp(x)

def f6(x):
    return jnp.tanh(x)

def f7(x):
    return jnp.log(jnp.abs(x)+1)

def f8(x):
    return 0.5*x**2 + 0.2*x + 1

def f9(x):
    return 0.1*x**3 + 0.3*x**2 + 2*x + 1

def f10(x):
    return 0.1*x**4 + 0.5*x**3 + 1

lof = [f1,f2,f3,f4,f5, f6, f7, f8, f9, f10]

def f(x):
    return sum([lof[i](x[:,i:i+1]) for i in range(len(lof))])

n, noise = 400, 0.01
key = jax.random.PRNGKey(np.random.randint(50))  # Replace 12345 with any desired seed value
X = jr.uniform(key, (n, len(lof)))
y = f(X) + jr.normal(key, (n, 1)) * noise
#Select Data wanted
X = jnp.array(X)
# y = jnp.array(y).reshape(-1, 1) 
D = gpx.Dataset(X, y)

In [35]:
optimisers = [
    ox.adam(1.0),
    ox.adam(0.5),
    ox.nadamw(1.0),
    ox.nadamw(0.5),
    ox.yogi(1.0),
    ox.yogi(0.5),
]

num_iterations = 100
noise = 0.01
lengthscale = 1.0
trainable_noise = True

data_to_use = best_optimiser(
    D=D,
    optimiser_list= optimisers, 
    num_iters=num_iterations, 
    noises=noise,
    lengthscales=lengthscale, 
    trainable_noise=trainable_noise
    )

for index, optimiser, opt_posterior, minimum_value, noise, lengthscales in data_to_use:
    print(f"Index: {index}")
    print(f"Minimum Value: {minimum_value}")
    print("\n")


0
1
2
3
4
5
Returning values


Index: 3
Minimum Value: -1297.5139108616438
Index: 0
Minimum Value: -1296.4733925103142
Index: 4
Minimum Value: -1296.3549298180199
Index: 1
Minimum Value: -1296.1153272670226
Index: 5
Minimum Value: -1296.0194085938574
Index: 2
Minimum Value: -1243.2194485265638
