In [1]:
import argparse
from omegaconf import OmegaConf
import gymnasium
from world_models import WorldModel

import colorama
import numpy as np
from tinygrad import Tensor, nn
from tinygrad.nn.state import get_parameters, get_state_dict

import env_wrapper

In [2]:
def build_single_env(env_name, image_size, seed):
  env = gymnasium.make(env_name, full_action_space=False, render_mode="rgb_array", frameskip=1)
  env = env_wrapper.SeedEnvWrapper(env, seed=seed)
  env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4)
  env = gymnasium.wrappers.ResizeObservation(env, shape=image_size)
  env = env_wrapper.LifeLossInfo(env)
  return env

In [3]:
def build_world_model(conf, action_dim):
    return WorldModel(
        in_channels=conf.models.world_model.in_channels,
        action_dim=action_dim,
        transformer_max_length=conf.models.world_model.transformer_max_length,
        transformer_hidden_dim=conf.models.world_model.transformer_hidden_dim,
        transformer_num_layers=conf.models.world_model.transformer_num_layers,
        transformer_num_heads=conf.models.world_model.transformer_num_heads
    )

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("-n", type=str, required=True)
parser.add_argument("-seed", type=int, required=True)
parser.add_argument("-config_path", type=str, required=True)
parser.add_argument("-env_name", type=str, required=True)
parser.add_argument("-trajectory_path", type=str, required=True)
args = parser.parse_args([
    "-n", "MsPacman",
    "-seed", "1",
    "-config_path", "STORM.yaml",
    "-env_name", "ALE/MsPacman-v5",
    "-trajectory_path", "D_TRAJ/MsPacman.pkl"
])
conf = OmegaConf.load(args.config_path)

In [5]:
print(colorama.Fore.RED + str(args) + colorama.Style.RESET_ALL)

np.random.seed(args.seed)
Tensor.manual_seed(args.seed)

[31mNamespace(n='MsPacman', seed=1, config_path='STORM.yaml', env_name='ALE/MsPacman-v5', trajectory_path='D_TRAJ/MsPacman.pkl')[0m


In [6]:
dummy_env = build_single_env(args.env_name, conf.basic_settings.image_size, seed=0)
action_dim = dummy_env.action_space.n.item()
del dummy_env

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [7]:
world_model = build_world_model(conf, action_dim)

In [8]:
i = 1
updatable_count = 0
for k, v in get_state_dict(world_model).items():
    print(i, "- ", k, ": ", v.shape, " -- ", v.requires_grad)
    if v.requires_grad is None: updatable_count += 1
    i+=1
print("Total updatable parameters: ", updatable_count)

1 -  encoder.conv_in.weight :  (32, 3, 4, 4)  --  None
2 -  encoder.bn_in.weight :  (32,)  --  None
3 -  encoder.bn_in.bias :  (32,)  --  None
4 -  encoder.bn_in.running_mean :  (32,)  --  False
5 -  encoder.bn_in.running_var :  (32,)  --  False
6 -  encoder.bn_in.num_batches_tracked :  (1,)  --  False
7 -  encoder.convs.0.weight :  (64, 32, 4, 4)  --  None
8 -  encoder.convs.1.weight :  (128, 64, 4, 4)  --  None
9 -  encoder.convs.2.weight :  (256, 128, 4, 4)  --  None
10 -  encoder.bns.0.weight :  (64,)  --  None
11 -  encoder.bns.0.bias :  (64,)  --  None
12 -  encoder.bns.0.running_mean :  (64,)  --  False
13 -  encoder.bns.0.running_var :  (64,)  --  False
14 -  encoder.bns.0.num_batches_tracked :  (1,)  --  False
15 -  encoder.bns.1.weight :  (128,)  --  None
16 -  encoder.bns.1.bias :  (128,)  --  None
17 -  encoder.bns.1.running_mean :  (128,)  --  False
18 -  encoder.bns.1.running_var :  (128,)  --  False
19 -  encoder.bns.1.num_batches_tracked :  (1,)  --  False
20 -  encoder

# World model input

In [27]:
from tinygrad.nn.state import safe_load, load_state_dict

cur_device = Tensor.randn(1).device
print("cur device: ", cur_device)

inputs = safe_load("world_model_inputs.safetensors")
obs, action, reward, termination = inputs["obs"].to(cur_device), inputs["action"].to(cur_device), inputs["reward"].to(cur_device), inputs["termination"].to(cur_device)

cur device:  CUDA


In [28]:
with Tensor.train():
    total_loss = world_model.loss(obs, action, reward, termination, logger =None)
    total_loss.realize()

forward time: 0.02721s


In [29]:
total_loss.backward()

<Tensor <LB CUDA () contig:True (<BinaryOps.ADD: 1>, <buf device:CUDA size:1 dtype:dtypes.float>)> on CUDA with grad <LB CUDA () contig:True (<LoadOps.CONST: 2>, None)>>

In [31]:
i = 0
for p in get_parameters(world_model):
    if p.grad is None:
        i+=1
        print(p.shape, "No grad")
print("No grad count: ", i)

(32, 3, 4, 4) No grad
(32,) No grad
(32,) No grad
(32,) No grad
(32,) No grad
(1,) No grad
(64, 32, 4, 4) No grad
(128, 64, 4, 4) No grad
(256, 128, 4, 4) No grad
(64,) No grad
(64,) No grad
(64,) No grad
(64,) No grad
(1,) No grad
(128,) No grad
(128,) No grad
(128,) No grad
(128,) No grad
(1,) No grad
(256,) No grad
(256,) No grad
(256,) No grad
(256,) No grad
(1,) No grad
(512, 1033) No grad
(512,) No grad
(512,) No grad
(512, 512) No grad
(512,) No grad
(512,) No grad
(64, 512) No grad
(1, 1, 64) No grad
(512, 512) No grad
(512, 512) No grad
(512, 512) No grad
(512, 512) No grad
(512,) No grad
(512,) No grad
(1024, 512) No grad
(1024,) No grad
(512, 1024) No grad
(512,) No grad
(512,) No grad
(512,) No grad
(512, 512) No grad
(512, 512) No grad
(512, 512) No grad
(512, 512) No grad
(512,) No grad
(512,) No grad
(1024, 512) No grad
(1024,) No grad
(512, 1024) No grad
(512,) No grad
(512,) No grad
(512,) No grad
(512,) No grad
(512,) No grad
(1024, 4096) No grad
(1024,) No grad
(1024