In [None]:
import yaml
from AntController.HyperparamSearch import create_data_from_data_stores, \
            create_actor_func_from_hyper_param_search, \
            create_value_func_from_hyper_param_search, \
            pretrain_predictor_as_value_func, \
            get_value_func_loss

from AntController.AntEnvironment import EpisodeData
from AntController.HaikuPredictor import HaikuPredictor
import os
import numpy as np

In [None]:
actor_config_path = "AntController/configs/actor_hyperparam_search_config.yaml"
critic_config_path = "AntController/configs/critic_hyperparam_search_config.yaml"
training_data_dir = "TrainingData/Fixed_Walk_With_Sensor"
data = create_data_from_data_stores(training_data_dir)


In [None]:
rewards, states, actions, shifted_states, is_non_terminal = data
discount = 0.98
test_pivot = int(0.1 * len(data[0]))

In [None]:
with open(actor_config_path) as file:
    actor_config = yaml.load(file, Loader=yaml.FullLoader)
best_actor, best_actor_test_error, best_actor_config = \
    create_actor_func_from_hyper_param_search(actor_config, data)

In [None]:
with open(critic_config_path) as file:
    critic_config = yaml.load(file, Loader=yaml.FullLoader)
best_critic, best_critic_test_error, best_critic_config = \
    create_value_func_from_hyper_param_search(critic_config, data)


In [None]:
import pickle


def save_models(path, config, model):
    saved_params = (config, model.params, model.optimizer_state)
    pickle.dump(saved_params, open( path, "wb" ) )
    
def get_model(path):
    config, params, optimizer_state = pickle.load(open(path, "rb" ))
    predictor = HaikuPredictor.generate_controller_from_config(config)
    predictor.params = params
    predictor.optimizer_state = optimizer_state
    return predictor, config


In [None]:
save_dir = "AntController/configs/"

In [None]:
save_models(os.path.join(save_dir, "selected_actor.p"), best_actor_config, best_actor)

In [None]:
save_models(os.path.join(save_dir, "selected_critic.p"), best_critic_config, best_critic)

In [None]:
selected_critic = get_model(os.path.join(save_dir, "selected_critic.p"))

In [None]:
get_value_func_loss(selected_critic, discount, data, test_pivot)

In [None]:
epochs = 1024
batch_size = 256
pretrain_predictor_as_value_func(
        selected_critic, discount, data, epochs, batch_size, test_pivot
)

In [None]:
import optax
selected_critic.optimizer = optax.adam(0.00002, b1=0.5, b2=0.9)

In [None]:
selected_critic.learning_rate

In [None]:
selected_actor_, selected_actor_config = get_model(os.path.join(save_dir, "selected_actor.p"))

In [None]:
selected_actor.learning_rate

In [None]:
selected_actor.optimizer = optax.adam(0.0001, b1=0.5, b2=0.9)

In [None]:
for _ in range(1000):
    print("train error: ", train_actor_data_with_noise(selected_actor, batch_size, epochs, states, actions))

In [None]:
def train_actor_data_with_noise(predictor, batch_size, epochs, data, labels):
        losses = []
        for _ in range(epochs):
            noise = np.pad(np.random.normal(0,0.07, (batch_size, 8)), ((0,0), (0,14)), 'constant',constant_values= 0)
            batch_index = np.random.choice(range(len(data)), batch_size)
            loss = predictor.train_batch(data[batch_index] + noise, labels[batch_index])
            losses.append(loss)
        return np.mean(losses)/batch_size

In [None]:
save_models(os.path.join(save_dir, "selected_actor_trained_with_noise.p"), selected_actor_config, selected_actor)

In [1]:
import numpy as np
from AntController.JaxUtils import *

In [5]:
means = np.asarray([1,1,1])
variances = np.asarray([1,1,1])
values = np.asarray([[1,1.1,0.9], [1,1,1]])
normal_density(means, variances, values)

DeviceArray([[0.3989423 , 0.39695254, 0.39695254],
             [0.3989423 , 0.3989423 , 0.3989423 ]], dtype=float32)

In [7]:
-np.sum(np.log(normal_density(means, variances, values)))

5.5236316

In [None]:
np/