In [None]:
import argparse
from omegaconf import OmegaConf
import gymnasium
from world_models import WorldModel

import colorama
import numpy as np
from tinygrad import Tensor, nn
from tinygrad.nn.state import get_parameters, get_state_dict

import env_wrapper

In [None]:
def build_single_env(env_name, image_size, seed):
  env = gymnasium.make(env_name, full_action_space=False, render_mode="rgb_array", frameskip=1)
  env = env_wrapper.SeedEnvWrapper(env, seed=seed)
  env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4)
  env = gymnasium.wrappers.ResizeObservation(env, shape=image_size)
  env = env_wrapper.LifeLossInfo(env)
  return env

In [None]:
def build_world_model(conf, action_dim):
    return WorldModel(
        in_channels=conf.models.world_model.in_channels,
        action_dim=action_dim,
        transformer_max_length=conf.models.world_model.transformer_max_length,
        transformer_hidden_dim=conf.models.world_model.transformer_hidden_dim,
        transformer_num_layers=conf.models.world_model.transformer_num_layers,
        transformer_num_heads=conf.models.world_model.transformer_num_heads
    )

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("-n", type=str, required=True)
parser.add_argument("-seed", type=int, required=True)
parser.add_argument("-config_path", type=str, required=True)
parser.add_argument("-env_name", type=str, required=True)
parser.add_argument("-trajectory_path", type=str, required=True)
args = parser.parse_args([
    "-n", "MsPacman",
    "-seed", "1",
    "-config_path", "STORM.yaml",
    "-env_name", "ALE/MsPacman-v5",
    "-trajectory_path", "D_TRAJ/MsPacman.pkl"
])
conf = OmegaConf.load(args.config_path)

In [None]:
print(colorama.Fore.RED + str(args) + colorama.Style.RESET_ALL)

np.random.seed(args.seed)
Tensor.manual_seed(args.seed)

In [None]:
dummy_env = build_single_env(args.env_name, conf.basic_settings.image_size, seed=0)
action_dim = dummy_env.action_space.n.item()
del dummy_env

In [None]:
world_model = build_world_model(conf, action_dim)

In [None]:
# i = 1
# updatable_count = 0
# for k, v in get_state_dict(world_model).items():
#     print(i, "- ", k, ": ", v.shape, " -- ", v.requires_grad)
#     if v.requires_grad is None: updatable_count += 1
#     i+=1
# print("Total updatable parameters: ", updatable_count)

# World model safetensors from torch

In [None]:
from safetensors import safe_open

In [None]:
state_dict = {}
with safe_open("world_model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        state_dict[k] = f.get_tensor(k)

In [None]:
state_dict.keys()

In [None]:
Tensor(state_dict['dist_head.post_head.bias'].cpu().numpy())

In [None]:
for k, v in get_state_dict(world_model).items():
    if k in state_dict.keys():
        print(k, " -- ", "yes")
    else:
        print(k, " -- ", "no")

In [None]:
from tqdm import tqdm

model_state_dict = get_state_dict(world_model)
for k, v in tqdm(model_state_dict.items()):
    if k in state_dict.keys():
        if state_dict[k].ndim == 0: t = state_dict[k].reshape(1).cpu().numpy()
        else: t = state_dict[k].cpu().numpy()
        model_state_dict[k].assign(Tensor(t)).realize()

In [None]:
world_model_inputs = {}
with safe_open("world_model_inputs.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        world_model_inputs[k] = f.get_tensor(k)

world_model_inputs = {k: Tensor(v.cpu().numpy()) for k, v in world_model_inputs.items()}

In [None]:
with Tensor.train():
    total_loss = world_model.loss(world_model_inputs["obs"], world_model_inputs["action"], world_model_inputs["reward"], world_model_inputs["termination"],logger=None)
total_loss.backward()

In [None]:
total_loss.numpy()

In [None]:
world_model_outputs = {}
with safe_open("world_model_outputs.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        world_model_outputs[k] = f.get_tensor(k)

In [None]:
grad_out = {}
with safe_open("grad_out.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        grad_out[k] = f.get_tensor(k)

In [None]:
grad_out["encoder.backbone.0.weight.grad"]

In [None]:
grad_out["termination_decoder.backbone.0.weight.grad"]

In [None]:
world_model.encoder.backbone[0].weight

In [None]:
world_model_outputs["total_loss"]

In [None]:
from tinygrad.nn.state import safe_load
wm_out_tiny = safe_load("world_model_output_embed_post_logits.safetensors")

In [None]:
wm_out_tiny["embedding"].numpy()

In [None]:
world_model_outputs["embedding"].cpu().numpy()