In [None]:
from simple_foraging_env import SimpleForagingEnv
from opponent_model import OpponentModel, SubGoalSelector
from q_agent import QLearningAgent, ReplayBuffer
from q_agent_classic import QLearningAgentClassic
from omg_args import OMGArgs
import transformers as t
import torch
import os

In [None]:
from experiments.om import q_agent


env = SimpleForagingEnv(grid_size=7, max_steps=30)

obs_sample = env.reset()[0]
H, W, F_dim = obs_sample.shape
NUM_ACTIONS = 4

args = OMGArgs(
  device="cpu",
  folder_id="0",
  batch_size=8,
  horizon_H=6,
  qnet_hidden=256,
  max_steps=30,
  selector_mode="conservative",
  vae_beta=0.1,
  beta_start=1.0,
  beta_end=2.0,
  selector_tau_start=2.0,
  selector_tau_end=0.1,
  state_shape=obs_sample[0].shape,
  H=H, W=W,
  state_feature_splits=(F_dim,),
  action_dim=NUM_ACTIONS,
  latent_dim=32,
  d_model=256,
  nhead=8,
  num_encoder_layers=2,
  num_decoder_layers=2,
  dim_feedforward=1024,
  dropout=0.1,
)

os.makedirs("./diagrams_0", exist_ok=True)

vae = t.TransformerVAE(args)
cvae = t.TransformerCVAE(args)
vae.load_state_dict(torch.load('./models_0/vae.pth'))
cvae.load_state_dict(torch.load('./models_0/cvae.pth'))
selector = SubGoalSelector(args)
cvae_optimizer = torch.optim.Adam(cvae.parameters(), lr=args.cvae_lr)

op_model = OpponentModel(
  cvae, vae, selector, optimizer=cvae_optimizer, args=args)

q_agent = QLearningAgent(
    env=env,
    opponent_model=op_model,
    args=args
)