In [None]:
import logging

logger = logging.getLogger()
logger.disabled = True

from copy import deepcopy
import random

from tqdm.auto import tqdm

from agent import DQNAgent


train_configs = []
for pretrain_semantic in [False]:
    for gamma in [0.5, 0.75]:
        for test_seed in [0, 1, 2, 3, 4]:
            for batch_size in [256, 512, 1024]:
                for ddqn in [True, False]:
                    for dueling_dqn in [True, False]:
                        params = {
                            "env_str": "room_env:RoomEnv-v1",
                            "env_config": {
                                "des_size": "l",
                                "question_prob": 1.0,
                                "allow_random_human": False,
                                "allow_random_question": False,
                                "check_resources": True,
                            },
                            "num_iterations": 128 * 20,
                            "replay_buffer_size": 128 * 20,
                            "epsilon_decay_until": 128 * 20,
                            "warm_start": 128 * 10,
                            "batch_size": batch_size,
                            "target_update_interval": 10,
                            "max_epsilon": 1.0,
                            "min_epsilon": 0.1,
                            "gamma": gamma,
                            "capacity": {
                                "episodic": 16,
                                "semantic": 16,
                                "short": 1,
                            },
                            "pretrain_semantic": pretrain_semantic,
                            "nn_params": {
                                "hidden_size": 64,
                                "num_layers": 2,
                                "embedding_dim": 64,
                                "v1_params": {
                                    "include_human": "sum",
                                    "human_embedding_on_object_location": False,
                                },
                                "v2_params": None,
                                "fuse_information": "sum",
                                "include_positional_encoding": True,
                                "max_timesteps": 128,
                                "max_strength": 128,
                            },
                            "run_test": True,
                            "num_samples_for_results": 10,
                            "plotting_interval": 10,
                            "train_seed": test_seed + 5,
                            "test_seed": test_seed,
                            "device": "cpu",
                            "ddqn": ddqn,
                            "dueling_dqn": dueling_dqn,
                            "default_root_dir": f"./training_results/",
                        }
                        train_configs.append(deepcopy(params))

random.shuffle(train_configs)
for params in tqdm(train_configs):
    agent = DQNAgent(**params)
    agent.train()

In [1]:
from glob import glob
from explicit_memory.utils import read_yaml
import pandas as pd

results_all = []
for results_path in glob("./training_results/DQN/*/results.yaml"):
    train_path = results_path.replace("results.yaml", "train.yaml")
    train = read_yaml(train_path)
    results = read_yaml(results_path)
    results_all.append(
        {
            "gamma": train["gamma"],
            "batch_size": train["batch_size"],
            "ddqn": train["ddqn"],
            "dueling_dqn": train["dueling_dqn"],
            "test_score": results["test_score"]["mean"],
            "path": results_path.split("/")[-2],
        }
    )

df = pd.DataFrame(results_all)
df_sorted = df.sort_values(by="test_score", ascending=False)
df_sorted[:20]

Unnamed: 0,gamma,batch_size,ddqn,dueling_dqn,test_score,path
107,0.9,32,False,False,84.2,2024-02-23 07:15:37.385522
199,0.5,256,False,False,83.6,2024-02-21 14:44:10.721757
65,0.5,256,False,False,80.2,2024-02-20 19:26:57.937234
152,0.5,256,True,True,77.8,2024-02-21 04:04:46.595617
360,0.9,32,True,True,76.8,2024-02-25 01:18:50.172876
186,0.5,32,False,False,76.8,2024-02-21 02:58:53.944084
10,0.9,128,True,True,76.8,2024-02-25 12:08:19.884828
35,0.75,64,False,False,76.4,2024-02-23 00:36:32.949645
226,0.9,128,True,False,75.6,2024-02-23 15:16:02.490512
235,0.75,1024,False,False,75.0,2024-02-21 16:57:25.989463


In [6]:
results_path.split('/')[-2]

'2024-02-23 00:42:55.690184'