In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.optim as optim
import os
import glob
from tqdm import tqdm

from toddler.action_coding import mass_answers, force_answers
from toddler.models import ValueNetwork
from toddler.RecurrentWorker import train
from toddler.validate import validate

from isaac.models import ComplexRNNModel

from toddler.simulator.config import generate_every_world_configuration, generate_cond

from generate_passive_simulations import get_configuration_answer

In [2]:
model_directory = "models/grab_a_puck/"
data_directory = "grab_a_puck_plots/"

In [3]:
from isaac.utils import get_cuda_device_if_available
device = get_cuda_device_if_available()
print(device)

cuda:0


In [4]:
discount_factor = 0.95

every_conf = generate_every_world_configuration()
every_world_answer = np.array(list(map(get_configuration_answer, every_conf)))
n_configurations = len(every_conf)

train_size = 0.7
val_size = 0.15
test_size = 0.15

all_indices = np.arange(n_configurations)
train_indices, not_train_indices = train_test_split(all_indices, train_size=train_size,
                                                    random_state=0, stratify=every_world_answer)
val_indices, test_indices = train_test_split(not_train_indices, train_size=0.5,                
                                             random_state=0,
                                             stratify=every_world_answer[not_train_indices])

N_WORLDS = 100
timeout = 1800
print("N_WORLDS", N_WORLDS)

torch.manual_seed(0)
np.random.seed(0)
repeated_val_indices = np.random.choice(val_indices, N_WORLDS, replace=True)
val_cond = generate_cond(every_conf[repeated_val_indices])

experience_replay = ()
agent_answers = ()

n_bodies = 1
action_repeat = 1
starting_step = 0
starting_episode = 0

for cond in val_cond:
    cond["timeout"] = timeout
    cond["forces"] = [[0]]


validation_dfs = []
for seed in [0, 42, 72]:
    
    this_seed_model_directory = model_directory + str(seed) + "/"
    
    for model in tqdm(glob.glob(this_seed_model_directory+"*")):
        episode_number = model.split("/")[-1].split("_")[0]
    
        net_params = {"input_dim":10, "hidden_dim":25, "n_layers":4, "output_dim":6, "dropout":0.0}
        value_network = ValueNetwork(**net_params).to(torch.device(device))
        value_network.load_state_dict(torch.load(model))
        optimizer = optim.Adam(value_network.parameters(), lr=5e-4)

        agent_cond = val_cond

        valArgs = {"value_network": value_network, "val_cond": agent_cond, 
                   "timeout": timeout, "n_bodies": n_bodies,
                   "action_repeat": action_repeat, "print_stats":False,
                   "device": device, "reward_control": True, "done_with_control": True, 
                    "reward_not_controlling_negatively": True, "remove_features_in_index": [2, 3]}

        validation_data = validate(**valArgs)

        validation_data = {stat+"_"+attr: [f(validation_data[attr])] for attr in ["control", "episode_length"] 
                           for stat, f in zip(["avg", "std"], [np.mean, np.std])}
        
        df = pd.DataFrame.from_dict(validation_data)
        df["seed"] = seed
        df["episode"] = episode_number
        validation_dfs.append(df)

  0%|          | 0/34 [00:00<?, ?it/s]

N_WORLDS 1


  3%|▎         | 1/34 [00:02<01:14,  2.26s/it]

NO ANSWER
NO ANSWER


  policy = np.array(policy) / sum(policy)
  selected_action = np.random.choice(possible_actions, p=policy)
 12%|█▏        | 4/34 [00:03<00:37,  1.23s/it]

NO ANSWER
NO ANSWER


 15%|█▍        | 5/34 [00:03<00:26,  1.11it/s]

NO ANSWER
NO ANSWER


 21%|██        | 7/34 [00:03<00:18,  1.48it/s]

NO ANSWER


 26%|██▋       | 9/34 [00:04<00:13,  1.88it/s]

NO ANSWER
NO ANSWER


 32%|███▏      | 11/34 [00:05<00:10,  2.29it/s]

NO ANSWER
NO ANSWER


 35%|███▌      | 12/34 [00:06<00:11,  1.86it/s]

NO ANSWER


 38%|███▊      | 13/34 [00:06<00:10,  2.00it/s]

NO ANSWER


 44%|████▍     | 15/34 [00:07<00:08,  2.24it/s]

NO ANSWER
NO ANSWER


 47%|████▋     | 16/34 [00:07<00:06,  2.63it/s]

NO ANSWER
NO ANSWER


 59%|█████▉    | 20/34 [00:08<00:04,  3.35it/s]

NO ANSWER
NO ANSWER
NO ANSWER
NO ANSWER


 74%|███████▎  | 25/34 [00:09<00:02,  4.13it/s]

NO ANSWER
NO ANSWER
NO ANSWER
NO ANSWER
NO ANSWER


 82%|████████▏ | 28/34 [00:10<00:01,  4.13it/s]

NO ANSWER
NO ANSWER


 85%|████████▌ | 29/34 [00:11<00:02,  2.36it/s]

NO ANSWER


 88%|████████▊ | 30/34 [00:11<00:01,  2.58it/s]

NO ANSWER


 91%|█████████ | 31/34 [00:12<00:01,  1.99it/s]

NO ANSWER
NO ANSWER


 97%|█████████▋| 33/34 [00:13<00:00,  2.07it/s]

NO ANSWER


100%|██████████| 34/34 [00:14<00:00,  1.74it/s]
  0%|          | 0/34 [00:00<?, ?it/s]

NO ANSWER


  3%|▎         | 1/34 [00:00<00:26,  1.23it/s]

NO ANSWER
NO ANSWER


 15%|█▍        | 5/34 [00:01<00:14,  1.98it/s]

NO ANSWER
NO ANSWER
NO ANSWER


 21%|██        | 7/34 [00:02<00:12,  2.19it/s]

NO ANSWER
NO ANSWER


 26%|██▋       | 9/34 [00:02<00:06,  3.59it/s]

NO ANSWER
NO ANSWER


 29%|██▉       | 10/34 [00:03<00:05,  4.25it/s]

NO ANSWER
NO ANSWER


 38%|███▊      | 13/34 [00:03<00:04,  4.30it/s]

NO ANSWER
NO ANSWER
NO ANSWER


 47%|████▋     | 16/34 [00:04<00:04,  4.06it/s]

NO ANSWER
NO ANSWER


 50%|█████     | 17/34 [00:05<00:03,  4.49it/s]

NO ANSWER


 53%|█████▎    | 18/34 [00:05<00:06,  2.53it/s]

NO ANSWER


 56%|█████▌    | 19/34 [00:06<00:07,  2.00it/s]

NO ANSWER


 59%|█████▉    | 20/34 [00:07<00:06,  2.05it/s]

NO ANSWER


 62%|██████▏   | 21/34 [00:07<00:07,  1.77it/s]

NO ANSWER


 68%|██████▊   | 23/34 [00:08<00:05,  2.11it/s]

NO ANSWER
NO ANSWER


 74%|███████▎  | 25/34 [00:09<00:03,  2.38it/s]

NO ANSWER
NO ANSWER


 76%|███████▋  | 26/34 [00:10<00:04,  1.90it/s]

NO ANSWER


 85%|████████▌ | 29/34 [00:10<00:01,  2.93it/s]

NO ANSWER
NO ANSWER
NO ANSWER


 88%|████████▊ | 30/34 [00:11<00:01,  2.99it/s]

NO ANSWER


 91%|█████████ | 31/34 [00:11<00:01,  2.19it/s]

NO ANSWER
NO ANSWER


100%|██████████| 34/34 [00:12<00:00,  2.67it/s]
  0%|          | 0/34 [00:00<?, ?it/s]

NO ANSWER
NO ANSWER


  6%|▌         | 2/34 [00:00<00:04,  7.74it/s]

NO ANSWER
NO ANSWER
NO ANSWER


 12%|█▏        | 4/34 [00:00<00:03,  9.36it/s]

NO ANSWER


 15%|█▍        | 5/34 [00:01<00:08,  3.28it/s]

NO ANSWER
NO ANSWER


 21%|██        | 7/34 [00:01<00:08,  3.04it/s]

NO ANSWER
NO ANSWER


 26%|██▋       | 9/34 [00:02<00:06,  3.77it/s]

NO ANSWER
NO ANSWER


 32%|███▏      | 11/34 [00:02<00:07,  3.24it/s]

NO ANSWER


 38%|███▊      | 13/34 [00:03<00:07,  2.86it/s]

NO ANSWER
NO ANSWER


 41%|████      | 14/34 [00:03<00:05,  3.59it/s]

NO ANSWER


 44%|████▍     | 15/34 [00:04<00:08,  2.33it/s]

NO ANSWER


 47%|████▋     | 16/34 [00:05<00:09,  1.92it/s]

NO ANSWER


 50%|█████     | 17/34 [00:06<00:10,  1.70it/s]

NO ANSWER


 53%|█████▎    | 18/34 [00:06<00:10,  1.58it/s]

NO ANSWER


 56%|█████▌    | 19/34 [00:07<00:08,  1.75it/s]

NO ANSWER


 59%|█████▉    | 20/34 [00:08<00:08,  1.61it/s]

NO ANSWER


 62%|██████▏   | 21/34 [00:08<00:08,  1.52it/s]

NO ANSWER


 65%|██████▍   | 22/34 [00:09<00:08,  1.47it/s]

NO ANSWER
NO ANSWER


 71%|███████   | 24/34 [00:10<00:05,  1.69it/s]

NO ANSWER


 74%|███████▎  | 25/34 [00:11<00:05,  1.57it/s]

NO ANSWER


 76%|███████▋  | 26/34 [00:11<00:04,  1.68it/s]

NO ANSWER


 79%|███████▉  | 27/34 [00:12<00:04,  1.58it/s]

NO ANSWER
NO ANSWER
NO ANSWER
NO ANSWER


 97%|█████████▋| 33/34 [00:13<00:00,  2.67it/s]

NO ANSWER
NO ANSWER
NO ANSWER


100%|██████████| 34/34 [00:13<00:00,  2.08it/s]

NO ANSWER





In [5]:
validation_dfs = pd.concat(validation_dfs)
validation_dfs.to_hdf(data_directory+"validation_data.h5", key="validation_data")