In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os
import yaml
import pathlib
from types import SimpleNamespace # Used to mimic argparse.Namespace
import mani_skill
import exploration as expl
import models
import tools
import envs.wrappers as wrappers
from parallel import Parallel, Damy
import torch
from torch import nn
from torch import distributions as torchd
import datetime
to_np = lambda x: x.detach().cpu().numpy()
from dreamer import make_env, count_steps, Dreamer, make_dataset
import functools

In [None]:
# --- Configuration Loading ---
def recursive_update(base, update):
    """Recursively updates a dictionary `base` with values from `update`."""
    for key, value in update.items():
        if isinstance(value, dict) and key in base:
            recursive_update(base[key], value)
        else:
            base[key] = value

def load_configs(config_names=None):
    """
    Loads configurations from 'configs.yaml', applying defaults and
    specified overrides.
    """
    # Adjust this path if configs.yaml is not in the same directory
    config_path = pathlib.Path("~/dreamerv3-torch/configs.yaml").expanduser()
    
    if not config_path.exists():mani_skill
    all_configs = yaml.safe_load(config_path.read_text())

    name_list = ["defaults", *config_names] if config_names else ["defaults"]
    
    final_config = {}
    for name in name_list:
        if name not in all_configs:
            print(f"Warning: Configuration '{name}' not found in configs.yaml. Skipping.")
            continue
        recursive_update(final_config, all_configs[name])
    

    for k,v in final_config.items():
        if isinstance(v, dict): # TODO: only one level, need recurse
            for dk, dv in v.items():
                v[dk] = tools.args_type(dv)(dv)
        else:
            final_config[k] = tools.args_type(v)(v)

    # Convert the dictionary to a SimpleNamespace to mimic argparse.Namespace
    return SimpleNamespace(**final_config)



# --- Mimicking Command Line Arguments (for Jupyter) ---
config = load_configs(config_names=['maniskill'])
# config = load_configs(config_names=['dmc_vision'])

print(config.task, config.units, type(config.actor['lr']))

In [None]:
tools.set_seed_everywhere(config.seed)
if config.deterministic_run:
    tools.enable_deterministic_run()

if not config.logdir:
    config.logdir = f"./logs/{datetime.datetime.now().strftime(format='%d_%m_%y/%H:%M:%S')}"
    
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps"
config.demodir = config.traindir or logdir / "demo_eps"
config.evaldir = config.evaldir or logdir / "eval_eps"
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat

if type(config.logdir) == str: config.traindir = pathlib.Path(config.traindir)
if type(config.demodir) == str: config.demodir =  pathlib.Path(config.demodir)

print("Logdir", logdir, )
logdir.mkdir(parents=True, exist_ok=True)
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
config.demodir.mkdir(parents=True, exist_ok=True)

step = count_steps(config.traindir) + count_steps(config.demodir)
# step in logger is environmental step
logger = tools.Logger(logdir, config.action_repeat * step)

print("Create envs.")
if config.offline_traindir:
    directory = config.offline_traindir.format(**vars(config))
else:
    directory = config.traindir
train_eps = tools.load_episodes(directory, limit=100) #config.dataset_size)

if config.offline_evaldir:
    directory = config.offline_evaldir.format(**vars(config))
else:
    directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)

demo_eps = tools.load_episodes(config.demodir, limit=None)

make = lambda mode, id: make_env(config, mode, id)
train_envs = [make("train", i) for i in range(config.envs)]
eval_envs = [make("eval", i) for i in range(config.envs)]
if config.parallel:
    train_envs = [Parallel(env, "process") for env in train_envs]
    eval_envs = [Parallel(env, "process") for env in eval_envs]
else:
    train_envs = [Damy(env) for env in train_envs]
    eval_envs = [Damy(env) for env in eval_envs]
acts = train_envs[0].action_space
print("Action Space", acts)
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]

In [None]:
state = None
if not config.offline_traindir:
    prefill = max(0, config.prefill - count_steps(config.traindir))
    print(f"Prefill dataset ({prefill} steps).")
    if hasattr(acts, "discrete"):
        random_actor = tools.OneHotDist(
            torch.zeros(config.num_actions).repeat(config.envs, 1)
        )
    else:
        random_actor = torchd.independent.Independent(
            torchd.uniform.Uniform(
                torch.tensor(acts.low).repeat(config.envs, 1),
                torch.tensor(acts.high).repeat(config.envs, 1),
            ),
            1,
        )

    def random_agent(o, d, s):
        action = random_actor.sample()
        logprob = random_actor.log_prob(action)
        return {"action": action, "logprob": logprob}, None

    state = tools.simulate(
        random_agent,

        train_envs,
        train_eps,
        config.traindir,
        logger,
        limit=config.dataset_size,
        steps=prefill,
    )
    logger.step += prefill * config.action_repeat
    print(f"Logger: ({logger.step} steps).")

print("Simulate agent.")
train_dataset = make_dataset(train_eps, config)
eval_dataset = make_dataset(eval_eps, config)
demo_dataset = make_dataset(demo_eps, config)
train_envs[0].observation_space

In [None]:
if False:
    env = train_envs[0]
    env.reset()()
    obs, *_ = env.step({'action': train_envs[0].action_space.sample()})()
    for k,v in obs.items():
        print(k, v.shape if hasattr(v, 'shape') else type(v)) 

In [None]:
agent = Dreamer(
    train_envs[0].observation_space,
    train_envs[0].action_space,
    config,
    logger,
    train_dataset,
    demo_dataset,
).to(config.device)
agent.requires_grad_(requires_grad=False)
if (logdir / "latest.pt").exists():
    checkpoint = torch.load(logdir / "latest.pt")
    agent.load_state_dict(checkpoint["agent_state_dict"])
    tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
    agent._should_pretrain._once = False

In [None]:
if False:
    eval_envs = [make("train", i) for i in range(config.envs)]
    eval_envs = [Damy(env) for env in eval_envs]

    import numpy as np
    from tools import convert

    envs = eval_envs
    step, episode = 0, 0
    done = np.ones(len(envs), bool)
    length = np.zeros(len(envs), np.int32)
    obs = [None] * len(envs)
    agent_state = None
    reward = [0] * len(envs)
    if done.any():
        indices = [index for index, d in enumerate(done) if d]
        results = [envs[i].reset() for i in indices]
        results = [r() for r in results]
        for index, result in zip(indices, results): # NOTE: Michael, does this kill us? (pulls every transition off of the GPU to save it, done below as well.)
            t = result.copy()
            # for k,v in t.items():
            #     print(k, v)
            #     convert(v)
            t = {k: convert(v) for k, v in t.items()}
            # action will be added to transition in add_to_cache
            t["reward"] = 0.0
            t["discount"] = 1.0
            # replace obs with done by initial state
            obs[index] = result
    obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}
    for k,v in obs.items():
        print(k, v.shape if hasattr(v, 'shape') else v)
    if agent_state is None:
        latent = action = None
    else:
        latent, action = agent_state
    obs = agent._wm.preprocess(obs)
    embed = agent._wm.encoder(obs)
    # latent, _ = agent._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
    prev_state = latent; prev_action = action; is_first = obs["is_first"]
    # def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
    # initialize all prev_state
    if prev_state == None or torch.sum(is_first) == len(is_first):
        prev_state = agent._wm.dynamics.initial(len(is_first))
        prev_action = torch.zeros(
            (len(is_first), agent._wm.dynamics._num_actions), device=agent._wm.dynamics._device
        )

        if len(embed.shape) > len(prev_action.shape):
            prev_state = {k:v.unsqueeze(0) for k,v in prev_state.items()}
            # prev_state = prev_state.unsqueeze(0)
            prev_action = prev_action.unsqueeze(0)
    # overwrite the prev_state only where is_first=True
    elif torch.sum(is_first) > 0:
        is_first = is_first[:, None]
        prev_action *= 1.0 - is_first
        init_state = agent._wm.dynamics.initial(len(is_first))
        for key, val in prev_state.items():
            is_first_r = torch.reshape(
                is_first,
                is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
            )
            prev_state[key] = (
                val * (1.0 - is_first_r) + init_state[key] * is_first_r
            )

    prior = agent._wm.dynamics.img_step(prev_state, prev_action)
    embed.shape, prior["deter"].shape, prev_action.shape, prev_state['stoch'].shape, prev_state["deter"].shape

    x = torch.cat([prior["deter"], embed], -1)

In [None]:

# logger._videos = {}

# make sure eval will be executed once after config.steps
agent._step = 0

print("Create envs.")
if config.offline_traindir:
    directory = config.offline_traindir.format(**vars(config))
else:
    directory = config.traindir
train_eps = tools.load_episodes(directory, limit=100) #config.dataset_size)
train_dataset = make_dataset(train_eps, config)
agent._dataset = train_dataset

train_envs = [make("train", i) for i in range(config.envs)]
if config.parallel:
    train_envs = [Parallel(env, "process") for env in train_envs]
else:
    train_envs = [Damy(env) for env in train_envs]

# if config.offline_evaldir:
#     directory = config.offline_evaldir.format(**vars(config))
# else:
#     directory = config.evaldir
# eval_eps = tools.load_episodes(directory, limit=1)

# eval_dataset = make_dataset(eval_eps, config)
# eval_envs = [make("train", i) for i in range(config.envs)]
# eval_envs = [Damy(env) for env in eval_envs]

while agent._step < config.steps + config.eval_every:
    # logger.write()
    # if config.eval_episode_num > 0:
    #     print("Start evaluation.")
    #     eval_policy = functools.partial(agent, training=False)
    #     tools.simulate(
    #         eval_policy,
    #         eval_envs,
    #         eval_eps,
    #         config.evaldir,
    #         logger,
    #         is_eval=True,
    #         episodes=config.eval_episode_num,
    #     )
        # if config.video_pred_log:
        #     d = next(eval_dataset)
        #     video_pred = agent._wm.video_pred(d)
        #     logger.video("eval_openl", to_np(video_pred))
    print("Start training.")
    state = tools.simulate(
        agent,
        train_envs,
        train_eps,
        config.traindir,
        logger,
        limit=config.dataset_size,
        steps=config.eval_every,
        state=state,
    )
    items_to_save = {
        "agent_state_dict": agent.state_dict(),
        "optims_state_dict": tools.recursively_collect_optim_state_dict(agent),
    }
    torch.save(items_to_save, logdir / "latest.pt")
for env in train_envs + eval_envs:
    try:
        env.close()
    except Exception:
        pass

In [None]:
# train_eps = tools.load_episodes(directory, limit=100) #config.dataset_size)
# train_dataset = make_dataset(train_eps, config)
d = next(train_dataset)
data = agent._wm.preprocess(d)
agent._wm.encoder(data)

In [None]:
for k,v in eval_eps.items():
    print(k, v['reward'].__len__())
    print(v['image'][0].shape)
    print(v['state'][0].shape)

print("Create envs.")
if config.offline_traindir:
    directory = config.offline_traindir.format(**vars(config))
else:
    directory = config.traindir
train_eps = tools.load_episodes(directory, limit=100) #config.dataset_size)
train_dataset = make_dataset(train_eps, config)

d = next(train_dataset)
print(k, d['reward'].__len__())
print(d['image'].shape)
print(d['state'].shape)

In [None]:
d = next(train_dataset)

In [None]:
for k,v in d.items():
    print(k, v.shape if hasattr(v, 'shape') else v)

In [None]:
print("Observation space", env.observation_space)
print("Action space", env.action_space)

# obs, _ = env.reset() # TODO: reset with a seed for determinism
obs = env.reset() # TODO: reset with a seed for determinism


In [None]:
import time

done = False
step = 0
t0 = time.time()
while not done:
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step({'action': action})
    done = terminated or truncated
    step += 1
    # env.env.env.env.env.render() # maniskill render doesn't take any arguments, but the gymnasium environment does. Annoying.
env.close()

print(f"{step} steps took {time.time() - t0}")




In [None]:
obs.keys()