## Train HumemAI-Unified

In [None]:
import matplotlib

matplotlib.use("Agg")

import logging

logger = logging.getLogger()
logger.disabled = True

import os
from agent import DQNAgent
from tqdm.auto import tqdm
import random
import itertools


room_size = "xxl-different-prob"
terminates_at = 199
num_iterations = (terminates_at + 1) * 200
replay_buffer_size = num_iterations // 10
batch_size = 32
target_update_interval = 20
gamma = {"mm": 0.95, "explore": 0.95}
semantic_decay_factor = 0.8
embedding_dim = 64
num_layers = 2
triple_qual_weight = 0.8

for test_seed in [0, 1, 2, 3, 4]:
    for capacity_max in [48, 96, 192]:
        prob_type = (
            "non-equal-object-probs"
            if "different-prob" in room_size
            else "equal-object-probs"
        )
        root_path = (
            f"./training-results/{prob_type}/dqn/"
            f"room_size={room_size}/capacity={capacity_max}/"
        )
        for pretrain_semantic in [False, "include_walls", "exclude_walls"]:
            params_dict = {
                "env_str": "room_env:RoomEnv-v2",
                "num_iterations": num_iterations,
                "replay_buffer_size": replay_buffer_size,
                "warm_start": batch_size,
                "batch_size": batch_size,
                "target_update_interval": target_update_interval,
                "epsilon_decay_until": num_iterations,
                "max_epsilon": 1.0,
                "min_epsilon": 0.1,
                "gamma": gamma,
                "learning_rate": 0.001,
                "capacity": {"long": capacity_max, "short": 15},
                "pretrain_semantic": pretrain_semantic,
                "semantic_decay_factor": semantic_decay_factor,
                "dqn_params": {
                    "gcn_layer_params": {
                        "type": "stare",
                        "embedding_dim": embedding_dim,
                        "num_layers": num_layers,
                        "gcn_drop": 0.1,
                        "triple_qual_weight": triple_qual_weight,
                    },
                    "relu_between_gcn_layers": True,
                    "dropout_between_gcn_layers": False,
                    "mlp_params": {
                        "num_hidden_layers": num_layers,
                        "dueling_dqn": True,
                    },
                },
                "num_samples_for_results": {"val": 5, "test": 10},
                "validation_interval": 1,
                "plotting_interval": 50,
                "train_seed": test_seed + 5,
                "test_seed": test_seed,
                "device": "cpu",
                "qa_function": "latest_strongest",
                "env_config": {
                    "question_prob": 1.0,
                    "terminates_at": terminates_at,
                    "randomize_observations": "all",
                    "room_size": room_size,
                    "rewards": {"correct": 1, "wrong": 0, "partial": 0},
                    "make_everything_static": False,
                    "num_total_questions": 1000,
                    "question_interval": 1,
                    "include_walls_in_observations": True,
                },
                "intrinsic_explore_reward": 0,
                "ddqn": True,
                "default_root_dir": root_path,
                "explore_policy": "rl",
                "mm_policy": "rl",
                "scale_reward": False,
            }

            agent = DQNAgent(**params_dict)
            agent.train()