## Train MM / explore with random sampling for HP tuning


In [1]:
import matplotlib

matplotlib.use("Agg")

import logging

logger = logging.getLogger()
logger.disabled = True

import os
from agent import DQNAgent
from tqdm.auto import tqdm
import random
import itertools

# Number of combinations you want
num_combinations = 500  # Change this to however many combinations you need

# default
room_size = "xl"
capacity_max = 12
batch_size = 1000
terminates_at = 99
num_iterations = (terminates_at + 1) * 10
validation_starts_at = 0

prob_type = (
    "non-equal-object-probs" if "different-prob" in room_size else "equal-object-probs"
)
root_path = (
    f"./training-results/{prob_type}/dqn/room_size={room_size}/capacity={capacity_max}/"
)

# random
test_seed_ = [i for i in range(num_combinations)]
target_update_interval_ = [10]
gamma_mm_ = [0.99, 0.990]
gamma_explore_ = [0.99, 0.999]
semantic_decay_factor_ = [0.8]
embedding_dim_ = [32]
relu_for_attention_ = [False]
concat_embeddings_ = [True]

replay_buffer_size_ = [
    1000,
    # terminates_at,
    # 200,
    # num_iterations // 10,
    # num_iterations // 100,
]
warm_start_ = [
    1000,
    # terminates_at,
    # num_iterations // 10,
    # num_iterations // 20,
    # num_iterations // 200,
]


# Generate all combinations
params_all = list(
    itertools.product(
        test_seed_,
        target_update_interval_,
        gamma_mm_,
        gamma_explore_,
        semantic_decay_factor_,
        replay_buffer_size_,
        warm_start_,
        embedding_dim_,
        relu_for_attention_,
        concat_embeddings_,
    )
)

random_combinations = random.sample(params_all, num_combinations)

for i, params in tqdm(enumerate(random_combinations)):
    (
        test_seed,
        target_update_interval,
        gamma_mm,
        gamma_explore,
        semantic_decay_factor,
        replay_buffer_size,
        warm_start,
        embedding_dim,
        relu_for_attention,
        concat_embeddings,
    ) = params

    params_dict = {
        "env_str": "room_env:RoomEnv-v2",
        "num_iterations": num_iterations,
        "replay_buffer_size": replay_buffer_size,
        "validation_starts_at": validation_starts_at,
        "warm_start": warm_start,
        "batch_size": batch_size,
        "target_update_interval": target_update_interval,
        "epsilon_decay_until": num_iterations,
        "max_epsilon": 1.0,
        "min_epsilon": 0.01,
        "gamma": {"mm": gamma_mm, "explore": gamma_explore},
        "capacity": {
            "episodic": capacity_max // 2,
            "semantic": capacity_max // 2,
            "short": 1,
        },
        "pretrain_semantic": False,
        "semantic_decay_factor": semantic_decay_factor,
        "lstm_params": {
            "num_layers": 2,
            "embedding_dim": embedding_dim,
            "hidden_size": embedding_dim,
            "bidirectional": False,
            "max_timesteps": terminates_at + 1,
            "max_strength": terminates_at + 1,
            "relu_for_attention": relu_for_attention,
            "concat_embeddings": concat_embeddings,
        },
        "mlp_params": {
            "hidden_size": embedding_dim,
            "num_hidden_layers": 1,
            "dueling_dqn": True,
        },
        "num_samples_for_results": {"val": 5, "test": 10},
        "validation_interval": 1,
        "plotting_interval": 50,
        "train_seed": test_seed + 5,
        "test_seed": test_seed,
        "device": "cpu",
        "qa_function": "episodic_semantic",
        "explore_policy_heuristic": "avoid_walls",
        "env_config": {
            "question_prob": 1.0,
            "terminates_at": terminates_at,
            "randomize_observations": "objects",
            "room_size": room_size,
            "rewards": {"correct": 1, "wrong": 0, "partial": 0},
            "make_everything_static": False,
            "num_total_questions": 1000,
            "question_interval": 5,
            "include_walls_in_observations": True,
        },
        "ddqn": True,
        "default_root_dir": root_path,
    }

    agent = DQNAgent(**params_dict)
    agent.train()

  plt.legend(loc="upper left")
  plt.show()
0it [00:16, ?it/s]


KeyboardInterrupt: 

In [6]:
agent.env.question_interval

5

## Run fixed combinations


In [None]:
import matplotlib

matplotlib.use("Agg")

import logging

logger = logging.getLogger()
logger.disabled = True

import os
from agent import DQNAgent
from tqdm.auto import tqdm
import random
import itertools

room_size = "xl"
terminates_at = 99 
num_iterations = (terminates_at + 1) * 100
replay_buffer_size = num_iterations // 10
warm_start = num_iterations // 100
validation_starts_at = 0
batch_size = 32
target_update_interval = 100
gamma_mm = 0.999
gamma_explore = 0.99
embedding_dim = 32
semantic_decay_factor = 0.8


for test_seed in [0, 1, 2, 3, 4]:
    for capacity_max in [12, 24, 48]:

        prob_type = (
            "non-equal-object-probs"
            if "different-prob" in room_size
            else "equal-object-probs"
        )

        root_path = (
            f"./training-results/{prob_type}/dqn/"
            f"room_size={room_size}/capacity={capacity_max}/"
        )


        for agent_type in ["hybrid", "semantic", "episodic"]:

            if agent_type == "hybrid":
                capacity = {
                    "episodic": capacity_max // 2,
                    "semantic": capacity_max // 2,
                    "short": 1,
                }
            elif agent_type == "episodic":
                capacity = {"episodic": capacity_max, "semantic": 0, "short": 1}
            elif agent_type == "semantic":
                capacity = {"episodic": 0, "semantic": capacity_max, "short": 1}
            else:
                raise ValueError(f"Unknown agent_type: {agent_type}")

            params_dict = {
                "env_str": "room_env:RoomEnv-v2",
                "num_iterations": num_iterations,
                "replay_buffer_size": replay_buffer_size,
                "validation_starts_at": validation_starts_at,
                "warm_start": warm_start,
                "batch_size": batch_size,
                "target_update_interval": target_update_interval,
                "epsilon_decay_until": num_iterations,
                "max_epsilon": 1.0,
                "min_epsilon": 0.01,
                "gamma": {"mm": gamma_mm, "explore": gamma_explore},
                "capacity": capacity,
                "pretrain_semantic": False,
                "semantic_decay_factor": semantic_decay_factor,
                "lstm_params": {
                    "num_layers": 2,
                    "embedding_dim": embedding_dim,
                    "hidden_size": embedding_dim,
                    "bidirectional": False,
                    "max_timesteps": terminates_at + 1,
                    "max_strength": terminates_at + 1,
                    "relu_for_attention": True,
                },
                "mlp_params": {
                    "hidden_size": embedding_dim,
                    "num_hidden_layers": 1,
                    "dueling_dqn": True,
                },
                "num_samples_for_results": {"val": 5, "test": 10},
                "validation_interval": 1,
                "plotting_interval": 50,
                "train_seed": test_seed + 5,
                "test_seed": test_seed,
                "device": "cpu",
                "qa_function": "episodic_semantic",
                "explore_policy_heuristic": "avoid_walls",
                "env_config": {
                    "question_prob": 1.0,
                    "terminates_at": terminates_at,
                    "randomize_observations": "objects",
                    "room_size": room_size,
                    "rewards": {"correct": 1, "wrong": 0, "partial": 0},
                    "make_everything_static": False,
                    "num_total_questions": 1000,
                    "question_interval": 5,
                    "include_walls_in_observations": True,
                },
                "ddqn": True,
                "default_root_dir": root_path,
            }

            agent = DQNAgent(**params_dict)
            agent.train()