In [None]:
import sys

from absl import flags
from ml_collections import config_flags
from collections import defaultdict
import tqdm
from PIL import Image, ImageEnhance

import random
import jax
import jax.numpy as jnp
import numpy as np

from agents import agents
from utils.env_utils import make_env_and_datasets
from utils.flax_utils import restore_agent
from utils.evaluation import supply_rng, add_to, flatten


2025-06-29 19:31:23.643933: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751239884.672444  838088 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751239884.931175  838088 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from distutils.dep_util import newer, newer_group
  from distutils.dep_util import newer, newer_group


In [2]:
FLAGS = flags.FLAGS

# flags.DEFINE_string('env_name', 'antmaze-medium-stitch-v0', 'Environment (dataset) name.')
# flags.DEFINE_integer('seed', 0, 'Random seed.')
config_flags.DEFINE_config_file('agent', '../impls/agents/qrl.py', lock_config=False)

env_name = 'antmaze-medium-stitch-v0'
seed = 1

if not FLAGS.is_parsed():
    FLAGS(sys.argv, known_only=True)

config = FLAGS.agent
env, train_dataset, val_dataset = make_env_and_datasets(env_name, frame_stack=config['frame_stack'])

# Initialize agent.
random.seed(seed)
np.random.seed(seed)

example_batch = train_dataset.sample(1)

agent_class = agents[config['agent_name']]
agent = agent_class.create(
    seed,
    example_batch['observations'],
    example_batch['actions'],
    config,
)



In [9]:
# restore_dir = "/n/fs/rl-chongyiz/exp_logs/hdualrl_logs/psiql/20250618_psiql_antmaze-large-navigate-v0_expectile=0.7_alpha=3.0_discount=0.99/1/debug/sd001_s_24251254.0.20250618_134110"
restore_dir = "/n/fs/rl-chongyiz/exp_logs/ogbench_logs/qrl/20250615_qrl_antmaze-medium-stitch-v0_alpha=0.003_discount=0.99_squared_transition_loss=False/1/debug/sd001_s_24218514.0.20250615_150752"
# restore_dir = "/n/fs/rl-chongyiz/exp_logs/ogbench_logs/qrl/20250615_qrl_antmaze-medium-stitch-v0_alpha=0.003_discount=0.99_squared_transition_loss=True/3/debug/sd003_s_24218513.0.20250615_150751"
restore_epoch = 1_000_000

agent = restore_agent(agent, restore_dir, restore_epoch)

Restored from /n/fs/rl-chongyiz/exp_logs/ogbench_logs/qrl/20250615_qrl_antmaze-medium-stitch-v0_alpha=0.003_discount=0.99_squared_transition_loss=False/1/debug/sd001_s_24218514.0.20250615_150752/params_1000000.pkl


In [10]:
@jax.jit
def sample_actions(
    observations,
    agent=agent,
    goals=None,
    seed=None,
    temperature=1.0,
):
    """Sample actions from the actor.

    It first queries the high-level actor to obtain subgoal representations, and then queries the low-level actor
    to obtain raw actions.
    """
    # seed, waypoint_seed, action_seed = jax.random.split(seed, 3)

    # TODO: posterior sampling
    # candidates = jnp.concatenate([
    #     jnp.expand_dims(observations, 0),
    #     candidates,
    #     jnp.expand_dims(goals, 0)
    # ], axis=0)

    # n_observations = jnp.repeat(jnp.expand_dims(observations, 0), candidates.shape[0], axis=0)
    # n_goals = jnp.repeat(jnp.expand_dims(observations, 0), candidates.shape[0], axis=0)

    # v_sw1, v_sw2 = agent.network.select('value')(n_observations, candidates)
    # v_sw = (v_sw1 + v_sw2) / 2

    # v_wg1, v_wg2 = agent.network.select('value')(candidates, n_goals)
    # v_wg = (v_wg1 + v_wg2) / 2

    # v_sg1, v_sg2 = agent.network.select('value')(n_observations, n_goals)
    # v_sg = (v_sg1 + v_sg2) / 2

    # logits = v_sw + v_wg - v_sg
    # waypoint_idx = jax.random.categorical(waypoint_seed, logits)
    # waypoint = candidates[waypoint_idx]

    dist = agent.network.select('actor')(observations, goals, temperature=temperature)
    actions = dist.sample(seed=seed)
    actions = jnp.clip(actions, -1, 1)

    return actions

num_eval_episodes = 10
num_video_episodes = 1
video_frame_skip = 3
eval_temperature = 0.0
task_infos = env.unwrapped.task_infos if hasattr(env.unwrapped, 'task_infos') else env.task_infos
num_tasks = len(task_infos)

task_stats = defaultdict()
renders = []
for task_id in tqdm.trange(1, num_tasks + 1):
    actor_fn = supply_rng(sample_actions, rng=jax.random.PRNGKey(np.random.randint(0, 2**32)))
    trajs = []
    stats = defaultdict(list)

    cur_renders = []
    for i in tqdm.trange(num_eval_episodes + num_video_episodes):
        traj = defaultdict(list)
        should_render = i >= num_eval_episodes

        observation, info = env.reset(options=dict(task_id=task_id))
        goal = info.get('goal')
        goal_frame = info.get('goal_rendered')
        done = False
        step = 0
        render = []
        while not done:
            action = actor_fn(observations=observation, goals=goal, temperature=eval_temperature)
            action = np.array(action)
            action = np.clip(action, -1, 1)

            next_observation, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            step += 1
            
            if should_render and (step % video_frame_skip == 0 or done):
                frame = env.render().copy()
                if goal_frame is not None:
                    render.append(np.concatenate([goal_frame, frame], axis=0))
                else:
                    render.append(frame)

            transition = dict(
                observation=observation,
                next_observation=next_observation,
                action=action,
                reward=reward,
                done=done,
                info=info,
            )
            add_to(traj, transition)
            observation = next_observation
        if i < num_eval_episodes:
            add_to(stats, flatten(info))
            trajs.append(traj)
        else:
            renders.append(np.array(render))

    for k, v in stats.items():
        stats[k] = np.mean(v)

    for k, v in stats.items():
        task_stats['task{}/'.format(task_id) + k] = v
    
    renders.extend(cur_renders)


100%|██████████| 11/11 [00:08<00:00,  1.28it/s]
100%|██████████| 11/11 [00:08<00:00,  1.23it/s]
100%|██████████| 11/11 [00:08<00:00,  1.25it/s]
100%|██████████| 11/11 [00:05<00:00,  2.06it/s]
100%|██████████| 11/11 [00:07<00:00,  1.40it/s]
100%|██████████| 5/5 [00:39<00:00,  7.93s/it]


In [11]:
for task in range(1, num_tasks + 1):
    task_success = task_stats[f'task{task}/success']
    
    print(f"task{task}/success: {task_success}")

task1/success: 0.5
task2/success: 0.8
task3/success: 0.6
task4/success: 1.0
task5/success: 0.6


In [12]:
def get_video(renders=None, n_cols=None, fps=15):
    """Return a Weights & Biases video.

    It takes a list of videos and reshapes them into a single video with the specified number of columns.

    Args:
        renders: List of videos. Each video should be a numpy array of shape (t, h, w, c).
        n_cols: Number of columns for the reshaped video. If None, it is set to the square root of the number of videos.
    """
    # Pad videos to the same length.
    max_length = max([len(render) for render in renders])
    for i, render in enumerate(renders):
        assert render.dtype == np.uint8

        # Decrease brightness of the padded frames.
        final_frame = render[-1]
        final_image = Image.fromarray(final_frame)
        enhancer = ImageEnhance.Brightness(final_image)
        final_image = enhancer.enhance(0.5)
        final_frame = np.array(final_image)

        pad = np.repeat(final_frame[np.newaxis, ...], max_length - len(render), axis=0)
        renders[i] = np.concatenate([render, pad], axis=0)

        # Add borders.
        renders[i] = np.pad(renders[i], ((0, 0), (1, 1), (1, 1), (0, 0)), mode='constant', constant_values=0)
    renders = np.array(renders)  # (n, t, h, w, c)

    return renders


renders = get_video(renders)

In [8]:
import moviepy.editor as mpy

clip = mpy.ImageSequenceClip(list(renders[1]), fps=30)

clips = [mpy.ImageSequenceClip(list(frames), fps=30) for frames in renders]
clip_row = mpy.clips_array([clips])

clip_row.ipython_display(fps=30, loop=True, autoplay=True)

Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4



                                                               

Moviepy - Done !
Moviepy - video ready __temp__.mp4


