In [1]:
from agents.forward_backward.replay_buffer import FBReplayBuffer
from agents.utils import set_seed_everywhere
import warnings
from pathlib import Path
import yaml
from dm_control import suite
import torch

warnings.filterwarnings("ignore", category=DeprecationWarning)
working_dir = Path.cwd().parent
config_path = working_dir / "agents" / "forward_backward" / "config.yaml"
model_path = working_dir / "agents" / "forward_backward" / "saved_models"

with open(config_path, "rb") as f:
    config = yaml.safe_load(f)

set_seed_everywhere(config["seed"])
data_dir = (
    working_dir
    / "datasets"
    / config["domain_name"]
    / config["exploration_algorithm"]
    / "buffer"
)

env = suite.load(
    domain_name=config["domain_name"],
    task_name=config["task_name"],
    task_kwargs={"random": config["seed"]},
)

replay_buffer = FBReplayBuffer(
    env=env,
    data_dir=data_dir,
    max_episodes=1000,
    discount=0.99,
    shuffle_data_across_episodes=True,
    device=torch.device("cpu"),
)

 11%|█         | 1000/8928 [00:58<07:40, 17.23it/s]


In [11]:
replay_buffer.storage.keys()

dict_keys(['observation', 'action', 'reward', 'discount', 'physics', 'skill'])

In [15]:
replay_buffer.storage["observation"][0]

array([[ 0.84898543, -0.23791558, -0.96012914, -0.999999  , -0.94315654,
        -0.68368876],
       [-0.1374791 , -0.22520307, -0.999999  , -0.94550586, -0.9208139 ,
         0.6910253 ],
       [ 0.7811452 ,  0.5230211 , -0.61812615, -0.05428113,  0.999999  ,
         0.33884507],
       ...,
       [-0.74251556,  0.9129435 , -0.28740636, -0.4348122 , -0.3442296 ,
         0.508444  ],
       [ 0.32568395,  0.3202653 ,  0.4224036 , -0.24123122, -0.26647812,
         0.45888558],
       [-0.8553836 , -0.41908494, -0.47923994,  0.5223094 , -0.20252648,
        -0.3121467 ]], dtype=float32)

In [3]:
torch.backends.mps.is_available()

True

In [6]:
import wandb
# Initialize wandb run
agent = torch.load("forward_backward/saved_models/forward_backward_long.pickle")

In [7]:
agent

FBAgent(
  (FB): ForwardBackwardRepresentation(
    (forward_representation): ForwardRepresentation(
      (F1): ForwardModel(
        (trunk): Sequential(
          (0): Linear(in_features=1024, out_features=1024, bias=True)
          (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (2): Tanh()
          (3): Linear(in_features=1024, out_features=1024, bias=True)
          (4): ReLU()
          (5): Linear(in_features=1024, out_features=50, bias=True)
        )
        (obs_action_preprocessor): AbstractPreprocessor(
          (trunk): Sequential(
            (0): Linear(in_features=30, out_features=1024, bias=True)
            (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (2): Tanh()
            (3): Linear(in_features=1024, out_features=1024, bias=True)
            (4): ReLU()
            (5): Linear(in_features=1024, out_features=512, bias=True)
          )
        )
        (obs_z_preprocessor): AbstractPreprocessor(
          (trunk)