In [1]:
import sys

from absl import flags
from ml_collections import config_flags
from collections import defaultdict
from tqdm import trange
from PIL import Image, ImageEnhance
import moviepy.editor as mpy
import matplotlib.pyplot as plt

import tensorflow as tf
import random
import jax
import jax.numpy as jnp
import numpy as np

from agents import agents
from utils.env_utils import make_env_and_datasets
from utils.datasets import Dataset
from utils.flax_utils import restore_agent
from utils.evaluation import supply_rng, add_to, flatten

error: XDG_RUNTIME_DIR is invalid or not set in the environment.
2025-07-30 17:06:56.992862: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-30 17:06:57.033618: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-30 17:06:57.033642: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-30 17:06:57.034734: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-0

In [2]:
FLAGS = flags.FLAGS

# flags.DEFINE_string('env_name', 'jaco_reach_top_left', 'Environment (dataset) name.')
flags.DEFINE_string('env_name', 'jaco_reach_bottom_right', 'Environment (dataset) name.')
flags.DEFINE_integer('seed', 10, 'Random seed.')
flags.DEFINE_string('obs_norm_type', 'normal', 'Type of observation normalization. (none, normal, bounded)')
flags.DEFINE_float('p_aug', None, 'Probability of applying image augmentation.')
flags.DEFINE_integer('num_aug', 1, 'Number of image augmentations.')
flags.DEFINE_integer('inplace_aug', 1, 'Whether to replace the original image after applying augmentations.')
flags.DEFINE_integer('frame_stack', None, 'Number of frames to stack.')
# config_flags.DEFINE_config_file('agent', '../impls/agents/sarsa_ifql_vib_gpi.py', lock_config=False)
config_flags.DEFINE_config_file('agent', '../impls/agents/iql.py', lock_config=False)

if not FLAGS.is_parsed():
    FLAGS(sys.argv, known_only=True)

config = FLAGS.agent
_, env, dataset, _ = make_env_and_datasets(
    FLAGS.env_name, frame_stack=FLAGS.frame_stack, max_size=10_000_000, reward_free=True)

# Initialize agent.
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
tf.random.set_seed(FLAGS.seed)

# Set up datasets.
dataset = Dataset.create(**dataset)
dataset.obs_norm_type = FLAGS.obs_norm_type
dataset.p_aug = FLAGS.p_aug
dataset.num_aug = FLAGS.num_aug
dataset.inplace_aug = FLAGS.inplace_aug
dataset.frame_stack = FLAGS.frame_stack
dataset.return_next_actions = True
dataset.normalize_observations()

example_batch = dataset.sample(1)

agent_class = agents[config['agent_name']]
agent = agent_class.create(
    FLAGS.seed,
    example_batch['observations'],
    example_batch['actions'],
    config,
)


load datafile: 100%|██████████| 5/5 [00:03<00:00,  1.37it/s]
load datafile: 100%|██████████| 5/5 [00:00<00:00, 13.23it/s]


In [3]:
# InFOM
# restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/sarsa_ifql_vib_gpi_offline2offline/20250505_sarsa_ifql_vib_gpi_offline2offline_jaco_reach_top_left_obs_norm=normal_alpha=0.1_num_fg=16_actor_freq=4_expectile=0.9_critic_z_type=prior_vf_time_emb=False_transition_ln=True_kl_weight=0.2_latent_dim=128_clip_fg=True/300/debug/sd300_s_23762371.0.20250505_072019'
# restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/sarsa_ifql_vib_gpi_offline2offline/20250505_sarsa_ifql_vib_gpi_offline2offline_jaco_reach_bottom_right_obs_norm=normal_alpha=0.1_num_fg=16_actor_freq=4_expectile=0.9_critic_z_type=prior_vf_time_emb=False_transition_ln=True_kl_weight=0.2_latent_dim=128_clip_fg=True/200/debug/sd200_s_23764687.0.20250505_125850'
# IQL
# restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/iql_offline2offline/20250420_iql_offline2offline_jaco_reach_top_left_obs_norm_type=normal_alpha=10.0_expectile=0.99_actor_freq=4/40/debug/sd040_s_23654785.0.20250420_045306'
restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/iql_offline2offline/20250420_iql_offline2offline_jaco_reach_bottom_left_obs_norm_type=normal_alpha=10.0_expectile=0.99_actor_freq=4/40/debug/sd040_s_23654851.0.20250420_054227'
restore_epoch = 1_500_000

agent_class = agents[config['agent_name']]
agent = agent_class.create(
    FLAGS.seed,
    example_batch['observations'],
    example_batch['actions'],
    config,
)
agent = restore_agent(agent, restore_path, restore_epoch)

Restored from /n/fs/rl-chongyiz/exp_logs/ogbench_logs/iql_offline2offline/20250420_iql_offline2offline_jaco_reach_bottom_left_obs_norm_type=normal_alpha=10.0_expectile=0.99_actor_freq=4/40/debug/sd040_s_23654851.0.20250420_054227/params_1500000.pkl


In [4]:
def get_video(renders):
    """Return a Weights & Biases video.

    It takes a list of videos and reshapes them into a single video with the specified number of columns.

    Args:
        renders: List of videos. Each video should be a numpy array of shape (t, h, w, c).
        n_cols: Number of columns for the reshaped video. If None, it is set to the square root of the number of videos.
    """
    # Pad videos to the same length.
    max_length = max([len(render) for render in renders])
    for i, render in enumerate(renders):
        assert render.dtype == np.uint8

        # Decrease brightness of the padded frames.
        final_frame = render[-1]
        final_image = Image.fromarray(final_frame)
        enhancer = ImageEnhance.Brightness(final_image)
        final_image = enhancer.enhance(0.5)
        final_frame = np.array(final_image)

        pad = np.repeat(final_frame[np.newaxis, ...], max_length - len(render), axis=0)
        renders[i] = np.concatenate([render, pad], axis=0)

        # Add borders.
        renders[i] = np.pad(renders[i], ((0, 0), (1, 1), (1, 1), (0, 0)), mode='constant', constant_values=0)
    renders = np.array(renders)  # (n, t, h, w, c)

    return renders

num_eval_episodes = 0
num_video_episodes = 6
video_frame_skip = 3

actor_fn = supply_rng(agent.sample_actions, rng=jax.random.PRNGKey(np.random.randint(0, 2**32)))
trajs = []
stats = defaultdict(list)

renders = []
for i in trange(num_eval_episodes + num_video_episodes):
    traj = defaultdict(list)
    should_render = i >= num_eval_episodes

    observation, info = env.reset()
    if dataset is not None:
        observation = dataset.normalize_observations(observation)
    done = False
    step = 0
    render = []
    while not done:
        action = actor_fn(observations=observation, temperature=0)
        action = np.array(action)
        action = np.clip(action, -1, 1)

        next_observation, reward, terminated, truncated, info = env.step(action)
        if dataset is not None:
            next_observation = dataset.normalize_observations(next_observation)
        done = terminated or truncated
        step += 1

        if should_render and (step % video_frame_skip == 0 or done):
            if hasattr(env, 'physics'):
                frame = env.physics.render(height=200, width=200).copy()
            else:
                frame = env.render().copy()
            # frame = env.render().copy()
            render.append(frame)

        transition = dict(
            observation=observation,
            next_observation=next_observation,
            action=action,
            reward=reward,
            done=done,
            info=info,
        )
        add_to(traj, transition)
        observation = next_observation
    if i < num_eval_episodes:
        add_to(stats, flatten(info))
        trajs.append(traj)
    else:
        renders.append(np.array(render))

for k, v in stats.items():
    stats[k] = np.mean(v)

  logger.warn(

100%|██████████| 6/6 [00:05<00:00,  1.09it/s]


##### InFOM jaco_reach_top_left

In [None]:
videos = get_video(renders)

fps = 15
num_rows = 2
num_cols = 3
clip_array = []
for row in range(num_rows):
    clip_row = []
    for col in range(num_cols):
        idx = row * num_cols + col
        
        clip = mpy.ImageSequenceClip(list(videos[idx]), fps=fps)
        clip_row.append(clip)
    clip_array.append(clip_row)

clip_array = mpy.clips_array(clip_array)
clip_array.ipython_display(fps=fps, loop=True, autoplay=True)

Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4



t:   0%|          | 0/84 [00:00<?, ?it/s, now=None]

                                                             

Moviepy - Done !
Moviepy - video ready __temp__.mp4




##### IQL jaco_reach_top_left

In [5]:
videos = get_video(renders)

fps = 15
num_rows = 2
num_cols = 3
clip_array = []
for row in range(num_rows):
    clip_row = []
    for col in range(num_cols):
        idx = row * num_cols + col
        
        clip = mpy.ImageSequenceClip(list(videos[idx]), fps=fps)
        clip_row.append(clip)
    clip_array.append(clip_row)

clip_array = mpy.clips_array(clip_array)
clip_array.ipython_display(fps=fps, loop=True, autoplay=True)

Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4



                                                             

Moviepy - Done !
Moviepy - video ready __temp__.mp4




##### InFOM jaco_reach_bottom_right

In [8]:
videos = get_video(renders)

fps = 15
num_rows = 2
num_cols = 3
clip_array = []
for row in range(num_rows):
    clip_row = []
    for col in range(num_cols):
        idx = row * num_cols + col
        
        clip = mpy.ImageSequenceClip(list(videos[idx]), fps=fps)
        clip_row.append(clip)
    clip_array.append(clip_row)

clip_array = mpy.clips_array(clip_array)
clip_array.ipython_display(fps=fps, loop=True, autoplay=True)

Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4



                                                             

Moviepy - Done !
Moviepy - video ready __temp__.mp4


##### IQL jaco_reach_bottom_right

In [5]:
videos = get_video(renders)

fps = 15
num_rows = 2
num_cols = 3
clip_array = []
for row in range(num_rows):
    clip_row = []
    for col in range(num_cols):
        idx = row * num_cols + col
        
        clip = mpy.ImageSequenceClip(list(videos[idx]), fps=fps)
        clip_row.append(clip)
    clip_array.append(clip_row)

clip_array = mpy.clips_array(clip_array)
clip_array.ipython_display(fps=fps, loop=True, autoplay=True)

Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4



                                                             

Moviepy - Done !
Moviepy - video ready __temp__.mp4


