## 1. Getting Started

[Doc: getting started](https://robomimic.github.io/docs/introduction/getting_started.html)
To use tensorboard, run
```bash
tensorboard --logdir bc_trained_models/test --host 127.0.0.1 --port 6006
```

### 1.1. Off-screen renderer test

In [None]:
# Off-screen renderer test
import os
import numpy as np
import matplotlib.pyplot as plt
import robosuite as suite
from robomimic.envs.env_robosuite import EnvRobosuite
import robomimic.utils.obs_utils as ObsUtils

os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["MUJOCO_EGL_DEVICE_ID"] = os.environ.get("MUJOCO_EGL_DEVICE_ID", "0")

def test_robosuite_offscreen_render():
    env = suite.make(
        env_name="Lift",
        robots="Panda",
        has_renderer=False,              # no onscreen window
        has_offscreen_renderer=True,     # we want rgb_array frames
        use_camera_obs=False,
        control_freq=20,
    )
    env.reset()
    frame = env.sim.render(height=256, width=256, camera_name="frontview")

    plt.imshow(frame); plt.show()
    print("\tframe dtype/shape:", frame.dtype, frame.shape)
    print("\tframe min/max:", np.min(frame), np.max(frame))
    print("\t✅ robosuite render OK")
    
def test_robomimic_offscreen_render():
    
    ObsUtils.initialize_obs_utils_with_obs_specs({
        "obs":  {"low_dim": [], "rgb": [], "depth": [], "scan": []},
        "goal": {"low_dim": [], "rgb": [], "depth": [], "scan": []},
    })

    
    env = EnvRobosuite(
        env_name="Lift",
        robots="Panda",
        render=False,          # no onscreen
        render_offscreen=True, # important
        use_image_obs=False,
    )

    env.reset()
    img = env.render(mode="rgb_array", height=256, width=256)
    plt.imshow(img); plt.show()
    print("\trobomimic rgb_array:", type(img), getattr(img, "shape", None), getattr(img, "dtype", None))
    print("\t✅ robomimic render OK")

test_robosuite_offscreen_render()
test_robomimic_offscreen_render()


### 1.2. Training pipeline sanity check: train_bc_rnn.py (~1.5 min)

In [None]:
# python examples/train_bc_rnn.py --debug
import sys, subprocess
from pathlib import Path

REPO_ROOT = Path.cwd().resolve().parent
ROBOMIMIC_CWD = REPO_ROOT / "third_party" / "robomimic"
print(f"robomimic working directory: {ROBOMIMIC_CWD}")

cmd = [sys.executable, str(ROBOMIMIC_CWD / "examples" / "train_bc_rnn.py"), 
       "--debug"]
subprocess.run(cmd, input="y\n", text=True, check=True, cwd=str(ROBOMIMIC_CWD))

### 1.3 Actually train a diffusion policy (~15 minutes)

This scripts can run for a long time, feel free to terminate it early after a few checkpoints have been saved.

Robomimic has an open issue with [rendering diffusion policy rollouts](https://github.com/ARISE-Initiative/robomimic/issues/269).
- The current solution is setting `render_video=False` in [exp/templates/diffusion_policy.json](../third_party/robomimic/robomimic/exps/templates/diffusion_policy.json).

In [None]:
import os, sys, subprocess
from pathlib import Path

os.environ["MUJOCO_GL"] = "glx"
os.environ["PYOPENGL_PLATFORM"] = "glx"
os.environ["MUJOCO_EGL_DEVICE_ID"] = os.environ.get("MUJOCO_EGL_DEVICE_ID", "0")

REPO_ROOT = Path.cwd().resolve().parent
ROBOMIMIC_CWD = REPO_ROOT / "third_party" / "robomimic"

print(f"robomimic working directory: {ROBOMIMIC_CWD}")
assert ROBOMIMIC_CWD.is_dir(), f"robomimic dir not found: {ROBOMIMIC_CWD}"

# 1) Download datasets (run from robomimic/)
cmd = [sys.executable, "-m", "robomimic.scripts.download_datasets",
       "--tasks", "lift", "--dataset_types", "ph"]
subprocess.run(cmd, input="y\n",  # automatically pass y to download_datasets.py when propted "Overwrite?"
               text=True, check=True, cwd=str(ROBOMIMIC_CWD))

# 2) Train (run from robomimic/)
cmd = [sys.executable, "-m", "robomimic.scripts.train",
       "--config", "robomimic/exps/templates/diffusion_policy.json",
    #    "--config", "robomimic/exps/templates/bc.json",
       "--dataset", "datasets/lift/ph/low_dim_v15.hdf5"]
subprocess.run(cmd, input="y\n", text=True, check=True, cwd=str(ROBOMIMIC_CWD))



## 2. Generate Rollout Policy Checkpoint

In [None]:
%load_ext autoreload
%autoreload 2
import sys
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Tuple
from IPython.display import display
import h5py
import imageio

import numpy as np
import torch
import flax.nnx as nnx # only used here for displaying h5 trees

repo_root = Path('../')
sys.path.append(str(repo_root))

import robomimic.utils.obs_utils as ObsUtils
from robomimic.models.policy_nets import ActorNetwork, GaussianActorNetwork
from robomimic.models.diffusion_policy_nets import ConditionalUnet1D

from src.latent_sope.robomimic_interface.rollout import (
    rollout,
    RolloutPolicy,
    RolloutLatentRecorder,
    RolloutLatentTrajectory,
    save_rollout_latents,
    get_policy_frame_stack,
    PolicyFeatureHook,
)
from src.latent_sope.robomimic_interface.checkpoints import (
    load_checkpoint,
    load_demo,
    build_h5_tree,
    build_algo_from_checkpoint,
    build_env_from_checkpoint,
    build_rollout_policy_from_checkpoint,
    prepare_obs,
    EnvBase,
)
from src.latent_sope.utils.common import CONSOLE_LOGGER

dataset_path = Path("../third_party/robomimic/datasets/lift/ph/low_dim_v15.hdf5")
# policy_train_dir = Path("../third_party/robomimic/bc_trained_models/test/20260119152203")
test_dir = Path("../third_party/robomimic/diffusion_policy_trained_models/test")
policy_train_dirs = sorted([d for d in test_dir.glob("*") if d.is_dir()])
assert len(policy_train_dirs) > 0, "No policy train dirs found, you have to train a policy first!"
policy_train_dir = policy_train_dirs[-1]

output_video_path = policy_train_dir / "rollout.mp4"
output_latents_path = policy_train_dir / "rollout_latents.h5"
demo_index = 0
num_steps = 60

""" 1. Load policy and environment from checkpoint """
policy_model_checkpoint = load_checkpoint(policy_train_dir.resolve(),
                                          ckpt_path="last.pth")
policy_algo = build_algo_from_checkpoint(policy_model_checkpoint)
policy_net: ActorNetwork | ConditionalUnet1D = policy_algo.nets.policy
##test
keys = list(policy_algo.global_config.all_obs_keys)
print("policy expects keys (len={}):".format(len(keys)))
print(keys)

img_like = [k for k in keys if ("image" in k) or ("rgb" in k)]
print("\nimage-like keys in policy:", img_like)
##test

policy:RolloutPolicy = build_rollout_policy_from_checkpoint(policy_model_checkpoint, 
                                              device=torch.device("cuda"), verbose=False)

from robomimic.models.obs_nets import ObservationEncoder
import robomimic.utils.obs_utils as ObsUtils
##test
found = False
for name, m in policy_net.named_modules():
    if isinstance(m, ObservationEncoder):
        found = True
        print("FOUND ObservationEncoder at:", name)

        enc_keys = list(m.obs_shapes.keys())
        print("encoder obs keys:", enc_keys)

        rgb_keys = [k for k in enc_keys if ObsUtils.key_is_obs_modality(key=k, obs_modality="rgb")]
        print("encoder rgb keys:", rgb_keys)
        break

if not found:
    print("No ObservationEncoder found inside policy_net.")
##test
import robomimic.utils.tensor_utils as TensorUtils

def find_first_obs_encoder(root):
    for name, m in root.named_modules():
        if isinstance(m, ObservationEncoder):
            return name, m
    return None, None

# ---- A) find the obs encoder inside the REAL torch module used for forward ----
# Try policy_net first (usually correct)
enc_name, obs_encoder = find_first_obs_encoder(policy_net)
print("found obs encoder in policy_net:", enc_name, type(obs_encoder))

# If not found / doesn't fire later, we will try to hook inside `policy` (RolloutPolicy wrapper)
assert obs_encoder is not None, "No ObservationEncoder found inside policy_net. This policy may not be vision-based."

keys = list(obs_encoder.obs_shapes.keys())
print("encoder obs keys:", keys)

rgb_keys = [k for k in keys if ObsUtils.key_is_obs_modality(k, "rgb")]
print("rgb_keys:", rgb_keys)
assert len(rgb_keys) > 0, "No rgb keys -> this checkpoint/policy is not using vision (cannot hook visual latent)."

rgb_key = rgb_keys[0]
print("using rgb_key =", rgb_key)

# ---- B) hook VisualCore output and mirror post-net steps ----
class VisualLatentHook:
    def __init__(self, obs_encoder, rgb_key, detach=True):
        self.obs_encoder = obs_encoder
        self.rgb_key = rgb_key
        self.detach = detach

        self.visual_net = obs_encoder.obs_nets[rgb_key]                 # VisualCore
        self.randomizers = list(obs_encoder.obs_randomizers[rgb_key])   # ModuleList
        self.activation = obs_encoder.activation                        # ReLU or None

        self.latest = None
        self.handle = self.visual_net.register_forward_hook(self._hook)

        print("[VisualLatentHook] hooked net:", type(self.visual_net), "for key:", rgb_key)

    def _hook(self, module, inputs, output):
        x = output

        # same order as ObservationEncoder.forward AFTER obs_nets[k](x)
        if self.activation is not None:
            x = self.activation(x)

        for rand in reversed(self.randomizers):
            if rand is not None:
                x = rand.forward_out(x)

        x = TensorUtils.flatten(x, begin_axis=1)

        if self.detach:
            x = x.detach()

        self.latest = x

    def close(self):
        if self.handle is not None:
            self.handle.remove()
            self.handle = None

vis_hook = VisualLatentHook(obs_encoder, rgb_key)

env:EnvBase = build_env_from_checkpoint(
    policy_model_checkpoint,
    render=False,
    render_offscreen=True,
    verbose=False,
)

""" Prepare video writer"""


with h5py.File(dataset_path, "r") as h5:
    
    """ 2. Load h5 file using h5py and examine the structure """
    print("==== Showing dataset h5 tree (max_depth=2):")
    nnx.display(build_h5_tree(h5, max_depth=2))
    print("\n==== Showing data.demo_0 tree: ")
    nnx.display(build_h5_tree(h5["data"]["demo_0"], max_children=10))
    
    demo_keys = sorted(list(h5["data"].keys())) # ['demo_0', 'demo_1', 'demo_10', ...]
    demo_str_id = demo_keys[demo_keys.index(f"demo_{demo_index}")]
    obs_keys = sorted(list(h5["data"][demo_str_id]["obs"].keys())) # ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_eef_quat_site', ...]
    
    global_cfg_keys = set(policy_algo.global_config.all_obs_keys)
    obs_keys = [k for k in obs_keys if k in global_cfg_keys]
    obs_keys_visual = [k for k in obs_keys 
                       if ObsUtils.key_is_obs_modality(key=k, obs_modality="rgb") 
                       or ObsUtils.key_is_obs_modality(key=k, obs_modality="depth")]
    
    print(f"\n==== obs_keys filtered by global cfg: ")
    print(obs_keys)
    
    obs_np, actions_np = load_demo(h5, demo_str_id, obs_keys, num_steps)
    
    # obs_stats = policy_model_checkpoint.ckpt_dict.get("obs_normalization_stats", None)
    obs_stats = None
    obs_torch = prepare_obs(obs_np, device=policy_algo.device, obs_stats=obs_stats)
    actions_torch = torch.as_tensor(actions_np, device=policy_algo.device, dtype=torch.float32)
    
    print(f"\n==== Tensors in obs_torch: ")
    print("\n".join([f"\t{k}: {v.shape}, {v.device}, {v.dtype}" for k, v in obs_torch.items()]))
    print(f"\n==== actions_t: {actions_torch.shape}, {actions_torch.device}, {actions_torch.dtype} \n")
    
    
    """ 3. Build rollout policy """
    video_writer = imageio.get_writer(str(output_video_path), fps=20)
    camera_names = ['agentview']
    
    frame_stack = get_policy_frame_stack(policy)
    feature_hook = PolicyFeatureHook(
        policy,
        feat_type="visual_latent",
    )
    print("rgb_key:", getattr(feature_hook, "_rgb_key", None))
    print("visual_net:", type(getattr(feature_hook, "_visual_net", None)))
    recorder = RolloutLatentRecorder(
        feature_hook,
        obs_keys=obs_keys,
        store_obs=False,
        store_next_obs=False,
    )
    stats = rollout(
        policy=policy,
        env=env,
        horizon=num_steps,
        render=False,
        video_writer=video_writer,
        video_skip=1,
        camera_names=camera_names,
        recorder=recorder,
    )
    video_writer.close()

assert vis_hook.latest is not None, "Hook never fired -> wrong module instance or policy not using rgb."
print("visual latent shape:", vis_hook.latest.shape)

print(f"==== Rollout video saved at \n{output_video_path.resolve()}")

traj:RolloutLatentTrajectory = recorder.finalize(stats)
#save_rollout_latents(output_latents_path, traj)
# -------------------------
# Test 0: sanity check that traj.latents is changing over time
# -------------------------
print("traj.latents shape:", None if traj.latents is None else traj.latents.shape)
if traj.latents is not None:
    print("traj.latents mean/std:", traj.latents.mean(), traj.latents.std())
    d = np.linalg.norm(traj.latents[1:] - traj.latents[:-1], axis=-1)  # (T-1,)
    print("mean ||Δlatent||:", d.mean(), "min:", d.min(), "max:", d.max())

save_rollout_latents(output_latents_path, traj)
print(f"==== Saved rollout latents at \n{output_latents_path.resolve()}")
#print(f"==== Saved rollout latents at \n{output_latents_path.resolve()}")
#print(f"==== Rollout latents traj object:")
nnx.display(traj)

#allclose = np.allclose(traj.latents, np.concatenate([traj.obs[k] for k in traj.obs.keys()], axis=-1))
#print("==== feats @traj.latents stored by @PolicyFeatureHook should be\n" +
     #f"     close to feats @traj.obs. Are they? {allclose}")
#assert allclose, "latents and stacked obs should be close"

In [None]:
# ============================
# TEST: prove your hook is visual + prove policy uses vision
# (NO TRAINING REQUIRED)
# ============================
import numpy as np

# 0) sanity: does policy config expect an image key?
keys = list(policy_algo.global_config.all_obs_keys)
img_like = [k for k in keys if ("image" in k) or ("rgb" in k)]
print("policy expects image-like keys:", img_like)
assert len(img_like) > 0, "Policy config has no image key -> this checkpoint is likely low-dim only."

# 1) env reset should return that key
obs0 = env.reset()
print("env.reset keys:", list(obs0.keys())[:20], "... total:", len(obs0.keys()))
rgb_key = getattr(feature_hook, "_rgb_key", None)  # your PolicyFeatureHook should set this
print("feature_hook._rgb_key =", rgb_key)
assert rgb_key is not None, "feature_hook has no _rgb_key (your visual_latent hook setup didn't run)."
assert rgb_key in obs0, f"env.reset() does not contain rgb_key={rgb_key}. Your env/policy obs keys mismatch."

# 2) make a perturbed copy of obs0 (only change pixels)
obs1 = {k: (v.copy() if hasattr(v, "copy") else v) for k, v in obs0.items()}
img = obs1[rgb_key].astype(np.float32)
noise = np.random.normal(0, 8.0, size=img.shape).astype(np.float32)
obs1[rgb_key] = np.clip(img + noise, 0, 255).astype(obs0[rgb_key].dtype)

# helper to force one policy forward (works across RolloutPolicy variants)
def force_forward(pol, obs):
    if hasattr(pol, "get_action"):
        return pol.get_action(obs)
    if callable(pol):
        return pol(obs)
    if hasattr(pol, "policy") and hasattr(pol.policy, "get_action"):
        return pol.policy.get_action(obs)
    raise RuntimeError("Can't run policy forward. Inspect policy object methods.")

# 3) run forward on obs0 and obs1 and read the latent that your PolicyFeatureHook captured
_ = force_forward(policy, obs0)
z0 = feature_hook._last_feature
assert z0 is not None, "Hook didn't fire on obs0 forward (wrong module instance OR policy not using rgb path)."
z0n = z0.detach().float().cpu().numpy()

_ = force_forward(policy, obs1)
z1 = feature_hook._last_feature
assert z1 is not None, "Hook didn't fire on obs1 forward."
z1n = z1.detach().float().cpu().numpy()

print("\n[HOOK TEST]")
print("latent shape:", z0n.shape)
print("delta latent L2 :", float(np.linalg.norm(z1n - z0n)))
print("delta latent MAE:", float(np.mean(np.abs(z1n - z0n))))

# 4) Optional: prove the ACTION changes when image changes (stronger: policy truly uses vision)
a0 = force_forward(policy, obs0)
a1 = force_forward(policy, obs1)

# a0/a1 might be numpy already or torch; normalize to numpy
if hasattr(a0, "detach"):
    a0 = a0.detach().cpu().numpy()
if hasattr(a1, "detach"):
    a1 = a1.detach().cpu().numpy()
a0 = np.asarray(a0)
a1 = np.asarray(a1)

print("\n[POLICY-USES-VISION TEST]")
print("action shape:", a0.shape)
print("delta action L2:", float(np.linalg.norm(a1 - a0)))


## 3. Load Policy Rollout Checkpoint and Train Chunk Diffusion Model modules

**Policy diffusion:** Implemented in [policy.py](../third_party/sope/opelab/core/policy.py) as `DiffusionPolicy`, which wraps [CleanDiffuser](https://github.com/CleanDiffuserTeam/CleanDiffuser) components (`PearceMlp`, `PearceObsCondition`, `DiscreteDiffusionSDE`) to sample actions from observations.

**Trajectory-chunk diffusion (sequence model):** 
- Implemented in [temporal.py](../third_party/sope/opelab/core/baselines/diffusion/temporal.py) (`TemporalUnet` backbone)
- Wrapped by [diffusion.py](../third_party/sope/opelab/core/baselines/diffusion/diffusion.py) (`GaussianDiffusion` sampler)
- Used by [diffuser.py](../third_party/sope/opelab/core/baselines/diffuser.py) to generate chunked trajectories.

**Guidance link between them:**  During chunked sampling, the trajectory diffusion uses the policy diffusion to compute guidance gradients. 
- In [diffusion.py](../third_party/sope/opelab/core/baselines/diffusion/diffusion.py), `default_sample_fn(...)` checks the policy type and calls `gradlog_diffusion(...)`
- which in turn calls `DiffusionPolicy.grad_log_prob(...)` in [policy.py](../third_party/sope/opelab/core/policy.py) to get a score/grad-log term. 
- This gradient is scaled and injected into the trajectory diffusion step as the guidance term.


In [None]:
%load_ext autoreload
%autoreload 2
import sys
from copy import deepcopy
from pathlib import Path
from IPython.display import display

import numpy as np
import torch
import flax.nnx as nnx # only used here for displaying h5 trees

repo_root = Path('../')
sys.path.append(str(repo_root))

from src.latent_sope.utils.common import CONSOLE_LOGGER
from src.latent_sope.robomimic_interface.rollout import RolloutLatentTrajectory
from src.latent_sope.robomimic_interface.dataset import (
    RolloutChunkDatasetConfig,
    RolloutChunkDataset,
    load_rollout_latents,
    make_rollout_chunk_dataloader,
)
from src.latent_sope.diffusion.ddpm_latents import SopeChunkDiffusionConfig, SopeChunkDiffusion

test_dir = Path("../third_party/robomimic/diffusion_policy_trained_models/test")
policy_train_dirs = sorted([d for d in test_dir.glob("*") if d.is_dir()])
assert len(policy_train_dirs) > 0, "No policy train dirs found, you have to train a policy first!"
policy_train_dir = policy_train_dirs[-1]

rollout_path = policy_train_dir / "rollout_latents.h5"
assert rollout_path.exists(), f"Missing rollout file: {rollout_path}. Update the path to a saved rollout."

dataset_config = RolloutChunkDatasetConfig()
data:RolloutLatentTrajectory = load_rollout_latents(rollout_path)
dataset:RolloutChunkDataset = RolloutChunkDataset(
    traj=data,
    config=dataset_config,
)

item = next(iter(dataset))
# frame_stack = 1
# obs_dim = data["z"].shape[1] * frame_stack
# action_dim = data["actions"].shape[1]

# dl, stats = make_rollout_chunk_dataloader(
#     paths=[rollout_path],
#     W=8,
#     stride=1,
#     batch_size=4,
#     frame_stack=frame_stack,
#     source="z",
#     include_actions=True,
#     normalize=True,
# )

# batch = next(iter(dl))
# if isinstance(batch, (list, tuple)):
#     batch = batch[0]

# cfg = SopeChunkDiffusionConfig(
#     horizon=8,
#     obs_dim=obs_dim,
#     action_dim=action_dim,
#     diffusion_steps=64,
# )
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = SopeChunkDiffusion(cfg, normalization_stats=stats, device=device)
# opt = model.make_optimizer()

# cond = model.make_cond(batch)
# loss, _ = model.loss(batch, cond)
# loss.backward()
# opt.step()
# opt.zero_grad()
# print("loss:", float(loss.item()))

In [None]:
nnx.display(item)