In [3]:
#!/usr/bin/env python3
import os
import time

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import yaml
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange
import time
import numpy as np
import math

import wandb
from data.wm_dataset import WorldModelDataset
from data.common import EquiSampler, transpose_collate
from models.world_model import WorldModelRSSM
from trainers.metrics import MetricsHelper

def transpose_collate(batch):
    """transposes batch and time dimension
    (B, T, ...) -> (T, B, ...)"""
    from torch.utils.data._utils.collate import default_collate
    return [torch.transpose(x, 0, 1) for x in default_collate(batch)]


class Featurizer(object):
    """ Featurizer reads in conf, loads world model from checkpoint, loads dataset in train dir
        calculates features for that dataset under that world model
    """
    def __init__(self, dataset_config, conf):
        super().__init__()

        self.dataset_config = dataset_config
        self.conf = conf

        # TODO: Don't hardcode this
        self.device = torch.device("cuda:0")
        self.batch_size = 500
        self.batch_length = 1
        
        #self.groundtruth = np.load("/home/bdemoss/research/div-rl/data/breakout/valf/expertv2-features.npy")

    def get_features(self):
        print("Building WM Dataset for featurizer...")
        dataloader, dataset = self.build_dataloader()
        
        chunk_size = math.ceil(len(dataset)/self.batch_size)

        print("Loading in trained world model...")
        # this is why you need the full conf all the way down here
        wm = WorldModelRSSM(self.conf)
        wm.load_state_dict(torch.load(self.conf["dreamer"]["wm_checkpoint"])['model_state_dict'])
        wm.to(self.device)
        wm.eval()
        wm.requires_grad_(False)

        print("Getting features...")
        features, actions = self.infer_features(dataset, dataloader, wm)
        return features, actions
    
    def build_dataloader(self):
        dataset_config = self.dataset_config.copy()
        dataset_config["batch_length"] = self.batch_length
        dataset_config["rank"] = self.device
        dataset = WorldModelDataset(dataset_config)
        batch_sampler = EquiSampler(
            len(dataset), self.batch_length, self.batch_size, init_idx = 0)
        dataloader = DataLoader(dataset,
                                pin_memory=True,
                                batch_sampler=batch_sampler,
                                collate_fn=transpose_collate)
        return dataloader, dataset
    
    @torch.inference_mode()
    def infer_features(self, dataset, dataloader, wm):
        s = time.time()
        
        # TODO: don't hardcode dim
        feats = np.zeros((len(dataset), 2048), dtype=np.float32)
        acts = np.zeros((len(dataset), 18), dtype=np.float32)
        
        for epoch in range(2):
            # potentially need to do more than 2 epochs in edge cases when
            # episode is longer than chunk_size and extends across 3 chunks
            
            fake_sampler = iter(EquiSampler(len(dataset), 
                                            self.batch_length, 
                                            self.batch_size, 
                                            init_idx = 0))
            if epoch == 0:
                in_states = wm.init_state(self.batch_size)
            else:
                # put the last hidden state of previous chunk
                # as the first hidden state of next chunk
                # and zero out first one (assume index starts at 0)
                in_states = [torch.roll(x, 1, 0) for x in in_states]
                in_states[0][0] = torch.zeros_like(in_states[0][0])
                in_states[1][0] = torch.zeros_like(in_states[1][0])
            for i, batch in enumerate(dataloader):
                action, image, reset = [x.to(self.device) for x in batch]

                obs = {"image": image, "reset": reset, "action": action}

                #with autocast(enabled=True):
                features, out_states = wm(obs, in_states)
                    # T,B,Feats
                    # T = 1 for inference
                in_states = out_states

                features = features.cpu().numpy().squeeze()
                
                idxs = np.array(next(fake_sampler))
                feats[idxs] = features
                acts[idxs] = action.cpu().numpy().squeeze()
                if i == 0 or i == 1:
                    print(i)
                    print("first 5 idxs",idxs[:5])
                    print("first 5 actions", action[:5, :5])
                    print('NEXT')
        print("Feature inference took", int(time.time()-s), "seconds")
        return feats, acts
    
    @torch.inference_mode()
    def infer_features2(self, dataset):
        """Gets the learned state features for this 
        dataset using whichever WorldModel we loaded."""
        # size of the final feature array will be (transitions, feature_dim)
        s = time.time()
        fake_sampler = iter(EquiSampler(len(self.dataset), 
                                            self.batch_length, 
                                            self.batch_size, 
                                            init_idx = 0))
        feats = np.zeros((len(self.dataset), 2048), dtype=np.float32)
        
        in_states = self.wm.init_state(self.batch_size)
        
        chunk_size = self.chunk_size
        
        print('in infer_features2')
        for i, batch in enumerate(self.dataloader):
            action, image, reset = [x.to(self.device) for x in batch]

            obs = {"image": image, "reset": reset, "action": action}

            with autocast(enabled=True):
                features, out_states = self.wm(obs, in_states)
                # T,B,Feats
            in_states = out_states
            
            features = features.cpu().numpy().squeeze()
            
            feats[np.array(next(fake_sampler))] = features
        print("Feature2 inference took", int(time.time()-s), "seconds")
        return feats
    
    @torch.inference_mode()
    def infer_features_exact(self, dataset, dataloader, wm):
        """Gets the learned state features for this 
        dataset using whichever WorldModel we loaded."""
        
        # size of the final feature array will be (transitions, feature_dim)
        print("Begin inference")
        s = time.time()
        feature_list = []
        in_states = wm.init_state(1)
        for i in range(len(dataset)):
            action, image, reset = [x.to(self.device) for x in dataset.get_trans(i)]
            obs = {"image": image, "reset": reset, "action": action}

            with autocast(enabled=False):
                features, out_states = wm(obs, in_states)
            feature_list.append(features.cpu().numpy().squeeze())
            in_states = out_states
            if i%10000 == 0:
                print("inferred", i, "features in", int(time.time()-s), "seconds")

        print("got", len(feature_list), "features")

        print("Feature inference took", int(time.time()-s), "seconds")
        features = np.stack(feature_list)
        return features


def load_conf(config_file, env_name):
    with open(config_file, "r") as f:
        raw_conf = yaml.safe_load(f)
    base_conf = raw_conf["base"]
    data_conf = raw_conf["data"]
    env_conf = raw_conf[env_name]
    conf = {**base_conf, **env_conf}
    if "dreamer" in raw_conf:
        dreamer_conf = raw_conf["dreamer"]
        conf["dreamer"] = dreamer_conf
    conf["data"] = data_conf
    return conf


def run_trainer(conf):
    trainer = Featurizer(conf["data"]["train"], conf)
    #wandb.watch(trainer.model)
    #trainer.train()
    return trainer


def main():
    config_file = "../config/config_agent.yaml"
    env_name = "atari_breakout"
    conf = load_conf(config_file, env_name)

    #wandb.login()
    #wandb.init(project="world-model")
    trainer = run_trainer(conf)
    # launch_ddp_train(conf)
    return trainer




In [4]:
trainer = main()
features, actions = trainer.get_features()

Building WM Dataset for featurizer...
Building WorldModelDataset for ['breakout-expert-v2']


A.L.E: Arcade Learning Environment (version 0.7.5+db37282)
[Powered by Stella]


post fix in fix_actions [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
total transitions: 1000000
num batch elements: 1000000
Chunk size: 2000
n steps: 2000
Loading in trained world model...
Getting features...
Chunk size: 2000
n steps: 2000
iters[:5] [0, 2000, 4000, 6000, 8000]
iters[:5] [0, 2000, 4000, 6000, 8000]
0
first 5 idxs [   0 2000 4000 6000 8000]
first 5 actions tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.],
         [0.,

In [3]:
print(features.shape, actions.shape)

In [4]:
gt = np.load("/home/bdemoss/research/div-rl/data/breakout/valf/expertv2-features.npy")
gta = np.load("/home/bdemoss/research/div-rl/data/breakout/valfeatures-expertv2-autocast.npy").squeeze()

In [10]:
device = torch.device("cuda:0")
config_file = "../config/config_agent.yaml"
env_name = "atari_breakout"
conf = load_conf(config_file, env_name)
wm = WorldModelRSSM(conf)
wm.load_state_dict(torch.load(conf["dreamer"]["wm_checkpoint"])['model_state_dict'])
wm.to(device)
wm.eval()
wm.requires_grad_(False)

WorldModelRSSM(
  (encoder): EncoderModel(
    (encoder): CnnEncoder(
      (model): Sequential(
        (0): Conv2d(1, 48, kernel_size=(4, 4), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(48, 96, kernel_size=(4, 4), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(96, 192, kernel_size=(4, 4), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): Conv2d(192, 384, kernel_size=(4, 4), stride=(2, 2))
        (7): ELU(alpha=1.0)
        (8): Flatten(start_dim=1, end_dim=-1)
      )
    )
  )
  (rssm_core): RSSMCore(
    (cell): RSSMCell(
      (z_mlp): Linear(in_features=1024, out_features=1000, bias=True)
      (a_mlp): Linear(in_features=18, out_features=1000, bias=False)
      (in_norm): LayerNorm((1000,), eps=0.001, elementwise_affine=True)
      (gru): GRUCellStack(
        (layers): ModuleList(
          (0): GRUCell(1000, 1024)
        )
      )
      (prior_mlp_h): Linear(in_features=1024, out_features=1000, bias=True)
      (prior_norm): LayerNorm

In [14]:
total_diff = np.sum(np.square(gta[:,:1024] - features[:,:1024]))/1000000.
print(total_diff)

In [30]:
b = np.load("/home/bdemoss/research/div-rl/data/breakout/valf/breakout-expert-v2.npz")
print(b["action"].shape)
print(actions.shape)
np.any(b["action"] == actions)
i = 1
np.all(b["action"][i] == actions[i])
actions[:5]
b["action"][:5]

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]])

In [23]:
import d4rl_atari
import gym
env = gym.make('breakout-expert-v2')
data = env.get_dataset()
data["actions"].shape

(1000000,)

In [25]:
data["actions"][:10]

array([3, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=int32)

In [138]:
def cdiff(x,y):
    x=x.squeeze()
    y=y.squeeze()
    randiff = np.sum(np.square(x[1024:]-y[1024:]))
    detdiff = np.sum(np.square(x[:1024]-y[:1024]))
    return detdiff, randiff

index = 666
in_states = torch.tensor(trainer.ampfeatures[index]).unsqueeze(0).to(trainer.device)
in_states = (in_states[..., :1024], in_states[..., 1024:])
action, image, reset = [x.to(trainer.device) for x in trainer.dataset.get_trans(index)]

print('action:',action,action.shape)
waction = torch.tensor([[[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.]]]).to(trainer.device)
print('waction:',waction,waction.shape)

obs = {"image": image, "reset": reset, "action": action}
wobs = {"image": image, "reset": reset, "action": waction}

with autocast(enabled=True):
    features, out_states = trainer.wm(obs, in_states)
with autocast(enabled=True):
    wfeatures, wout_states = trainer.wm(wobs, in_states)
    
features = features.detach().cpu().numpy().squeeze()
wfeatures = wfeatures.detach().cpu().numpy().squeeze()

print('differences',cdiff(features,wfeatures))

print(features.shape)
print([out_states[x].shape for x in range(2)])


action: tensor([[[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.]]], device='cuda:0') torch.Size([1, 1, 18])
waction: tensor([[[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.]]], device='cuda:0') torch.Size([1, 1, 18])
differences (0.0, 18.0)
(2048,)
[torch.Size([1, 1024]), torch.Size([1, 1024])]


In [3]:
from data.common import EquiSampler
sampler = EquiSampler(1000000, 1, 100, init_idx = 0)
sampler2 = EquiSampler(1000000, 1, 100, init_idx = 0)
batch_sampler = iter(sampler)
batch_sampler2 = iter(sampler2)
print(len(sampler))
print(len(sampler2))

idxs = []
idxs2 = []

while True:
    n = next(batch_sampler, None)
    nn = next(batch_sampler2, None)
    if n is None:
        break
    idxs.append(n)
    idxs2.append(nn)

print(idxs[0])


print(idxs2[0])

print(idxs2==idxs)

Chunk size: 10000
n steps: 10000
Chunk size: 10000
n steps: 10000
10000
10000
[0, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000, 110000, 120000, 130000, 140000, 150000, 160000, 170000, 180000, 190000, 200000, 210000, 220000, 230000, 240000, 250000, 260000, 270000, 280000, 290000, 300000, 310000, 320000, 330000, 340000, 350000, 360000, 370000, 380000, 390000, 400000, 410000, 420000, 430000, 440000, 450000, 460000, 470000, 480000, 490000, 500000, 510000, 520000, 530000, 540000, 550000, 560000, 570000, 580000, 590000, 600000, 610000, 620000, 630000, 640000, 650000, 660000, 670000, 680000, 690000, 700000, 710000, 720000, 730000, 740000, 750000, 760000, 770000, 780000, 790000, 800000, 810000, 820000, 830000, 840000, 850000, 860000, 870000, 880000, 890000, 900000, 910000, 920000, 930000, 940000, 950000, 960000, 970000, 980000, 990000]
[0, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000, 110000, 120000, 130000, 140000, 150000, 160000, 170000, 1800