In [1]:
import argparse
from omegaconf import OmegaConf
import gymnasium
from world_models import WorldModel

import colorama
import numpy as np
from tinygrad import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict

import env_wrapper

In [2]:
def build_single_env(env_name, image_size, seed):
  env = gymnasium.make(env_name, full_action_space=False, render_mode="rgb_array", frameskip=1)
  env = env_wrapper.SeedEnvWrapper(env, seed=seed)
  env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4)
  env = gymnasium.wrappers.ResizeObservation(env, shape=image_size)
  env = env_wrapper.LifeLossInfo(env)
  return env

In [3]:
def build_world_model(conf, action_dim):
    return WorldModel(
        in_channels=conf.models.world_model.in_channels,
        action_dim=action_dim,
        transformer_max_length=conf.models.world_model.transformer_max_length,
        transformer_hidden_dim=conf.models.world_model.transformer_hidden_dim,
        transformer_num_layers=conf.models.world_model.transformer_num_layers,
        transformer_num_heads=conf.models.world_model.transformer_num_heads
    )

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("-n", type=str, required=True)
parser.add_argument("-seed", type=int, required=True)
parser.add_argument("-config_path", type=str, required=True)
parser.add_argument("-env_name", type=str, required=True)
parser.add_argument("-trajectory_path", type=str, required=True)
args = parser.parse_args([
    "-n", "MsPacman",
    "-seed", "1",
    "-config_path", "STORM.yaml",
    "-env_name", "ALE/MsPacman-v5",
    "-trajectory_path", "D_TRAJ/MsPacman.pkl"
])
conf = OmegaConf.load(args.config_path)

In [5]:
print(colorama.Fore.RED + str(args) + colorama.Style.RESET_ALL)

np.random.seed(args.seed)
Tensor.manual_seed(args.seed)

[31mNamespace(n='MsPacman', seed=1, config_path='STORM.yaml', env_name='ALE/MsPacman-v5', trajectory_path='D_TRAJ/MsPacman.pkl')[0m


In [6]:
dummy_env = build_single_env(args.env_name, conf.basic_settings.image_size, seed=0)
action_dim = dummy_env.action_space.n.item()
del dummy_env

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [7]:
world_model = build_world_model(conf, action_dim)