## Train MM / explore with random sampling


In [1]:
import matplotlib

matplotlib.use("Agg")

import logging

logger = logging.getLogger()
logger.disabled = True

import os
import torch
import numpy as np

from agent import DQNAgent
from tqdm.auto import tqdm
import random
import itertools

# Number of combinations you want
num_combinations = 100  # Change this to however many combinations you need

# default
room_size = "xl-different-prob"
capacity_max = 12
batch_size = 8
terminates_at = 9
num_iterations = (terminates_at + 1) * 100
validation_starts_at = num_iterations // 2

prob_type = (
    "non-equal-object-probs" if "different-prob" in room_size else "equal-object-probs"
)
root_path = (
    f"./training-results/{prob_type}/dqn/room_size={room_size}/capacity={capacity_max}/"
)

root_path = "training-results/TRASH"

# random
test_seed_ = [i for i in range(num_combinations)]
target_update_interval_ = [10]
min_epsilon_ = [0.1]
gamma_ = [0.8, 0.9, 0.99]
semantic_decay_factor_ = [0.7, 0.9, 0.99]
pretrain_semantic_ = [False]

# Weights for agent_capacity_ elements
replay_buffer_size_ = [
    num_iterations,
    num_iterations // 2,
]
warm_start_ = [
    num_iterations // 2,
    num_iterations // 4,
    num_iterations // 10,
]


# Generate all combinations
params_all = list(
    itertools.product(
        test_seed_,
        target_update_interval_,
        min_epsilon_,
        gamma_,
        semantic_decay_factor_,
        pretrain_semantic_,
        replay_buffer_size_,
        warm_start_,
    )
)

# Random combinations with weighted agent_capacity_
random_combinations = random.sample(params_all, num_combinations)

for i, params in tqdm(enumerate(random_combinations)):
    (
        test_seed,
        target_update_interval,
        min_epsilon,
        gamma,
        semantic_decay_factor,
        pretrain_semantic,
        replay_buffer_size,
        warm_start,
    ) = params

    params_dict = {
        "env_str": "room_env:RoomEnv-v2",
        "num_iterations": num_iterations,
        "replay_buffer_size": replay_buffer_size,
        "validation_starts_at": validation_starts_at,
        "warm_start": warm_start,
        "batch_size": batch_size,
        "target_update_interval": target_update_interval,
        "epsilon_decay_until": num_iterations,
        "max_epsilon": 1.0,
        "min_epsilon": min_epsilon,
        "gamma": gamma,
        "capacity": {"long": capacity_max, "short": 15},
        "pretrain_semantic": pretrain_semantic,
        "semantic_decay_factor": semantic_decay_factor,
        "dqn_params": {
            "gcn_layer_params": {
                "type": "stare",
                "embedding_dim": 10,
                "num_layers": 2,
                "gcn_drop": 0.1,
                "triple_qual_weight": 0.8,
            },
            "relu_between_gcn_layers": True,
            "dropout_between_gcn_layers": True,
            "mlp_params": {"num_hidden_layers": 2, "dueling_dqn": True},
        },
        "num_samples_for_results": {"val": 5, "test": 10},
        "validation_interval": 5,
        "plotting_interval": 50,
        "train_seed": test_seed + 5,
        "test_seed": test_seed,
        "device": "cpu",
        "qa_function": "latest_strongest",
        "env_config": {
            "question_prob": 1.0,
            "terminates_at": terminates_at,
            "randomize_observations": "all",
            "room_size": room_size,
            "rewards": {"correct": 1, "wrong": 0, "partial": 0},
            "make_everything_static": False,
            "num_total_questions": 1000,
            "question_interval": 1,
            "include_walls_in_observations": True,
        },
        "ddqn": True,
        "default_root_dir": root_path,
    }

    agent = DQNAgent(**params_dict)
    agent.train()

  from .autonotebook import tqdm as notebook_tqdm
  logger.deprecation(
  logger.deprecation(
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")


Running on cpu
> [0;32m/home/tk/repos/agent-room-env-v2-gnn/agent/dqn/nn/gnn.py[0m(420)[0;36mprocess_batch[0;34m()[0m
[0;32m    418 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    419 [0;31m[0;34m[0m[0m
[0m[0;32m--> 420 [0;31m        [0mshort_memory_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0mshort_memory_idx[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    421 [0;31m        [0magent_entity_index[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0magent_entity_index[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    422 [0;31m        [0mnum_short_memories[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mnum_short_memories[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
tensor([[ 10,  11,  12,  10,  12,  11,  10,  11,  10

0it [00:45, ?it/s]


In [2]:
# torch.save(entity_embeddings, "entity_embeddings.pt")
# torch.save(relation_embeddings, "relation_embeddings.pt")
# torch.save(edge_index, "edge_index.pt")
# torch.save(edge_type, "edge_type.pt")
# torch.save(quals, "quals.pt")
# torch.save(short_memory_idx, "short_memory_idx.pt")
# torch.save(agent_entity_index, "agent_entity_index.pt")
# torch.save(num_short_memories, "num_short_memories.pt")

# load what's saved above:

entity_embeddings = torch.load("entity_embeddings.pt")
relation_embeddings = torch.load("relation_embeddings.pt")
edge_index = torch.load("edge_index.pt")
edge_type = torch.load("edge_type.pt")
quals = torch.load("quals.pt")
short_memory_idx = torch.load("short_memory_idx.pt")
agent_entity_index = torch.load("agent_entity_index.pt")
num_short_memories = torch.load("num_short_memories.pt")


print(f"Entity Embeddings Shape:\t{entity_embeddings.shape}")
print(f"Relation Embeddings Shape:\t{relation_embeddings.shape}")
print(f"Edge Index Shape:\t\t{edge_index.shape}")
print(f"Edge Type Shape:\t\t{edge_type.shape}")
print(f"Quals Shape:\t\t\t{quals.shape}")
print(f"Short Memory Index Shape:\t{short_memory_idx.shape}")
print(f"Agent Entity Index Shape:\t{agent_entity_index.shape}")
print(f"Num Short Memories Shape:\t{num_short_memories.shape}")



Entity Embeddings Shape:	torch.Size([91, 10])
Relation Embeddings Shape:	torch.Size([102, 10])
Edge Index Shape:		torch.Size([2, 178])
Edge Type Shape:		torch.Size([178])
Quals Shape:			torch.Size([3, 236])
Short Memory Index Shape:	torch.Size([44])
Agent Entity Index Shape:	torch.Size([8])
Num Short Memories Shape:	torch.Size([8])


In [3]:
quals[0, :]

tensor([ 10,  11,  10,  11,  10,  12,  11,  10,  11,  10,  12,  12,  11,  12,
         11,  12,  11,  12,  11,  10,  11,  10,  11,  10,  12,  11,  10,  11,
         10,  12,  12,  11,  12,  11,  12,  11,  12,  11,  23,  23,  23,  23,
         23,  24,  24,  24,  25,  24,  24,  25,  24,  24,  25,  23,  23,  23,
         23,  23,  24,  24,  24,  25,  24,  24,  25,  24,  24,  25,  36,  36,
         36,  36,  36,  36,  36,  36,  36,  36,  36,  36,  36,  36,  47,  47,
         47,  47,  47,  48,  48,  49,  49,  47,  47,  47,  47,  47,  48,  48,
         49,  49,  60,  61,  60,  61,  60,  62,  60,  62,  60,  62,  61,  61,
         62,  61,  62,  62,  61,  62,  62,  61,  60,  61,  60,  61,  60,  62,
         60,  62,  60,  62,  61,  61,  62,  61,  62,  62,  61,  62,  62,  61,
         73,  73,  73,  74,  73,  75,  73,  73,  75,  73,  74,  73,  73,  73,
         74,  73,  75,  73,  73,  75,  73,  74,  86,  86,  86,  86,  86,  87,
         88,  87,  87,  87,  88,  87,  87,  88,  87,  86,  86,  

In [4]:
quals[1, :]

tensor([ 8,  9,  8,  9,  8, 10,  9,  8,  9,  8, 10, 10, 11, 12, 10, 10, 12, 10,
        12,  8,  9,  8,  9,  8, 10,  9,  8,  9,  8, 10, 10, 11, 12, 10, 10, 12,
        10, 12, 22, 22, 22, 22, 22, 23, 23, 23, 24, 23, 23, 23, 23, 23, 23, 22,
        22, 22, 22, 22, 23, 23, 23, 24, 23, 23, 23, 23, 23, 23, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 40, 40, 40, 40, 40, 40, 40, 41,
        41, 40, 40, 40, 40, 40, 40, 40, 41, 41, 50, 51, 50, 52, 50, 52, 50, 52,
        50, 51, 51, 51, 53, 51, 54, 54, 52, 54, 55, 51, 50, 51, 50, 52, 50, 52,
        50, 52, 50, 51, 51, 51, 53, 51, 54, 54, 52, 54, 55, 51, 63, 63, 63, 63,
        63, 64, 63, 63, 64, 63, 63, 63, 63, 63, 63, 63, 64, 63, 63, 64, 63, 63,
        71, 71, 71, 71, 71, 72, 73, 72, 73, 73, 73, 74, 74, 73, 74, 71, 71, 71,
        71, 71, 72, 73, 72, 73, 73, 73, 74, 74, 73, 74, 83, 84, 83, 84, 85, 83,
        85, 84, 83, 86, 83, 87, 85, 85, 88, 85, 89, 89, 87, 89, 90, 85, 83, 84,
        83, 84, 85, 83, 85, 84, 83, 86, 

In [5]:
quals[2, :]

tensor([  0,   0,   1,   1,   2,   2,   2,   3,   3,   4,   4,   5,   6,   7,
          8,   8,   9,  10,  11,   0,   0,   1,   1,   2,   2,   2,   3,   3,
          4,   4,   5,   6,   7,   8,   8,   9,  10,  11,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  24,  25,  26,
         27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  54,  55,
         56,  57,  58,  59,  60,  54,  55,  56,  57,  58,  59,  60,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  68,  69,  70,  71,  72,  73,  74,
         75,  76,  86,  86,  87,  87,  88,  88,  89,  89,  90,  90,  90,  91,
         92,  93,  93,  94,  95,  96,  97,  97,  86,  86,  87,  87,  88,  88,
         89,  89,  90,  90,  90,  91,  92,  93,  93,  94,  95,  96,  97,  97,
        110, 111, 112, 112, 113, 113, 114, 115, 115, 116, 116, 110, 111, 112,
        112, 113, 113, 114, 115, 115, 116, 116, 124, 125, 126, 127, 128, 129,
        130, 131, 132, 133, 134, 135, 136, 137, 138, 124, 125, 1

In [8]:
edge_type

tensor([  0,   1,   0,   2,   3,   4,   0,   5,   6,   5,   7,   8,   9,   5,
         13,  14,  15,  16,  17,  13,  17,  15,  15,  18,  19,  20,  21,  22,
         18,  22,  20,  20,  26,  27,  28,  29,  30,  29,  28,  30,  30,  31,
         32,  33,  34,  35,  34,  33,  35,  35,  39,  39,  40,  41,  42,  39,
         43,  44,  44,  45,  46,  47,  44,  48,  52,  53,  54,  55,  55,  55,
         56,  57,  58,  59,  60,  60,  60,  61,  65,  66,  67,  68,  69,  66,
         69,  66,  67,  70,  71,  72,  73,  74,  71,  74,  71,  72,  78,  79,
         80,  81,  82,  78,  82,  81,  79,  82,  80,  82,  83,  84,  85,  86,
         87,  83,  87,  86,  84,  87,  85,  87,  91,  92,  93,  94,  95,  92,
         92,  91,  92,  95,  93,  91,  96,  97,  98,  99, 100,  97,  97,  96,
         97, 100,  98,  96])

In [2]:
torch.save(torch.tensor(0), "foo.pt")

In [3]:
len(agent.dqn.entities), len(agent.dqn.relations)

(68, 13)

In [4]:
agent.dqn.relations

['north',
 'east',
 'south',
 'west',
 'atlocation',
 'north_inv',
 'east_inv',
 'south_inv',
 'west_inv',
 'atlocation_inv',
 'current_time',
 'timestamp',
 'strength']

In [2]:
data_single = np.array([[['dep_001', 'atlocation', 'room_000', {'current_time': 0}],
        ['room_000', 'south', 'room_004', {'current_time': 0}],
        ['room_000', 'north', 'wall', {'current_time': 0}],
        ['agent', 'atlocation', 'room_000', {'current_time': 0}],
        ['dep_007', 'atlocation', 'room_000', {'current_time': 0}],
        ['room_000', 'west', 'wall', {'current_time': 0}],
        ['room_000', 'east', 'room_001', {'current_time': 0}]]],
      dtype=object)

print(data_single.shape)

data = np.array([list([['dep_007', 'atlocation', 'room_000', {'current_time': 2, 'strength': 1}], ['agent', 'atlocation', 'room_000', {'current_time': 2, 'strength': 1}], ['room_000', 'west', 'wall', {'current_time': 2, 'strength': 1}], ['room_000', 'north', 'wall', {'current_time': 2, 'strength': 1.8}], ['dep_001', 'atlocation', 'room_000', {'current_time': 2, 'timestamp': [0, 1]}], ['room_000', 'south', 'room_004', {'current_time': 2, 'timestamp': [1]}], ['room_000', 'east', 'room_001', {'current_time': 2, 'timestamp': [0], 'strength': 1}]]),
       list([['room_005', 'east', 'room_006', {'current_time': 5}], ['agent', 'atlocation', 'room_005', {'current_time': 5}], ['room_005', 'south', 'wall', {'current_time': 5}], ['room_005', 'north', 'room_001', {'current_time': 5}], ['room_005', 'west', 'room_004', {'current_time': 5}], ['room_000', 'east', 'room_001', {'timestamp': [0], 'strength': 1}], ['agent', 'atlocation', 'room_000', {'timestamp': [0, 2, 3], 'strength': 1}], ['dep_001', 'atlocation', 'room_000', {'timestamp': [0, 1], 'strength': 1.62}], ['dep_007', 'atlocation', 'room_000', {'strength': 1, 'timestamp': [1]}], ['room_000', 'west', 'wall', {'timestamp': [1, 2], 'strength': 1}], ['room_000', 'north', 'wall', {'strength': 1}], ['room_000', 'south', 'room_004', {'timestamp': [3]}], ['room_004', 'south', 'wall', {'timestamp': [4]}], ['room_004', 'north', 'room_000', {'timestamp': [4]}]]),
       list([['agent', 'atlocation', 'room_001', {'current_time': 9, 'strength': 1.3122000000000003, 'timestamp': [6, 7]}], ['room_001', 'west', 'room_000', {'current_time': 9, 'timestamp': [1, 6]}], ['room_001', 'south', 'room_005', {'current_time': 9, 'strength': 2.2680000000000002, 'timestamp': [5]}], ['room_001', 'north', 'wall', {'current_time': 9, 'timestamp': [1, 2], 'strength': 2.52}], ['room_001', 'east', 'wall', {'current_time': 9, 'strength': 2.1222000000000003, 'timestamp': [2, 6]}], ['dep_001', 'atlocation', 'room_000', {'strength': 1.0628820000000003, 'timestamp': [4]}], ['room_000', 'south', 'room_004', {'strength': 1, 'timestamp': [3, 4]}], ['room_000', 'east', 'room_001', {'timestamp': [0, 3]}], ['room_000', 'west', 'wall', {'timestamp': [0]}], ['agent', 'atlocation', 'room_000', {'timestamp': [0], 'strength': 1}], ['room_000', 'north', 'wall', {'timestamp': [0], 'strength': 1}], ['dep_007', 'atlocation', 'room_000', {'strength': 1}]]),
       list([['dep_001', 'atlocation', 'room_000', {'current_time': 4, 'strength': 1.8, 'timestamp': [2]}], ['room_000', 'south', 'room_004', {'current_time': 4, 'timestamp': [3]}], ['room_000', 'east', 'room_001', {'current_time': 4, 'timestamp': [0]}], ['dep_007', 'atlocation', 'room_000', {'current_time': 4, 'strength': 1, 'timestamp': [3]}], ['room_000', 'west', 'wall', {'current_time': 4, 'strength': 1, 'timestamp': [2]}], ['room_000', 'north', 'wall', {'current_time': 4, 'strength': 1}], ['agent', 'atlocation', 'room_000', {'current_time': 4, 'timestamp': [0, 2], 'strength': 1}], ['room_004', 'south', 'wall', {'timestamp': [1]}], ['room_004', 'east', 'room_005', {'strength': 1}]]),
       list([['agent', 'atlocation', 'room_004', {'current_time': 3, 'timestamp': [1, 2]}], ['room_004', 'north', 'room_000', {'current_time': 3, 'strength': 1}], ['room_004', 'south', 'wall', {'current_time': 3, 'strength': 1, 'timestamp': [2]}], ['room_004', 'west', 'wall', {'current_time': 3, 'timestamp': [1, 2]}], ['room_004', 'east', 'room_005', {'current_time': 3, 'timestamp': [1]}], ['room_000', 'west', 'wall', {'strength': 1}], ['room_000', 'south', 'room_004', {'timestamp': [0]}], ['room_000', 'north', 'wall', {'timestamp': [0]}], ['room_000', 'east', 'room_001', {'timestamp': [0]}], ['agent', 'atlocation', 'room_000', {'strength': 1}]]),
       list([['room_004', 'north', 'room_000', {'current_time': 9, 'strength': 1}], ['agent', 'atlocation', 'room_004', {'current_time': 9, 'timestamp': [7], 'strength': 1}], ['room_004', 'south', 'wall', {'current_time': 9, 'strength': 1, 'timestamp': [7]}], ['room_004', 'east', 'room_005', {'current_time': 9}], ['room_004', 'west', 'wall', {'current_time': 9}], ['room_000', 'south', 'room_004', {'strength': 1}], ['room_005', 'north', 'room_001', {'timestamp': [2, 5], 'strength': 1}], ['room_006', 'east', 'room_007', {'timestamp': [3]}], ['room_006', 'west', 'room_005', {'strength': 1.1809800000000004}], ['sta_004', 'atlocation', 'room_006', {'strength': 1.1809800000000004}], ['agent', 'atlocation', 'room_006', {'timestamp': [4]}], ['room_005', 'east', 'room_006', {'strength': 1.4580000000000002}], ['room_005', 'south', 'wall', {'timestamp': [6]}], ['room_005', 'west', 'room_004', {'timestamp': [6]}]]),
       list([['agent', 'atlocation', 'room_000', {'current_time': 4, 'strength': 1}], ['room_000', 'east', 'room_001', {'current_time': 4, 'timestamp': [0, 3]}], ['room_000', 'north', 'wall', {'current_time': 4, 'strength': 1.8}], ['room_000', 'west', 'wall', {'current_time': 4, 'strength': 1, 'timestamp': [1, 3]}], ['dep_001', 'atlocation', 'room_000', {'current_time': 4, 'timestamp': [0, 1, 3]}], ['dep_007', 'atlocation', 'room_000', {'current_time': 4, 'timestamp': [0], 'strength': 1}], ['room_000', 'south', 'room_004', {'current_time': 4, 'timestamp': [0, 1], 'strength': 1}], ['room_004', 'west', 'wall', {'strength': 1}], ['room_004', 'south', 'wall', {'timestamp': [2]}], ['room_004', 'north', 'room_000', {'timestamp': [2]}], ['room_004', 'east', 'room_005', {'strength': 1}]]),
       list([['room_001', 'south', 'room_005', {'current_time': 5}], ['agent', 'atlocation', 'room_001', {'current_time': 5}], ['room_001', 'north', 'wall', {'current_time': 5}], ['room_001', 'east', 'wall', {'current_time': 5}], ['room_001', 'west', 'room_000', {'current_time': 5}], ['agent', 'atlocation', 'room_000', {'strength': 1, 'timestamp': [3]}], ['dep_001', 'atlocation', 'room_000', {'strength': 1, 'timestamp': [1, 2, 3]}], ['room_000', 'south', 'room_004', {'strength': 1.4580000000000002, 'timestamp': [3, 4]}], ['room_000', 'north', 'wall', {'strength': 3.168}], ['room_000', 'west', 'wall', {'strength': 1.62, 'timestamp': [2, 4]}], ['dep_007', 'atlocation', 'room_000', {'strength': 1.8, 'timestamp': [2, 3]}], ['room_000', 'east', 'room_001', {'timestamp': [3], 'strength': 1}]])],
      dtype=object)

print(data.shape)

(1, 7, 4)
(8,)


In [5]:
(
    entity_embeddings,
    relation_embeddings,
    edge_index,
    edge_type,
    quals,
    short_memory_idx,
    agent_node_idx,
    num_short_memories
) = agent.dqn.process_sample(data[0])

In [7]:
print(f"Entity Embeddings Shape:\t{entity_embeddings.shape}")
print(f"Relation Embeddings Shape:\t{relation_embeddings.shape}")
print(f"Edge Index Shape:\t\t{edge_index.shape}")
print(f"Edge Type Shape:\t\t{edge_type.shape}")
print(f"Quals Shape:\t\t\t{quals.shape}")
print(f"Short Memory Index Shape:\t{short_memory_idx.shape}")
print(f"Agent Node Index Shape:\t\t{agent_node_idx.shape}")
print(f"Number of Short Memories:\t{num_short_memories}")


Entity Embeddings Shape:	torch.Size([10, 10])
Relation Embeddings Shape:	torch.Size([13, 10])
Edge Index Shape:		torch.Size([2, 14])
Edge Type Shape:		torch.Size([14])
Quals Shape:			torch.Size([3, 30])
Short Memory Index Shape:	torch.Size([7])
Agent Node Index Shape:		torch.Size([])
Number of Short Memories:	7


In [9]:
(
    entity_embeddings,
    relation_embeddings,
    edge_index,
    edge_type,
    quals,
    short_memory_idx,
    agent_node_idx,
    num_short_memories
) = agent.dqn.process_batch(data)

In [10]:
print(f"Entity Embeddings Shape:\t{entity_embeddings.shape}")
print(f"Relation Embeddings Shape:\t{relation_embeddings.shape}")
print(f"Edge Index Shape:\t\t{edge_index.shape}")
print(f"Edge Type Shape:\t\t{edge_type.shape}")
print(f"Quals Shape:\t\t\t{quals.shape}")
print(f"Short Memory Index Shape:\t{short_memory_idx.shape}")
print(f"Agent Node Index Shape:\t\t{agent_node_idx.shape}")
print(f"Number of Short Memories:\t{num_short_memories}")


Entity Embeddings Shape:	torch.Size([107, 10])
Relation Embeddings Shape:	torch.Size([104, 10])
Edge Index Shape:		torch.Size([2, 178])
Edge Type Shape:		torch.Size([178])
Quals Shape:			torch.Size([3, 308])
Short Memory Index Shape:	torch.Size([46])
Agent Node Index Shape:		torch.Size([8])
Number of Short Memories:	tensor([7, 5, 5, 7, 5, 5, 7, 5])


In [3]:
import torch

# Example tensors
q_mm = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]])
num_short_memories = torch.tensor([5])

split_triples = [q_mm[start:start + num] for start, num in zip(num_short_memories.cumsum(0).roll(1), num_short_memories)]
split_triples[0] = q_mm[:num_short_memories[0]]

# for t in split_triples:
#     print(t)


In [4]:
split_triples

[tensor([[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12],
         [13, 14, 15]])]

In [18]:
t.shape

torch.Size([3, 3])

In [11]:
num_short_memories.sum()

tensor(46)

In [43]:
import torch

# Example tensor of shape (N, M)
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# Split the tensor into M row vectors
row_vectors = list(tensor.unbind(dim=0))

# Display the result
for row in row_vectors:
    print(row.unsqueeze(0).shape)


torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])


In [44]:
row_vectors

[tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])]

In [40]:
tensor.shape

torch.Size([3, 3])

In [37]:
import torch

mem = torch.tensor([1, 2, 2, 3, 3, 3])
num_mems = torch.tensor([1, 2, 3])

split_mem = [
    mem[start : start + num.item()]
    for start, num in zip(num_mems.cumsum(0).roll(1), num_mems)
]
split_mem[0] = mem[: num_mems[0]]

print(split_mem)

[tensor([1]), tensor([2, 2]), tensor([3, 3, 3])]


In [26]:
torch.tensor([torch.tensor(0), torch.tensor(1), torch.tensor(2)]).shape

torch.Size([3])

In [15]:
agent_node_idx

tensor([ 2, 12, 25, 48, 55, 67, 81, 96])

In [22]:
torch.tensor([[1,2]])[0].shape

torch.Size([2])

In [23]:
torch.tensor([[1,2]])[0:1].shape

torch.Size([1, 2])

In [16]:
batch = [agent.dqn.convert_sample_to_data(sample) for sample in data]
batch

AttributeError: 'GNN' object has no attribute 'convert_sample_to_data'

In [None]:
[Data(x=[7, 8], edge_index=[2, 7],
 Data(x=[9, 8], edge_index=[2, 14],
 Data(x=[8, 8], edge_index=[2, 12],
 Data(x=[8, 8], edge_index=[2, 9],
 Data(x=[6, 8], edge_index=[2, 10],
 Data(x=[9, 8], edge_index=[2, 14],
 Data(x=[8, 8], edge_index=[2, 11],
 Data(x=[8, 8], edge_index=[2, 12],]

In [32]:
[len(sample) for sample in data]

[7, 14, 12, 9, 10, 14, 11, 12]

In [13]:
batch[-1].x

tensor([[-0.0407,  0.0931, -0.0577, -0.1373, -0.2149,  0.0200,  0.1025, -0.1131],
        [ 0.0064, -0.0282, -0.0982,  0.1061,  0.0975,  0.1682,  0.0088,  0.0496],
        [-0.0398,  0.2858, -0.0652, -0.1250,  0.2002,  0.2245, -0.1621,  0.0538],
        [-0.2273,  0.3185,  0.0252, -0.1168,  0.1527, -0.4556, -0.0915, -0.0586],
        [-0.0714,  0.1853,  0.1917,  0.3164,  0.0983, -0.1425,  0.3199,  0.1631],
        [-0.1893,  0.1739, -0.1608, -0.1605,  0.0994, -0.1334,  0.0982,  0.0991],
        [ 0.0174,  0.1311, -0.1083,  0.1488, -0.0361, -0.0792, -0.0231,  0.0907],
        [ 0.3198, -0.1265,  0.1058,  0.0101, -0.1545,  0.2962,  0.1155,  0.0226]],
       grad_fn=<StackBackward0>)

In [14]:
batch[-1].edge_index

tensor([[0, 2, 0, 0, 0, 2, 5, 4, 4, 4, 7, 4],
        [1, 0, 3, 3, 4, 4, 4, 6, 3, 3, 4, 0]])

In [15]:
batch[-1].short_triples

[{'head_idx': 0, 'relation_idx': 2, 'tail_idx': 1},
 {'head_idx': 2, 'relation_idx': 4, 'tail_idx': 0},
 {'head_idx': 0, 'relation_idx': 0, 'tail_idx': 3},
 {'head_idx': 0, 'relation_idx': 1, 'tail_idx': 3},
 {'head_idx': 0, 'relation_idx': 3, 'tail_idx': 4}]

In [16]:
batch[-1].agent_node

2

In [17]:
from agent.dqn.nn.gnn import DataLoader
loader = DataLoader(batch, len(batch), shuffle=False)
batch_ = next(iter(loader))
batch_



DataBatch(x=[63, 8], edge_index=[2, 89], short_triples=[8], agent_node=[8], batch=[63], ptr=[9])

In [19]:
import torch
torch.equal(torch.cat([b.x for b in batch]), batch_.x)

True

In [26]:
batch_.edge_index[:, -5:]

tensor([[59, 59, 59, 62, 59],
        [61, 58, 58, 59, 55]])

In [20]:
[b.x.shape for b in batch]

[torch.Size([7, 8]),
 torch.Size([9, 8]),
 torch.Size([8, 8]),
 torch.Size([8, 8]),
 torch.Size([6, 8]),
 torch.Size([9, 8]),
 torch.Size([8, 8]),
 torch.Size([8, 8])]

In [21]:
foo = torch.zeros(2, 0)
to_increment = 0
for idx, b in enumerate(batch):
    edge_index_incresed = b.edge_index + to_increment
    foo = torch.cat([foo, edge_index_incresed], dim=1)
    to_increment += b.x.shape[0]

## Run fixed combinations

In [1]:
import matplotlib

matplotlib.use("Agg")

import logging

logger = logging.getLogger()
logger.disabled = True

import os
from agent import DQNAgent
from tqdm.auto import tqdm
import random
import itertools


room_size = "xl-different-prob"
terminates_at = 99
num_iterations = (terminates_at + 1) * 100
replay_buffer_size = num_iterations // 2
validation_starts_at = num_iterations // 2
warm_start = num_iterations // 4
batch_size = 32
target_update_interval = 10
gamma = 0.9
semantic_decay_factor = 0.9

for capacity_max in [24, 12, 6, 48]:
    prob_type = (
        "non-equal-object-probs"
        if "different-prob" in room_size
        else "equal-object-probs"
    )
    root_path = (
        f"./training-results/{prob_type}/dqn/"
        f"room_size={room_size}/capacity={capacity_max}/"
    )
    for pretrain_semantic in [False, "include_walls", "exclude_walls"]:
        for test_seed in [0, 1, 2, 3, 4]:
            params_dict = {
                "env_str": "room_env:RoomEnv-v2",
                "num_iterations": num_iterations,
                "replay_buffer_size": replay_buffer_size,
                "validation_starts_at": validation_starts_at,
                "warm_start": warm_start,
                "batch_size": batch_size,
                "target_update_interval": target_update_interval,
                "epsilon_decay_until": num_iterations,
                "max_epsilon": 1.0,
                "min_epsilon": 0.1,
                "gamma": gamma,
                "capacity": {"long": capacity_max, "short": 15},
                "pretrain_semantic": pretrain_semantic,
                "semantic_decay_factor": semantic_decay_factor,
                "dqn_params": {
                    "embedding_dim": 10,
                    "num_layers_GNN": 2,
                    "num_hidden_layers_MLP": 1,
                    "dueling_dqn": True,
                },
                "num_samples_for_results": {"val": 5, "test": 10},
                "validation_interval": 5,
                "plotting_interval": 50,
                "train_seed": test_seed + 5,
                "test_seed": test_seed,
                "device": "cpu",
                "qa_function": "latest_strongest",
                "env_config": {
                    "question_prob": 1.0,
                    "terminates_at": terminates_at,
                    "randomize_observations": "all",
                    "room_size": room_size,
                    "rewards": {"correct": 1, "wrong": 0, "partial": 0},
                    "make_everything_static": False,
                    "num_total_questions": 1000,
                    "question_interval": 1,
                    "include_walls_in_observations": True,
                },
                "ddqn": True,
                "default_root_dir": root_path,
            }

            agent = DQNAgent(**params_dict)
            agent.train()

  from .autonotebook import tqdm as notebook_tqdm


Running on cpu


TypeError: GNN.__init__() got an unexpected keyword argument 'embedding_dim'