In [None]:
import sys

from absl import flags
from ml_collections import config_flags
from collections import defaultdict
from tqdm import trange
from PIL import Image, ImageEnhance
import moviepy.editor as mpy
import matplotlib.pyplot as plt

import tensorflow as tf
import random
import jax
import jax.numpy as jnp
import numpy as np

from agents import agents
from utils.env_utils import make_env_and_datasets
from utils.datasets import Dataset
from utils.flax_utils import restore_agent
from utils.evaluation import supply_rng, add_to, flatten

In [1]:
FLAGS = flags.FLAGS

# flags.DEFINE_string('env_name', 'quadruped_jump', 'Environment (dataset) name.')
# flags.DEFINE_string('env_name', 'scene-play-singletask-task1-v0', 'Environment (dataset) name.')
flags.DEFINE_string('env_name', 'cube-single-play-singletask-task2-v0', 'Environment (dataset) name.')
flags.DEFINE_integer('seed', 10, 'Random seed.')
flags.DEFINE_string('obs_norm_type', 'normal', 'Type of observation normalization. (none, normal, bounded)')
flags.DEFINE_float('p_aug', None, 'Probability of applying image augmentation.')
flags.DEFINE_integer('num_aug', 1, 'Number of image augmentations.')
flags.DEFINE_integer('inplace_aug', 1, 'Whether to replace the original image after applying augmentations.')
flags.DEFINE_integer('frame_stack', None, 'Number of frames to stack.')
config_flags.DEFINE_config_file('agent', '../impls/agents/sarsa_ifql_vib_gpi.py', lock_config=False)
# config_flags.DEFINE_config_file('agent', '../impls/agents/fb_repr_fom.py', lock_config=False)

if not FLAGS.is_parsed():
    FLAGS(sys.argv, known_only=True)

config = FLAGS.agent
config['latent_dim'] = 512
config['clip_flow_goals'] = True
# config['transition_layer_norm'] = True
config['value_layer_norm'] = True

_, env, dataset, _ = make_env_and_datasets(
    FLAGS.env_name, frame_stack=FLAGS.frame_stack, max_size=10_000_000, reward_free=True)

# Initialize agent.
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
tf.random.set_seed(FLAGS.seed)

# Set up datasets.
dataset = Dataset.create(**dataset)
dataset.obs_norm_type = FLAGS.obs_norm_type
dataset.p_aug = FLAGS.p_aug
dataset.num_aug = FLAGS.num_aug
dataset.inplace_aug = FLAGS.inplace_aug
dataset.frame_stack = FLAGS.frame_stack
dataset.return_next_actions = True
dataset.normalize_observations()

example_batch = dataset.sample(1)

agent_class = agents[config['agent_name']]
agent = agent_class.create(
    FLAGS.seed,
    example_batch['observations'],
    example_batch['actions'],
    config,
)

In [3]:
# InFOM
# restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/sarsa_ifql_vib_gpi_offline2offline/20250504_sarsa_ifql_vib_gpi_offline2offline_quadruped_jump_obs_norm=normal_alpha=0.3_num_fg=16_actor_freq=4_expectile=0.9_critic_z_type=prior_vf_time_emb=False_transition_ln=True_kl_weight=0.005_latent_dim=128/300/debug/sd300_s_23747733.0.20250504_064407'
# restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/sarsa_ifql_vib_gpi_offline2offline/20250509_sarsa_ifql_vib_gpi_offline2offline_scene-play-singletask-task1-v0_obs_norm=normal_alpha=300.0_num_fg=16_actor_freq=4_expectile=0.99_critic_z_type=prior_vf_time_emb=False_actor_ln=False_kl_weight=0.2_latent_dim=128_clip_fg=True/200/debug/sd200_s_2154372.0.20250509_015603'
restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/sarsa_ifql_vib_gpi_offline2offline/20250727_sarsa_ifql_vib_gpi_offline2offline_cube-single-play-singletask-task2-v0_obs_norm=normal_alpha=30.0_ft_size=500000_ft_steps=500000_eval_freq=10000_num_fg=16_actor_freq=4_expectile=0.95_actor_ln=False_kl_weight=0.05_latent_dim=512_value_ln=True/200/debug/sd200_s_24622607.0.20250727_230701'
# FOM
# restore_path = '/n/fs/rl-chongyiz/exp_logs/ogbench_logs/fb_repr_fom_offline2offline/20250512_fb_repr_fom_offline2offline_quadruped_jump_obs_norm_type=normal_repr_alpha=10.0_alpha=0.3_num_fg=16_expectile=0.9_actor_freq=4_latent_dim=128_clip_fg=True/200/debug/sd200_s_2156137.0.20250512_132446'
restore_epoch = 1_500_000

agent_class = agents[config['agent_name']]
agent = agent_class.create(
    FLAGS.seed,
    example_batch['observations'],
    example_batch['actions'],
    config,
)
agent = restore_agent(agent, restore_path, restore_epoch)

##### InFOM quadruped_jump

In [12]:
rng = jax.random.PRNGKey(FLAGS.seed)
(terminal_locs,) = np.nonzero(dataset['terminals'] > 0)
initial_locs = np.concatenate([[0], terminal_locs[:-1] + 1])

num_trajs = 40
num_flow_futre_states = 400

idxs = np.random.randint(dataset.size, size=num_trajs)
initial_idxs = initial_locs[np.searchsorted(initial_locs, idxs, side='right') - 1]
terminal_idxs = terminal_locs[np.searchsorted(terminal_locs, idxs)]

trajs = defaultdict(list)
for init_idx, term_idx in zip(initial_idxs, terminal_idxs):
    trajs['observations'].append(dataset['observations'][init_idx:term_idx + 1])
    trajs['actions'].append(dataset['actions'][init_idx:term_idx + 1])

trajs['observations'] = np.asarray(trajs['observations'])
trajs['actions'] = np.asarray(trajs['actions'])

initial_observations = jnp.asarray(trajs['observations'][:, 0])
initial_actions = jnp.asarray(trajs['actions'][:, 0])
initial_next_observations = jnp.asarray(trajs['observations'][:, 1])
initial_next_actions = jnp.asarray(trajs['actions'][:, 1])

# predict z
latent_dist = agent.network.select('transition_encoder')(initial_next_observations, initial_next_actions)
latents = latent_dist.mode()

# sample future states
rng, noise_rng = jax.random.split(rng)
noises = jax.random.normal(
    noise_rng,
    shape=(num_flow_futre_states, *initial_observations.shape),
    dtype=initial_observations.dtype
)
flow_future_states = agent.compute_fwd_flow_goals(
    noises,
    jnp.broadcast_to(
        initial_observations[None],
        (num_flow_futre_states, *initial_observations.shape)
    ),
    jnp.broadcast_to(
        initial_actions[None],
        (num_flow_futre_states, *initial_actions.shape)
    ),
    jnp.broadcast_to(
        latents[None],
        (num_flow_futre_states, *latents.shape)
    ),
    observation_min=example_batch['observation_min'],
    observation_max=example_batch['observation_max'],
)
flow_future_states = flow_future_states.transpose([1, 0, 2])

print(flow_future_states.shape)
print(((flow_future_states[:, :, None] - trajs['observations'][:, None]) ** 2).shape)

pairwise_mse = jnp.mean((flow_future_states[:, :, None] - trajs['observations'][:, None]) ** 2)
print(pairwise_mse)

In [None]:
a = np.array([1.8885841, 1.8694075, 1.8969587, 1.9226277])
print(np.mean(a))
print(np.std(a))

##### InFOM scene-play-singletask-task1-v0

In [4]:
rng = jax.random.PRNGKey(FLAGS.seed)
(terminal_locs,) = np.nonzero(dataset['terminals'] > 0)
initial_locs = np.concatenate([[0], terminal_locs[:-1] + 1])

num_trajs = 40
num_flow_futre_states = 400

idxs = np.random.randint(dataset.size, size=num_trajs)
initial_idxs = initial_locs[np.searchsorted(initial_locs, idxs, side='right') - 1]
terminal_idxs = terminal_locs[np.searchsorted(terminal_locs, idxs)]

trajs = defaultdict(list)
for init_idx, term_idx in zip(initial_idxs, terminal_idxs):
    trajs['observations'].append(dataset['observations'][init_idx:term_idx + 1])
    trajs['actions'].append(dataset['actions'][init_idx:term_idx + 1])

trajs['observations'] = np.asarray(trajs['observations'])
trajs['actions'] = np.asarray(trajs['actions'])

initial_observations = jnp.asarray(trajs['observations'][:, 0])
initial_actions = jnp.asarray(trajs['actions'][:, 0])
initial_next_observations = jnp.asarray(trajs['observations'][:, 1])
initial_next_actions = jnp.asarray(trajs['actions'][:, 1])

# predict z
latent_dist = agent.network.select('transition_encoder')(initial_next_observations, initial_next_actions)
latents = latent_dist.mode()

# sample future states
rng, noise_rng = jax.random.split(rng)
noises = jax.random.normal(
    noise_rng,
    shape=(num_flow_futre_states, *initial_observations.shape),
    dtype=initial_observations.dtype
)
flow_future_states = agent.compute_fwd_flow_goals(
    noises,
    jnp.broadcast_to(
        initial_observations[None],
        (num_flow_futre_states, *initial_observations.shape)
    ),
    jnp.broadcast_to(
        initial_actions[None],
        (num_flow_futre_states, *initial_actions.shape)
    ),
    jnp.broadcast_to(
        latents[None],
        (num_flow_futre_states, *latents.shape)
    ),
    observation_min=example_batch['observation_min'],
    observation_max=example_batch['observation_max'],
)
flow_future_states = flow_future_states.transpose([1, 0, 2])

print(flow_future_states.shape)
print(((flow_future_states[:, :, None] - trajs['observations'][:, None]) ** 2).shape)

pairwise_mse = jnp.mean((flow_future_states[:, :, None] - trajs['observations'][:, None]) ** 2)
print(pairwise_mse)