In [1]:
import os
import json
import tqdm
import torch
import pickle
import numpy as np
import pandas as pd
import os
import ninjax as nj
from pathlib import Path
from typing import Tuple, List, Dict, Any
from torch.utils.data import DataLoader
import elements
import jax.numpy as jnp
import ruamel.yaml as yaml

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
!pip install -r ../dreamerv3/requirements.txt -U

[31mERROR: Operation cancelled by user[0m[31m
[0m

In [None]:
!pip install elements
!pip install ninjax
!pip install crafter
!pip install jax
!pip install ale-py

In [4]:
root = Path(os.getcwd()).parent
os.sys.path.append(str(root))
os.sys.path.append(str(root / "dreamerv3"))

In [5]:
# import locals
from models.agent_probe import Agent, sg
from models.rssm_probe import Decoder
from dreamerv3.agent import Agent as AgentOri
from embodied.envs import crafter
from embodied.jax import transform

# Functions

In [6]:
def load_config(argv=None):
    # adaptation from dreamerv3/main.py
    configs = elements.Path('./configs.yaml').read()
    configs = yaml.YAML(typ='safe').load(configs)
    parsed, other = elements.Flags(configs=['defaults']).parse_known(argv)
    config = elements.Config(configs['defaults'])
    for name in parsed.configs:
        config = config.update(configs[name])
    # config = elements.Flags(config).parse(other)
    config = config.update(logdir=(
        config.logdir.format(timestamp=elements.timestamp())))

    if 'JOB_COMPLETION_INDEX' in os.environ:
        config = config.update(replica=int(os.environ['JOB_COMPLETION_INDEX']))
    print('Replica:', config.replica, '/', config.replicas)

    logdir = elements.Path(config.logdir)
    print('Logdir:', logdir)
    print('Run script:', config.script)

    return config

  and should_run_async(code)


In [7]:
def load_dataset(path='./dataset/crafter') -> Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]]:
    img_pos = []
    img_epi = []

    stats_files = os.listdir(path)
    path = Path(path)
    for file in stats_files:
        with open(path / file, 'r') as f:
            data = [json.loads(line) for line in f]

        for d in data:
            img_pos.append({"image": np.array(d['image']), 'pos': np.array(d['pos'])})
            img_epi.append({"image": np.array(d['image']), 'episode': np.array(d['episode'])})

    return img_pos, img_epi

In [8]:
def load_agent(agent, path_to_pkl="./models/agent.pkl"):
    path_to_pkl = Path(path_to_pkl)
    with open(path_to_pkl, 'rb') as f:
        weight = pickle.load(f)['agent']
    print("Agent loaded from", path_to_pkl)
    print("Agent type:", type(weight))
    print(type(agent))
    agent.load(weight)
    print(type(agent))
    return agent

In [9]:
import inspect
def overwrite_func(agent):
    # overwrite call function in decoder
    new_dec_call = lambda carry, feat, reset, training, single=False: Decoder.__call__(agent.model.dec, carry, feat, reset, training, single=False)
    # print("new_dec_call: ")
    # lines = inspect.getsource(Decoder.__call__)
    agent.model.dec.__call__ = new_dec_call

    # overwrite policy function in agent
    new_agent_policy = lambda carry, obs, mode='train': Agent.policy(agent.model, carry, obs, mode='train')
    agent.model.policy = new_agent_policy

    return agent

In [10]:
def reconstruct_obs_dummy(img):
    batch_size = img.shape[0]
    obs = dict(
        image=img,
        reward=np.stack([np.float32(0) for _ in range(batch_size)]),
        is_first=np.stack([False]*batch_size),
        is_last=np.stack([False]*batch_size),
        is_terminal=np.stack([False]*batch_size),
    )
    return obs

In [11]:
import jax
def collect_activations(carry, obs, agent: Agent):
    # what is carry?: weight?
    # prevact = previous action
    (enc_carry, dyn_carry, dec_carry, prevact) = carry
    kw = dict(training=False, single=True)
    reset = obs['is_first']

    # editted
    activations = {}

    # what is tokens?
    enc_carry, enc_entry, tokens = agent.enc(enc_carry, obs, reset, **kw)

    # editted
    activations['encoder'] = tokens

    dyn_carry, dyn_entry, feat = agent.dyn.observe(
        dyn_carry, tokens, prevact, reset, **kw)

    # editted
    activations['dynamic'] = feat

    dec_entry = {}
    # if dec_carry:
    dec_carry, dec_entry, recons = agent.dec(dec_carry, feat, reset, **kw)

    # editted
    activations['decoder'] = recons['x']

    # actor: linear transformation of RSSM feature into action space distribution
    sample = lambda xs: jax.tree.map(lambda x: x.sample(nj.seed()), xs)
    policy = agent.pol(agent.feat2tensor(feat), bdims=1)
    act = sample(policy)

    # editted
    activations['policy'] = policy

    out = {}
    out['finite'] = elements.tree.flatdict(jax.tree.map(
        lambda x: jnp.isfinite(x).all(range(1, x.ndim)),
        dict(obs=obs, carry=carry, tokens=tokens, feat=feat, act=act)))
    carry = (enc_carry, dyn_carry, dec_carry, act)
    return carry, act, out, activations

# Initialization

In [None]:
config = load_config()

# environment
SEED = 0
TASK = "reward"
INDEX = 0
seed = hash((SEED, INDEX)) % (2 ** 32 - 1)
env = crafter.Crafter(seed=seed, task=TASK)
  # adaptation from dreamerv3/main.py
notlog = lambda k: not k.startswith('log/')
obs_space = {k: v for k, v in env.obs_space.items() if notlog(k)}
act_space = {k: v for k, v in env.act_space.items() if k != 'reset'}

env.close()

# agent
# adaptation from dreamerv3/main.py
config = elements.Config(
      **config.agent,
      logdir=config.logdir,
      seed=config.seed,
      jax=config.jax,
      batch_size=config.batch_size,
      batch_length=config.batch_length,
      replay_context=config.replay_context,
      report_length=config.report_length,
      replica=config.replica,
      replicas=config.replicas,
  )
agentori = AgentOri(obs_space, act_space, config)

Replica: 0 / 1
Logdir: /root/logdir/20250112T234839
Run script: test
Observations
  image            Space(uint8, shape=(64, 64, 3), low=0, high=255)
  reward           Space(float32, shape=(), low=-inf, high=inf)
  is_first         Space(bool, shape=(), low=False, high=True)
  is_last          Space(bool, shape=(), low=False, high=True)
  is_terminal      Space(bool, shape=(), low=False, high=True)
Actions
  action           Space(int32, shape=(), low=0, high=17)
Extras
  consec           Space(int32, shape=(), low=-2147483648, high=2147483647)
  stepid           Space(uint8, shape=(20,), low=0, high=255)
  dyn/deter        Space(float32, shape=(8192,), low=-inf, high=inf)
  dyn/stoch        Space(float32, shape=(32, 64), low=-inf, high=inf)
JAX devices (1): [cuda:0]
Policy devices: cuda:0
Train devices:  cuda:0
Initializing parameters...


In [None]:
agentori = overwrite_func(agentori)

pm, ps = agentori.policy_mirrored, agentori.policy_sharded
tp, pp = agentori.train_params_sharding, agentori.policy_params_sharding
_, ar = agentori.partition_rules
shared_kwargs = {'use_shardmap': agentori.jaxcfg.use_shardmap}

agentori._policy = transform.apply(
    nj.pure(agentori.model.policy), agentori.policy_mesh,
    (pp, pm, ps, ps), (ps, ps, ps), ar,
    static_argnums=(4,), **shared_kwargs)

TEST = "./models/checkpoint_test.pkl"
agentori = load_agent(agentori, TEST)

agent = agentori.model

In [None]:
# dataset
img_pos, img_epi = load_dataset("./test/")
print(len(img_pos), len(img_epi))
assert len(img_pos) == len(img_epi)

BSIZE = config.batch_size
dl_space = DataLoader(img_pos, batch_size=BSIZE, shuffle=True)
dl_time = DataLoader(img_epi, batch_size=BSIZE, shuffle=True)

# Activations

In [None]:
jax.config.update("jax_transfer_guard", "allow")
from ninjax import pure

## spatial activations

In [None]:
carry = agentori.init_policy(batch_size=BSIZE)

path_to_save = Path('./dataset/act/')
os.makedirs(path_to_save, exist_ok=True)

images = []
positions = []
deters = []
stochs = []
logits = []
encoders = []
decoders = []
policies = []

dtype=jnp.uint8

cpu = lambda x: jax.device_put(x, device=jax.devices('cpu')[0]).astype(dtype)
cuda = lambda x: jax.device_put(x, device=jax.devices()[0]).astype(dtype)


for idx, batch in enumerate(tqdm.tqdm(dl_space)):
    if idx == 2:
        break

    img = jnp.asarray(batch['image'].numpy())
    pos = jnp.asarray(batch['pos'].numpy())
    img = cuda(img)
    pos = cuda(pos)
    # img = jnp.asarray(img)
    # pos = jnp.asarray(pos)

    img = sg(img)
    pos = sg(pos)

    carry = sg(carry)
    obs = reconstruct_obs_dummy(img)

    carry, act, out, activations = agentori.policy(carry, obs)

    deter = activations['dynamic']['deter']
    stoch = activations['dynamic']['stoch']
    logit = activations['dynamic']['logit']

    deter = np.array(sg(deter))
    stoch = np.array(sg(stoch))
    logit = np.array(sg(logit))
    enc = np.array(sg(activations['encoder']))
    dec = np.array(sg(activations['decoder']))
    pol = np.array(sg(activations['policy']))

    print(deter.shape, stoch.shape, logit.shape, enc.shape, dec.shape, pol.shape)

    images.append(img)
    positions.append(pos)
    deters.append(deter)
    stochs.append(stoch)
    logits.append(logit)
    encoders.append(enc)
    decoders.append(dec)
    policies.append(pol)

images = np.stack(images, axis=0)
positions = np.stack(positions, axis=0)
deters = np.stack(deters, axis=0)
stochs = np.stack(stochs, axis=0)
logits = np.stack(logits, axis=0)
encoders = np.stack(encoders, axis=0)
decoders = np.stack(decoders, axis=0)
policies = np.stack(policies, axis=0)

np.savez_compressed(
    path_to_save / f'space.npz',
    images=images,
    positions=positions,
    deters=deters,
    stochs=stochs,
    logits=logits,
    encoders=encoders,
    decoders=decoders,
    policies=policies
)

## temporal activations

In [None]:
carry = agent.init_policy(batch_size=BSIZE)

path_to_save = Path('./dataset/act/')

images = []
positions = []
deters = []
stochs = []
logits = []
encoders = []
decoders = []
policies = []

for idx, batch in enumerate(tqdm.tqdm(dl_time)):
    img = batch['image'].numpy()
    epi = batch['episode'].numpy()
    img = sg(img)
    epi = sg(epi)

    assert type(img) == np.ndarray, type(img)
    assert type(epi) == np.ndarray, type(epi)

    carry = sg(carry)

    carry, act, out, activations = agent(carry, img, training=False)

    deter = activations['dynamic']['deter']
    stoch = activations['dynamic']['stoch']
    logit = activations['dynamic']['logit']

    deter = np.array(sg(deter))
    stoch = np.array(sg(stoch))
    logit = np.array(sg(logit))
    enc = np.array(sg(activations['encoder']))
    dec = np.array(sg(activations['decoder']))
    pol = np.array(sg(activations['policy']))

    print(deter.shape, stoch.shape, logit.shape, enc.shape, dec.shape, pol.shape)

    images.append(img)
    positions.append(pos)
    deters.append(deter)
    stochs.append(stoch)
    logits.append(logit)
    encoders.append(enc)
    decoders.append(dec)
    policies.append(pol)

images = np.stack(images, axis=0)
positions = np.stack(positions, axis=0)
deters = np.stack(deters, axis=0)
stochs = np.stack(stochs, axis=0)
logits = np.stack(logits, axis=0)
encoders = np.stack(encoders, axis=0)
decoders = np.stack(decoders, axis=0)
policies = np.stack(policies, axis=0)

np.savez_compressed(
    path_to_save / f'time.npz',
    images=images,
    positions=positions,
    deters=deters,
    stochs=stochs,
    logits=logits,
    encoders=encoders,
    decoders=decoders,
    policies=policies
)