In [1]:
import yaml
from AntController.HyperparamSearchUtils import create_data_from_data_stores, \
            create_actor_func_from_hyper_param_search, \
            create_value_func_from_hyper_param_search, \
            pretrain_predictor_as_value_func, \
            get_value_func_loss

from AntController.AntEnvironment import EpisodeData
from AntController.HaikuPredictor import HaikuPredictor
import os
import numpy as np

In [2]:
actor_config_path = "AntController/configs/actor_hyperparam_search_config_no_sensor.yaml"
critic_config_path = "AntController/configs/critic_hyperparam_search_config_no_sensor.yaml"
training_data_dir = "TrainingData/Fixed_Walk_With_Sensor"
data = create_data_from_data_stores(training_data_dir, sensor_enabled = False)


In [3]:
rewards, states, actions, shifted_states, is_non_terminal = data
discount = 0.9
test_pivot = int(0.1 * len(data[0]))

In [5]:
with open(actor_config_path) as file:
    actor_config = yaml.load(file, Loader=yaml.FullLoader)
best_actor, best_actor_test_error, best_actor_config = \
    create_actor_func_from_hyper_param_search(actor_config, data)

Running random search on actor...


Trained Model 255 with Loss 8.14543. Current Optimal Loss: 0.12633: 100%|██████████| 256/256 [26:49<00:00,  6.29s/it] 


In [4]:
with open(critic_config_path) as file:
    critic_config = yaml.load(file, Loader=yaml.FullLoader)
best_critic, best_critic_test_error, best_critic_config = \
    create_value_func_from_hyper_param_search(critic_config, data)


Running random search on value critic...


Trained Model 63 with Loss 0.09431. Current Optimal Loss: 0.00639: 100%|██████████| 64/64 [08:34<00:00,  8.03s/it]


In [5]:
import pickle


def save_models(path, config, model):
    saved_params = (config, model.params, model.optimizer_state)
    pickle.dump(saved_params, open( path, "wb" ) )
    
def get_model(path):
    config, params, optimizer_state = pickle.load(open(path, "rb" ))
    predictor = HaikuPredictor.generate_controller_from_config(config)
    predictor.params = params
    predictor.optimizer_state = optimizer_state
    return predictor, config


In [6]:
save_dir = "AntController/configs/"

In [9]:
save_models(os.path.join(save_dir, "selected_actor_no_sensor.p"), best_actor_config, best_actor)

In [8]:
save_models(os.path.join(save_dir, "selected_critic_no_sensor.p"), best_critic_config, best_critic)

In [15]:

selected_critic_no_sensor, selected_critic_config = get_model(os.path.join(save_dir, "selected_critic_no_sensor.p"))

In [22]:
get_value_func_loss(selected_critic_no_sensor, discount, data, test_pivot)

DeviceArray(0.00831687, dtype=float32)

In [12]:
epochs = 1024
batch_size = 256


In [None]:
pretrain_predictor_as_value_func(
        selected_critic_no_sensor, discount, data, epochs, batch_size, test_pivot
)

In [20]:
import optax


In [None]:
selected_critic.learning_rate

In [6]:
selected_actor, selected_actor_config = get_model(os.path.join(save_dir, "selected_actor_no_sensor.p"))



In [7]:
selected_actor.learning_rate

0.0001

In [23]:
selected_actor.optimizer = optax.adam(0.00002, b1=0.5, b2=0.9)

In [24]:
def train_actor_data_with_noise(predictor, batch_size, epochs, data, labels):
        losses = []
        for _ in range(epochs):
            noise = np.pad(np.random.normal(0,0.07, (batch_size, 8)), ((0,0), (0,4)), 'constant',constant_values= 0)
            batch_index = np.random.choice(range(len(data)), batch_size)
            loss = predictor.train_batch(data[batch_index] + noise, labels[batch_index])
            losses.append(loss)
        return np.mean(losses)/batch_size

In [25]:
for _ in range(1000):
    print("train error: ", train_actor_data_with_noise(selected_actor, batch_size, epochs, states, actions))

train error:  1.0541892051696777
train error:  1.0489362478256226
train error:  1.0425949096679688
train error:  1.0352531671524048
train error:  1.0340733528137207
train error:  1.0367422103881836
train error:  1.0444979667663574
train error:  1.0389244556427002
train error:  1.0475826263427734
train error:  1.0467956066131592
train error:  1.0444622039794922
train error:  1.0451042652130127
train error:  1.0444684028625488
train error:  1.0386672019958496
train error:  1.0459175109863281
train error:  1.0338637828826904
train error:  1.0355947017669678
train error:  1.0377000570297241
train error:  1.0493308305740356
train error:  1.0435633659362793
train error:  1.0411748886108398
train error:  1.0383318662643433
train error:  1.0365720987319946
train error:  1.0391643047332764
train error:  1.044542670249939
train error:  1.0330193042755127
train error:  1.035019874572754
train error:  1.0421440601348877
train error:  1.039217472076416
train error:  1.0341404676437378
train error: 

train error:  0.9908437728881836
train error:  1.0009737014770508
train error:  1.0029460191726685
train error:  1.004813551902771
train error:  0.9949741363525391
train error:  1.0012695789337158
train error:  0.9941476583480835
train error:  0.9975059628486633
train error:  0.996556282043457
train error:  0.9955832958221436
train error:  0.9943230152130127
train error:  1.000985026359558
train error:  0.9966117143630981
train error:  0.9984477758407593
train error:  1.005676031112671
train error:  1.000247836112976
train error:  0.9954559803009033
train error:  0.998216450214386
train error:  0.9905585050582886
train error:  0.9921779036521912
train error:  1.0029617547988892
train error:  0.9920607805252075
train error:  0.993502140045166
train error:  1.0026664733886719
train error:  0.9936804175376892
train error:  0.9968738555908203
train error:  1.0003834962844849
train error:  0.9994852542877197
train error:  0.9831328392028809
train error:  0.9923738241195679
train error:  0.9

KeyboardInterrupt: 

In [26]:
save_models(os.path.join(save_dir, "selected_actor_trained_with_noise_no_sensor.p"), selected_actor_config, selected_actor)

In [1]:
import numpy as np
from AntController.JaxUtils import *