In [2]:
import os
import json
import tqdm
import torch
import pickle
import numpy as np
import pandas as pd
import os
import ninjax
from pathlib import Path
from typing import Tuple, List, Dict, Any
from torch.utils.data import DataLoader

In [None]:
root = Path(os.getcwd()).parent
os.sys.path.append(str(root))
os.sys.path.append(str(root / "dreamerv3"))

In [16]:
# import locals
from models.agent_probe import Agent, sg
from embodied.envs import crafter

# Functions

In [17]:
def load_dataset(path='./dataset/crafter') -> Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]]:
    img_pos = []
    img_epi = []

    stats_files = os.listdir(path)
    path = Path(path)
    for file in stats_files:
        with open(path / file, 'r') as f:
            data = [json.loads(line) for line in f]
        
        for d in data:
            img_pos.append({"image": np.array(d['image']), 'pos': np.array(d['pos'])})
            img_epi.append({"image": np.array(d['image']), 'episode': np.array(d['episode'])})

    return img_pos, img_epi

In [18]:
def load_agent(path_to_pkl="./models/agent.pkl"):
    path_to_pkl = Path(path_to_pkl)
    with open(path_to_pkl, 'rb') as f:
        agent = pickle.load(f)
    print("Agent loaded from", path_to_pkl)
    print("Agent type:", type(agent))
    return agent

In [19]:
def overwrite_func(agent: Agent):

    # overwrite call function in decoder
    new_dec_call = lambda carry, feat, reset, training, single=False: Agent.__call__(agent, carry, feat, reset, training, single=False)
    agent.dec['simple'].__call__ = new_dec_call

    # overwrite policy function in agent
    new_agent_policy = lambda carry, obs, mode='train': Agent.policy(agent, carry, obs, mode='train')
    agent.policy = new_agent_policy

    return agent

# Initialization

In [None]:
# environment
TASK = "reward"
SEED = 0
seed = hash((SEED, 0)) % (2 ** 32 - 1)
env = crafter.Crafter(seed=seed, task=TASK)

# agent
# kwargs = {"size": [64, 64], "logs": False, "use_seed": True, "use_logdir": False}
# agent = Agent(env.obs_space, env.act_space, **kwargs)
agent = load_agent()
agent = overwrite_func(agent)

In [28]:
# dataset
img_pos, img_epi = load_dataset("./test/")
print(len(img_pos), len(img_epi))
assert len(img_pos) == len(img_epi)

BSIZE = 16
dl_space = DataLoader(img_pos, batch_size=BSIZE, shuffle=True)
dl_time = DataLoader(img_epi, batch_size=BSIZE, shuffle=True)

1066 1066


In [10]:
# test = [{"image": np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8), "pos": np.random.randint(0, 100, (2,), dtype=np.int32)} for _ in range(32)]
# BSIZE = 16

# dl_space = DataLoader(test, batch_size=BSIZE, shuffle=True)

# for batch in dl_space:
#     print(batch.keys())
#     print(batch['image'].shape, batch['pos'].shape)
#     print(batch['pos'])

dict_keys(['image', 'pos'])
torch.Size([16, 64, 64, 3]) torch.Size([16, 2])
tensor([[20, 23],
        [27, 47],
        [99, 76],
        [96, 41],
        [50, 34],
        [32, 87],
        [45, 23],
        [29, 84],
        [23, 65],
        [31, 58],
        [35,  7],
        [35, 28],
        [68, 19],
        [56, 68],
        [98, 82],
        [91, 87]], dtype=torch.int32)
dict_keys(['image', 'pos'])
torch.Size([16, 64, 64, 3]) torch.Size([16, 2])
tensor([[99, 46],
        [95, 76],
        [72, 74],
        [ 8, 44],
        [60, 95],
        [42, 94],
        [58, 40],
        [ 4, 50],
        [40, 97],
        [58, 99],
        [52, 22],
        [89, 26],
        [66, 31],
        [36, 44],
        [57, 43],
        [59, 34]], dtype=torch.int32)


# Activations

## spatial activations 

In [None]:
carry = agent.init_policy(batch_size=BSIZE)

path_to_save = Path('./dataset/act/')

for idx, batch in enumerate(tqdm.tqdm(dl_space)):
    img = batch['image'].numpy()
    pos = batch['pos'].numpy()
    img = sg(img)
    pos = sg(pos)
    
    assert type(img) == np.ndarray, type(img)
    assert type(pos) == np.ndarray, type(pos)
    
    carry = sg(carry)

    carry, act, out, activations = agent(carry, img, training=False)

    deter = activations['dynamic']['deter']
    stoch = activations['dynamic']['stoch']
    logit = activations['dynamic']['logit']

    deter = np.array(sg(deter))
    stoch = np.array(sg(stoch))
    logit = np.array(sg(logit))
    enc = np.array(sg(activations['encoder']))
    dec = np.array(sg(activations['decoder']))
    pol = np.array(sg(activations['policy']))

    print(deter.shape, stoch.shape, logit.shape, enc.shape, dec.shape, pol.shape)
    
    np.savez_compressed(
        path_to_save / f'idx{idx}.npz', 
        img=img, 
        pos=pos, 
        deter=deter, 
        stoch=stoch, 
        logit=logit, 
        enc=enc, 
        dec=dec, 
        pol=pol
    )

## time activations

In [None]:
carry = agent.init_policy(batch_size=BSIZE)

path_to_save = Path('./dataset/act/')

for idx, batch in enumerate(tqdm.tqdm(dl_time)):
    img = batch['image'].numpy()
    pos = batch['pos'].numpy()
    img = sg(img)
    pos = sg(pos)
    
    assert type(img) == np.ndarray, type(img)
    assert type(pos) == np.ndarray, type(pos)
    
    carry = sg(carry)

    carry, act, out, activations = agent(carry, img, training=False)

    deter = activations['dynamic']['deter']
    stoch = activations['dynamic']['stoch']
    logit = activations['dynamic']['logit']

    deter = np.array(sg(deter))
    stoch = np.array(sg(stoch))
    logit = np.array(sg(logit))
    enc = np.array(sg(activations['encoder']))
    dec = np.array(sg(activations['decoder']))
    pol = np.array(sg(activations['policy']))

    print(deter.shape, stoch.shape, logit.shape, enc.shape, dec.shape, pol.shape)
    
    np.savez_compressed(
        path_to_save / f'idx{idx}.npz', 
        img=img, 
        pos=pos, 
        deter=deter, 
        stoch=stoch, 
        logit=logit, 
        enc=enc, 
        dec=dec, 
        pol=pol
    )