In [1]:
import os
os.environ["MUJOCO_GL"] = "egl"
os.environ["DISPLAY"] = ""

from pathlib import Path

import gym
import h5py
import hydra
import mujoco
import imageio
from omegaconf import DictConfig
import numpy as np
import torch
import torch.nn as nn
from cleandiffuser.dataset.d4rl_antmaze_dataset import (
    D4RLAntmazeTDDataset,
    MultiHorizonD4RLAntmazeDataset,
)
from cleandiffuser.dataset.dataset_utils import dict_apply, loop_dataloader
from cleandiffuser.diffusion import ContinuousRectifiedFlow
from cleandiffuser.invdynamic import FancyMlpInvDynamic
from cleandiffuser.nn_condition import MLPCondition
from cleandiffuser.nn_diffusion import DiT1d
from cleandiffuser.utils import set_seed
from cleandiffuser.utils.iql import IQL
from torch.utils.data import DataLoader
from tqdm import tqdm

import d4rl

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
These new versions include large bug fixes, new versions of Python, and are where all new development will continue. Please upgrade these libraries as soon as you're able to do so.
If you'd like to read more about the story behind this switch, please check out ]8;;https://farama.org/Announcing-Minari\this blog post]8;;\.
  from distutils.dep_util import newer, newer_group
No module named 'flow'
No module named 'carla'
pybullet build time: Jan 29 2025 23:16:28
  from pkg_resources import parse_version
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https

In [2]:
class MultiHorizonD4RLAntmazeDatasetwQ(MultiHorizonD4RLAntmazeDataset):
    pred_values = None

    @torch.no_grad()
    def add_value(self, iql: IQL, device: str):
        self.pred_values = np.zeros_like(self.seq_rew)
        for i in tqdm(range(self.pred_values.shape[0])):
            self.pred_values[i] = (
                iql.V(torch.tensor(self.seq_obs[i], device=device)).cpu().numpy()
            )

    def __getitem__(self, idx: int):

        indices = [
            int(self.len_each_horizon[i] * (idx / self.len_each_horizon[-1]))
            for i in range(len(self.horizons))
        ]

        torch_datas = []

        for i, horizon in enumerate(self.horizons):
            path_idx, start, end = self.indices[i][indices[i]]

            rewards = self.seq_rew[path_idx, start:]
            values = (rewards * self.discount[: rewards.shape[0], None]).sum(0)

            data = {
                "obs": {"state": self.seq_obs[path_idx, start:end]},
                "act": self.seq_act[path_idx, start:end],
                "rew": self.seq_rew[path_idx, start:end],
                "pred_val": self.pred_values[path_idx, start:end],
                "val": values,
            }

            torch_data = dict_apply(data, torch.tensor)

            torch_datas.append(
                {
                    "horizon": horizon,
                    "data": torch_data,
                }
            )

        return torch_datas

In [3]:
def inference_pipeline(args):
    set_seed(args.seed)
    if args.test_model == "R2":
        args.diffusion_ckpt = args.reflow_ckpt

    save_path = f"results/{args.pipeline_name}/{args.task.env_name}/"
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    video_path = Path(save_path) / "antmaze_inference.mp4"

    w_cfgs = [1.0, 0.0, 0.0]
    planning_horizons = [5, 5, 9]
    n_levels = len(planning_horizons)
    temporal_horizons = [planning_horizons[-1] for _ in range(n_levels)]
    for i in range(n_levels - 1):
        temporal_horizons[-2 - i] = (planning_horizons[-2 - i] - 1) * (
            temporal_horizons[-1 - i] - 1
        ) + 1

    env = gym.make(args.task.env_name)
    dataset = MultiHorizonD4RLAntmazeDatasetwQ(
        env.get_dataset(), horizons=temporal_horizons, discount=args.discount
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        persistent_workers=True,
    )
    obs_dim, act_dim = dataset.o_dim, dataset.a_dim

    fix_masks = [torch.zeros((h, obs_dim)) for h in planning_horizons]
    loss_weights = [torch.ones((h, obs_dim)) for h in planning_horizons]
    for i in range(n_levels):
        fix_idx = 0 if i == 0 else [0, -1]
        fix_masks[i][fix_idx, :] = 1.0
        loss_weights[i][1, :] = args.next_obs_loss_weight

    nn_diffusions = [
        DiT1d(
            obs_dim,
            emb_dim=args.emb_dim,
            d_model=args.d_model,
            n_heads=args.n_heads,
            depth=args.depth,
            timestep_emb_type="fourier",
        )
        for _ in range(n_levels)
    ]
    nn_conditions = [
        MLPCondition(
            1,
            args.emb_dim,
            hidden_dims=[args.emb_dim],
        )
        for _ in range(n_levels)
    ]

    diffusions = [
        ContinuousRectifiedFlow(
            nn_diffusions[i],
            nn_conditions[i],
            fix_masks[i],
            loss_weights[i],
            ema_rate=args.ema_rate,
            device=args.device,
        )
        for i in range(n_levels)
    ]

    invdyn = FancyMlpInvDynamic(
        obs_dim, act_dim, 256, torch.nn.Tanh(), add_dropout=True, device=args.device
    )

    n_candidates = args.num_candidates

    for i in range(n_levels):
        diffusions[i].load(
            save_path
            + f'{"reflow" if args.test_model == "R2" else "diffusion"}{i}_ckpt_{args.diffusion_ckpt}.pt'
        )
        diffusions[i].eval()
    invdyn.load(save_path + f"invdyn_ckpt_{args.invdyn_ckpt}.pt")
    invdyn.eval()

    iql = IQL(obs_dim, act_dim, hidden_dim=512).to(args.device)
    iql.load(save_path + "iql_ckpt_latest.pt", device=args.device)
    iql.eval()

    # Single environment for evaluation
    eval_env = gym.make(args.task.env_name)

    normalizer = dataset.get_normalizer()
    episode_rewards = []

    priors = [
        torch.zeros((1, planning_horizons[i], obs_dim), device=args.device)
        for i in range(n_levels)
    ]
    priors[0] = torch.zeros((1, n_candidates, planning_horizons[0], obs_dim)).to(args.device)
    condition = torch.ones((1, n_candidates, 1), device=args.device)

    # Video writer
    writer = imageio.get_writer(str(video_path), fps=30)

    for i in range(1):

        obs, ep_reward, cum_done, t = eval_env.reset(), 0.0, False, 0

        while not cum_done and t < 1000 + 1:

            frame = eval_env.render(mode='rgb_array', width=256, height=256)
            writer.append_data(frame)

            target_return = np.ones(1, dtype=np.float32)
            if "medium-play" in args.task.env_name:
                target_return[:] = 0.2
                if obs[1] > 18.0:
                    target_return[:] = 0.8
            elif "medium-diverse" in args.task.env_name:
                target_return[:] = 0.2
                if obs[0] > 10.0:
                    target_return[:] = 0.3
                if obs[1] > 15.0:
                    target_return[:] = 0.8
            elif "large-play" in args.task.env_name:
                target_return[:] = 0.6
                if obs[0] >= 13.0 and obs[1] < 28:
                    target_return[:] = 0.25
                if obs[0] < 13.0:
                    target_return[:] = 0.1
            elif "large_diverse" in args.task.env_name:
                target_return[:] = 0.6
                if obs[0] >= 13.0 and obs[1] < 28:
                    target_return[:] = 0.3
                if obs[0] < 13.0:
                    target_return[:] = 0.25

            target_return = torch.tensor(target_return, device=args.device)[:, None, None]
            this_condition = condition * target_return

            obs_tensor = torch.tensor(normalizer.normalize(obs), device=args.device, dtype=torch.float32)  # shape [obs_dim]
            obs_tensor = obs_tensor.unsqueeze(0).unsqueeze(0)  # shape [1, 1, obs_dim]
            obs_tensor = obs_tensor.repeat(1, n_candidates, 1)  # shape [1, n_candidates, obs_dim]

            priors[0][:, :, 0] = obs_tensor

            priors[0][:, :, 0] = obs_tensor.unsqueeze(1)
            for j in range(n_levels):
                traj, _ = diffusions[j].sample(
                    priors[j].view(-1, planning_horizons[j], obs_dim),
                    n_samples=(n_candidates if j == 0 else 1),
                    sample_steps=2 if args.test_model == "R2" else 5,
                    use_ema=args.use_ema,
                    condition_cfg=(this_condition.reshape(-1, 1) if j == 0 else this_condition[:, 0]),
                    w_cfg=w_cfgs[j],
                    temperature=args.temperature,
                    sample_step_schedule="quad_continuous",
                )
                if j == 0:
                    traj = traj.reshape(1, n_candidates, -1, obs_dim)
                    with torch.no_grad():
                        value = iql.V(traj[:, :, 1])[:, :, 0]
                        idx = torch.argmax(value, -1)
                        traj = traj[torch.arange(1), idx]
                if j < n_levels - 1:
                    priors[j + 1][:, [0, -1]] = traj[:, [0, 1]]

            with torch.no_grad():
                act = invdyn(traj[:, 0], traj[:, 1]).cpu().numpy()[0]

            obs, rew, done, info = eval_env.step(act)
            t += 1
            cum_done = done
            ep_reward += rew
            print(f"[t={t}] xy: {obs[:2]}")
            print(f"[t={t}] rew: {ep_reward}")

        episode_rewards.append(np.clip(ep_reward, 0.0, 1.0))

        writer.close()  # Save the video

        episode_rewards = [env.get_normalized_score(r) for r in episode_rewards]
        episode_rewards = np.array(episode_rewards)
        print(np.mean(episode_rewards), np.std(episode_rewards))

        print(np.mean(episode_rewards, -1), np.std(episode_rewards, -1))

        eval_env.close()

In [4]:
from omegaconf import OmegaConf
from hydra import initialize, compose

# Initialize Hydra manually (no decorator)
initialize(config_path="../configs/diffuserlite/antmaze")
args = compose(config_name="antmaze")

# Optionally override config values right here
args.task.env_name = "antmaze-medium-play-v2"
args.device = "cuda"
args.num_candidates = 16
args.seed = 42

print(OmegaConf.to_yaml(args))

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="../configs/diffuserlite/antmaze")


pipeline_name: diffuserlite_d4rl_antmaze
mode: training
seed: 42
device: cuda
noreaching_penalty: -100
discount: 0.99
emb_dim: 256
d_model: 320
n_heads: 10
depth: 2
next_obs_loss_weight: 10.0
ema_rate: 0.9999
diffusion_gradient_steps: 1000000
invdyn_gradient_steps: 1000000
batch_size: 256
log_interval: 1000
save_interval: 10000
reflow_backbone_ckpt: latest
cond_dataset_size: 1600000
uncond_dataset_size: 400000
dataset_prepare_batch_size: 5000
dataset_prepare_sampling_steps: 20
reflow_gradient_steps: 200000
test_model: R1
diffusion_ckpt: 500000
reflow_ckpt: 500000
invdyn_ckpt: 500000
num_envs: 50
num_candidates: 16
num_episodes: 3
temperature: 1.0
use_ema: true
task:
  env_name: antmaze-medium-play-v2



In [5]:
video_path = inference_pipeline(args)

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Target Goal:  (20.740537089753065, 21.250043548508437)


load datafile: 100%|██████████| 8/8 [00:03<00:00,  2.37it/s]


Target Goal:  (20.185060446526535, 20.589082593223672)
Target Goal:  (20.61140725889111, 21.193027503877044)
Found 5 GPUs for rendering. Using device 0.


  deprecation(


[t=1] xy: [ 0.01712262 -0.0157921 ]
[t=1] rew: 0.0
[t=2] xy: [ 0.00932483 -0.03975155]
[t=2] rew: 0.0
[t=3] xy: [ 0.02299452 -0.07637495]
[t=3] rew: 0.0
[t=4] xy: [ 0.02688414 -0.03386415]
[t=4] rew: 0.0
[t=5] xy: [0.02512831 0.13431593]
[t=5] rew: 0.0
[t=6] xy: [-0.0058059  0.2915843]
[t=6] rew: 0.0
[t=7] xy: [-0.08553459  0.43590706]
[t=7] rew: 0.0
[t=8] xy: [-0.15491922  0.57292538]
[t=8] rew: 0.0
[t=9] xy: [-0.09070565  0.69679924]
[t=9] rew: 0.0
[t=10] xy: [-0.00231141  0.82575934]
[t=10] rew: 0.0
[t=11] xy: [0.07072875 0.96773441]
[t=11] rew: 0.0
[t=12] xy: [0.14874888 1.14672096]
[t=12] rew: 0.0
[t=13] xy: [0.22633984 1.30576038]
[t=13] rew: 0.0
[t=14] xy: [0.33236467 1.35152366]
[t=14] rew: 0.0
[t=15] xy: [0.43902786 1.37213129]
[t=15] rew: 0.0
[t=16] xy: [0.5242123  1.36809118]
[t=16] rew: 0.0
[t=17] xy: [0.62070293 1.41159835]
[t=17] rew: 0.0
[t=18] xy: [0.70222513 1.51450366]
[t=18] rew: 0.0
[t=19] xy: [0.81888418 1.68278566]
[t=19] rew: 0.0
[t=20] xy: [0.96481158 1.89744518