In [None]:
from simple_foraging_env import SimpleForagingEnv, SimpleAgent, RandomAgent
from opponent_model import OpponentModel, SubGoalSelector
from q_agent import QLearningAgent, ReplayBuffer
from q_agent_classic import QLearningAgentClassic
from omg_args import OMGArgs
import transformers as t
import matplotlib.pyplot as plt
import torch
import os
import random
import numpy as np
from collections import deque
from typing import Deque, Dict, List, Tuple, Optional

In [None]:
from experiments.om import q_agent


env = SimpleForagingEnv(grid_size=7, max_steps=30)

obs_sample = env.reset()[0]
H, W, F_dim = obs_sample.shape
NUM_ACTIONS = 4

args = OMGArgs(
  device="cpu",
  folder_id="0",
  batch_size=8,
  horizon_H=6,
  qnet_hidden=256,
  max_steps=30,
  selector_mode="conservative",
  vae_beta=0.1,
  beta_start=1.0,
  beta_end=2.0,
  selector_tau_start=2.0,
  selector_tau_end=0.1,
  state_shape=obs_sample.shape,
  H=H, W=W,
  state_feature_splits=(F_dim,),
  action_dim=NUM_ACTIONS,
  latent_dim=32,
  d_model=256,
  nhead=8,
  num_encoder_layers=2,
  num_decoder_layers=2,
  dim_feedforward=1024,
  dropout=0.1,
)

os.makedirs("./diagrams_0", exist_ok=True)

vae = t.TransformerVAE(args)
cvae = t.TransformerCVAE(args)
vae.load_state_dict(torch.load('./models_0/vae.pth', map_location=args.device))
cvae.load_state_dict(torch.load('./models_0/cvae.pth', map_location=args.device))
selector = SubGoalSelector(args)

op_model = OpponentModel(
  cvae, vae, selector, args=args)

q_agent = QLearningAgent(
    env=env,
    opponent_model=op_model,
    args=args
)
q_agent.q.load_state_dict(torch.load('./models_0/qnet.pth', map_location=args.device))

In [None]:
env._place_agent(0, (2, 0))
obs = env._get_observations()
op_model.prior_model.eval()
recon_logits, mu, logvar = op_model.prior_model(
    torch.from_numpy(obs[0]).float().unsqueeze(0).to(args.device)
)
op_model.visualize_subgoal_logits(
    obs[0], recon_logits)
print(mu.squeeze(0).detach().cpu().numpy())

In [None]:
env._place_agent(0, (2, 1))
obs = env._get_observations()
op_model.prior_model.eval()
recon_logits, mu, logvar = op_model.prior_model(
    torch.from_numpy(obs[0]).float().unsqueeze(0).to(args.device)
)
op_model.visualize_subgoal_logits(
    obs[0], recon_logits)
print(mu.squeeze(0).detach().cpu().numpy())

In [None]:
env._place_agent(0, (2, 2))
obs = env._get_observations()
op_model.prior_model.eval()
recon_logits, mu, logvar = op_model.prior_model(
    torch.from_numpy(obs[0]).float().unsqueeze(0).to(args.device)
)
op_model.visualize_subgoal_logits(
    obs[0], recon_logits)
print(mu.squeeze(0).detach().cpu().numpy())

In [None]:
env._place_agent(0, (2, 3))
obs = env._get_observations()
op_model.prior_model.eval()
recon_logits, mu, logvar = op_model.prior_model(
    torch.from_numpy(obs[0]).float().unsqueeze(0).to(args.device)
)
op_model.visualize_subgoal_logits(
    obs[0], recon_logits)
print(mu.squeeze(0).detach().cpu().numpy())

In [None]:
action_names = {0: "UP", 1: "DOWN", 2: "LEFT", 3: "RIGHT"}
def _choose_action(qvals: torch.Tensor) -> int:
    qvals = qvals.squeeze(0)  # (A,)
    max_q = torch.max(qvals).item()
    max_actions = (qvals == max_q).nonzero(as_tuple=False).view(-1)
    if len(max_actions) > 1:
      return int(max_actions[torch.randint(len(max_actions), (1,))].item())
    return int(torch.argmax(qvals, dim=-1).item())

def select_action(s_t: np.ndarray, history: Dict[str, List[torch.Tensor]]) -> Tuple[int, torch.Tensor]:
  """
  (interaction phase) Infer g_hat and act eps-greedily on Q(s,g_hat,*)
  """
  ghat_mu, ghat_logvar = q_agent._infer_ghat(s_t, history)  # (1, latent_dim)
  s = torch.from_numpy(s_t).float().unsqueeze(0).to(args.device)
  qvals = q_agent.q(s, ghat_mu)
  print("\t\t\tUP DOWN LEFT RIGHT")
  print(f"Infered latent Q: {qvals.squeeze(0).detach().cpu().numpy()}")
  a = _choose_action(qvals)
  return a, ghat_mu.squeeze(0), ghat_logvar.squeeze(0)

def select_action_random_latent(s_t: np.ndarray) -> Tuple[int, torch.Tensor]:
  """
  (interaction phase) Infer g_hat and act eps-greedily on Q(s,g_hat,*)
  """
  ghat_mu = torch.randn(1, args.latent_dim).to(args.device)
  ghat_logvar = torch.randn(1, args.latent_dim).to(args.device)
  s = torch.from_numpy(s_t).float().unsqueeze(0).to(args.device)
  qvals = q_agent.q(s, ghat_mu)
  print(f"Random latent Q: {qvals.squeeze(0).detach().cpu().numpy()}")
  a = _choose_action(qvals)
  return a, ghat_mu.squeeze(0), ghat_logvar.squeeze(0)

In [None]:
replay = ReplayBuffer(10_000)
opponent_agent = SimpleAgent(1)
  
for episode in range(100):
  opponent_agent.reset()
  obs = env.reset()
  done = False
  ep_ret = 0.0

  # History container
  history_len = args.max_history_length
  history = {
      "states": deque(maxlen=history_len),
      "actions": deque(maxlen=history_len)
  }

  step_buffer = deque(maxlen=args.horizon_H + 1)

  for step in range(args.max_steps or 500):
    # Convert deque to list for the model
    current_history = {k: list(v) for k, v in history.items()}

    a, ghat_mu, ghat_logvar = select_action(obs[0], current_history)
    a_opponent = opponent_agent.select_action(obs[1])
    actions = {0: a, 1: a_opponent}
    next_obs, reward, done, info = env.step(actions)

    # Store the current step's info
    step_info = {
        "state": obs[0].copy(),
        "action": a,
        "reward": float(reward[0]),
        "next_state": next_obs[0].copy(),
        "done": bool(done),
        "infer_mu": ghat_mu.detach().cpu(),
        "infer_log_var": ghat_logvar.detach().cpu(),
        "history": {k: [t.clone() for t in v] for k, v in current_history.items()}
    }
    step_buffer.append(step_info)

    # Once the buffer is full, the oldest step has its full future window
    if len(step_buffer) == args.horizon_H + 1:
      transition_to_store = step_buffer[0]
      future_states = [s["state"] for s in list(step_buffer)[1:]]
      transition_to_store["future_states"] = future_states
      replay.push(transition_to_store)
    elif done and len(step_buffer) > 1:
      # If episode ends, fill the future window with remaining states
      while len(step_buffer) > 1:
        transition_to_store = step_buffer.popleft()
        future_states = [s["state"] for s in list(step_buffer)]
        # Fill the rest with copies of the terminal state
        for _ in range(args.horizon_H - len(future_states)):
          future_states.append(step_buffer[-1]["state"])
        transition_to_store["future_states"] = future_states
        replay.push(transition_to_store)

    # Update history for the next step
    history["states"].append(torch.from_numpy(obs[0]).float())
    history["actions"].append(torch.tensor(a, dtype=torch.long))

    ep_ret += reward[0]
    obs = next_obs

    if done:
      break
  print(f"Episode {episode + 1} Return: {ep_ret}")
print("Data collection complete.")

In [None]:
def run_episode():
  opponent_agent = SimpleAgent(1)
  obs = env.reset()
  done = False
  ep_ret = 0.0

  # History container
  history_len = args.max_history_length
  history = {
      "states": deque(maxlen=history_len),
      "actions": deque(maxlen=history_len)
  }

  for step in range(args.max_steps or 500):
    # Convert deque to list for the model
    current_history = {k: list(v) for k, v in history.items()}

    a, ghat_mu, ghat_logvar = select_action_random_latent(obs[0])
    a, ghat_mu, ghat_logvar = select_action(obs[0], current_history)
    a_opponent = opponent_agent.select_action(obs[1])
    actions = {0: a, 1: a_opponent}
    env.render_from_obs(obs[0])
    print("Selected action:", action_names[a])
    print(ghat_mu.detach().cpu().numpy())

    
    next_obs, reward, done, info = env.step(actions)

    # Update history for the next step
    history["states"].append(torch.from_numpy(obs[0]).float())
    history["actions"].append(torch.tensor(a, dtype=torch.long))

    ep_ret += reward[0]
    obs = next_obs

    if done:
      break

In [None]:
run_episode()

In [None]:
batch_list = replay.sample(1)
om_batch = {
        # States: (B, H, W, F)
        "states": torch.stack([torch.from_numpy(b["state"]).float() for b in batch_list], dim=0),
        "history": q_agent._collate_history([b["history"] for b in batch_list]),
        "future_states": torch.stack([torch.from_numpy(np.stack(b["future_states"])) for b in batch_list], dim=0),
        "infer_mu": torch.stack([b["infer_mu"] for b in batch_list], dim=0),
        "infer_log_var": torch.stack([b["infer_log_var"] for b in batch_list], dim=0),
        "dones": torch.tensor([b["done"] for b in batch_list], dtype=torch.float32, device=args.device)
    }
print(om_batch["states"].shape)
print(len(om_batch["history"]))
print(om_batch["future_states"].shape)
print(om_batch["infer_mu"].shape)
print(om_batch["infer_log_var"].shape)


In [None]:
op_model.inference_model.eval()
with torch.no_grad():
    recon_logits, mu, logvar = op_model.inference_model(
        om_batch["states"],
        om_batch["history"],
    )
print("Reconstructed logits shape:", recon_logits.shape)
print("Inferred mu shape:", mu)
print("Inferred logvar shape:", logvar)
op_model.visualize_subgoal_logits(
    om_batch["states"][0].numpy(), recon_logits)