In [2]:
import argparse
from omegaconf import OmegaConf
import gymnasium
from world_models import WorldModel

import colorama
import numpy as np
from tinygrad import Tensor, nn
from tinygrad.nn.state import get_parameters, get_state_dict

import env_wrapper

In [3]:
def build_single_env(env_name, image_size, seed):
  env = gymnasium.make(env_name, full_action_space=False, render_mode="rgb_array", frameskip=1)
  env = env_wrapper.SeedEnvWrapper(env, seed=seed)
  env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4)
  env = gymnasium.wrappers.ResizeObservation(env, shape=image_size)
  env = env_wrapper.LifeLossInfo(env)
  return env

In [4]:
def build_world_model(conf, action_dim):
    return WorldModel(
        in_channels=conf.models.world_model.in_channels,
        action_dim=action_dim,
        transformer_max_length=conf.models.world_model.transformer_max_length,
        transformer_hidden_dim=conf.models.world_model.transformer_hidden_dim,
        transformer_num_layers=conf.models.world_model.transformer_num_layers,
        transformer_num_heads=conf.models.world_model.transformer_num_heads
    )

In [5]:
parser = argparse.ArgumentParser()
parser.add_argument("-n", type=str, required=True)
parser.add_argument("-seed", type=int, required=True)
parser.add_argument("-config_path", type=str, required=True)
parser.add_argument("-env_name", type=str, required=True)
parser.add_argument("-trajectory_path", type=str, required=True)
args = parser.parse_args([
    "-n", "MsPacman",
    "-seed", "1",
    "-config_path", "STORM.yaml",
    "-env_name", "ALE/MsPacman-v5",
    "-trajectory_path", "D_TRAJ/MsPacman.pkl"
])
conf = OmegaConf.load(args.config_path)

In [6]:
print(colorama.Fore.RED + str(args) + colorama.Style.RESET_ALL)

np.random.seed(args.seed)
Tensor.manual_seed(args.seed)

[31mNamespace(n='MsPacman', seed=1, config_path='STORM.yaml', env_name='ALE/MsPacman-v5', trajectory_path='D_TRAJ/MsPacman.pkl')[0m


In [7]:
dummy_env = build_single_env(args.env_name, conf.basic_settings.image_size, seed=0)
action_dim = dummy_env.action_space.n.item()
del dummy_env

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [8]:
world_model = build_world_model(conf, action_dim)

In [9]:
# i = 1
# updatable_count = 0
# for k, v in get_state_dict(world_model).items():
#     print(i, "- ", k, ": ", v.shape, " -- ", v.requires_grad)
#     if v.requires_grad is None: updatable_count += 1
#     i+=1
# print("Total updatable parameters: ", updatable_count)

# World model safetensors from torch

In [10]:
from safetensors import safe_open

In [11]:
state_dict = {}
with safe_open("world_model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        state_dict[k] = f.get_tensor(k)

In [12]:
state_dict.keys()

dict_keys(['dist_head.post_head.bias', 'dist_head.post_head.weight', 'dist_head.prior_head.bias', 'dist_head.prior_head.weight', 'encoder.backbone.0.weight', 'encoder.backbone.1.bias', 'encoder.backbone.1.num_batches_tracked', 'encoder.backbone.1.running_mean', 'encoder.backbone.1.running_var', 'encoder.backbone.1.weight', 'encoder.backbone.10.bias', 'encoder.backbone.10.num_batches_tracked', 'encoder.backbone.10.running_mean', 'encoder.backbone.10.running_var', 'encoder.backbone.10.weight', 'encoder.backbone.3.weight', 'encoder.backbone.4.bias', 'encoder.backbone.4.num_batches_tracked', 'encoder.backbone.4.running_mean', 'encoder.backbone.4.running_var', 'encoder.backbone.4.weight', 'encoder.backbone.6.weight', 'encoder.backbone.7.bias', 'encoder.backbone.7.num_batches_tracked', 'encoder.backbone.7.running_mean', 'encoder.backbone.7.running_var', 'encoder.backbone.7.weight', 'encoder.backbone.9.weight', 'image_decoder.backbone.0.weight', 'image_decoder.backbone.10.weight', 'image_deco

In [13]:
Tensor(state_dict['dist_head.post_head.bias'].cpu().numpy())

<Tensor <LB CUDA (1024,) contig:True (<LoadOps.COPY: 3>, None)> on CUDA with grad None>

In [14]:
for k, v in get_state_dict(world_model).items():
    if k in state_dict.keys():
        print(k, " -- ", "yes")
    else:
        print(k, " -- ", "no")

encoder.backbone.0.weight  --  yes
encoder.backbone.1.weight  --  yes
encoder.backbone.1.bias  --  yes
encoder.backbone.1.running_mean  --  yes
encoder.backbone.1.running_var  --  yes
encoder.backbone.1.num_batches_tracked  --  yes
encoder.backbone.3.weight  --  yes
encoder.backbone.4.weight  --  yes
encoder.backbone.4.bias  --  yes
encoder.backbone.4.running_mean  --  yes
encoder.backbone.4.running_var  --  yes
encoder.backbone.4.num_batches_tracked  --  yes
encoder.backbone.6.weight  --  yes
encoder.backbone.7.weight  --  yes
encoder.backbone.7.bias  --  yes
encoder.backbone.7.running_mean  --  yes
encoder.backbone.7.running_var  --  yes
encoder.backbone.7.num_batches_tracked  --  yes
encoder.backbone.9.weight  --  yes
encoder.backbone.10.weight  --  yes
encoder.backbone.10.bias  --  yes
encoder.backbone.10.running_mean  --  yes
encoder.backbone.10.running_var  --  yes
encoder.backbone.10.num_batches_tracked  --  yes
storm_transformer.stem.0.weight  --  yes
storm_transformer.stem.1.w

In [16]:
sum([x.numel() for x in get_parameters(world_model)])

16510730

In [None]:
from tqdm import tqdm

model_state_dict = get_state_dict(world_model)
for k, v in tqdm(model_state_dict.items()):
    if k in state_dict.keys():
        if state_dict[k].ndim == 0: t = state_dict[k].reshape(1).cpu().numpy()
        else: t = state_dict[k].cpu().numpy()
        model_state_dict[k].assign(Tensor(t)).realize()

In [None]:
world_model_inputs = {}
with safe_open("world_model_inputs.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        world_model_inputs[k] = f.get_tensor(k)

world_model_inputs = {k: Tensor(v.cpu().numpy()) for k, v in world_model_inputs.items()}

In [None]:
with Tensor.train():
    total_loss = world_model.loss(world_model_inputs["obs"], world_model_inputs["action"], world_model_inputs["reward"], world_model_inputs["termination"],logger=None)
total_loss.backward()

In [None]:
total_loss.numpy()

In [None]:
world_model_outputs = {}
with safe_open("world_model_outputs.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        world_model_outputs[k] = f.get_tensor(k)

In [None]:
grad_out = {}
with safe_open("grad_out.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        grad_out[k] = f.get_tensor(k)

In [None]:
grad_out["encoder.backbone.0.weight.grad"]

In [None]:
grad_out["termination_decoder.backbone.0.weight.grad"]

In [None]:
world_model.encoder.backbone[0].weight

In [None]:
world_model_outputs["total_loss"]

In [None]:
from tinygrad.nn.state import safe_load
wm_out_tiny = safe_load("world_model_output_embed_post_logits.safetensors")

In [None]:
wm_out_tiny["embedding"].numpy()

In [None]:
world_model_outputs["embedding"].cpu().numpy()