In [None]:

import gpflow
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from scipy.integrate import solve_ivp, odeint
from gpflow.utilities import print_summary, positive, to_default_float, set_trainable
from invariance_kernels import zero_mean, get_MOI
from invariance_functions import degree_of_freedom, get_GPR_model, get_damped_SHM_data, get_damped_pendulum_data, get_grid_of_points
from local_invariance_kernels import get_Damped_Polynomial_Local_Invariance
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

mean = zero_mean(2)
time_step = 0.1
training_time = 1
testing_time = 5

max_x = 3
min_x = 0.1
n_train = 3 
train_starting_position = np.random.uniform(min_x, max_x, (n_train))
train_starting_velocity = np.random.uniform(-max_x/5, max_x/5, (n_train))

test_starting_position = np.random.uniform(min_x, max_x)
test_starting_velocity = np.random.uniform(-max_x/5, max_x/5)

print(test_starting_position)
print(test_starting_velocity)

for gamma in [0.01, 0.05, 0.1]:
    print("current damping: %s" %gamma)
    data = get_damped_SHM_data(gamma, time_step, training_time, 1e-8, train_starting_position, train_starting_velocity) #switch
    test_data = get_damped_SHM_data(gamma, time_step, testing_time, 1e-8, [test_starting_position], [test_starting_velocity])
    mean_function = zero_mean(2)
    for jitter in [1e-5]:
        m = get_GPR_model(get_MOI(), mean_function, data, 100)
        print("%s, "%round(m.log_marginal_likelihood().numpy()))
        print(evaluate_model(m, test_data, time_step)[:2])
        try:
            kernel = get_Damped_Polynomial_Local_Invariance(5, 40, jitter, poly_f_d=2, poly_g_d=2)#switch
            m = get_GPR_model(kernel, mean, data, 300)
            print(round(m.log_marginal_likelihood().numpy()))
            evaluate_invariance = evaluate_model(m, test_data, time_step)
            print(evaluate_invariance[:2])
            print(kernel.f_poly.numpy())
            print(kernel.g_poly.numpy())
            print(m.kernel.epsilon.numpy())
        except tf.errors.InvalidArgumentError:
            print("jitter too small")
            break
