In [None]:
import sys
import os

from absl import flags
from ml_collections import config_flags
from collections import defaultdict
import tqdm

import random
import jax
import jax.numpy as jnp
import numpy as np

from agents import agents
from utils.env_utils import make_env_and_datasets
from utils.flax_utils import restore_agent
from utils.evaluation import supply_rng, add_to, flatten


In [2]:
FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'antmaze-large-navigate-v0', 'Environment (dataset) name.')
flags.DEFINE_integer('seed', 0, 'Random seed.')
config_flags.DEFINE_config_file('agent', '../impls/agents/psiql.py', lock_config=False)

if not FLAGS.is_parsed():
    FLAGS(sys.argv, known_only=True)

config = FLAGS.agent
env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name, frame_stack=config['frame_stack'])

# Initialize agent.
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)

example_batch = train_dataset.sample(1)

agent_class = agents[config['agent_name']]
agent = agent_class.create(
    FLAGS.seed,
    example_batch['observations'],
    example_batch['actions'],
    config,
)



In [3]:
restore_dir = "/n/fs/rl-chongyiz/exp_logs/hdualrl_logs/psiql/20250618_psiql_antmaze-large-navigate-v0_expectile=0.7_alpha=3.0_discount=0.99/1/debug/sd001_s_24251254.0.20250618_134110"
restore_epoch = 1_000_000

agent = restore_agent(agent, restore_dir, restore_epoch)

Restored from /n/fs/rl-chongyiz/exp_logs/hdualrl_logs/psiql/20250618_psiql_antmaze-large-navigate-v0_expectile=0.7_alpha=3.0_discount=0.99/1/debug/sd001_s_24251254.0.20250618_134110/params_1000000.pkl


In [15]:
@jax.jit
def sample_actions(
    observations,
    agent=agent,
    goals=None,
    seed=None,
    temperature=1.0,
):
    """Sample actions from the actor.

    It first queries the high-level actor to obtain subgoal representations, and then queries the low-level actor
    to obtain raw actions.
    """
    # seed, waypoint_seed, action_seed = jax.random.split(seed, 3)

    # TODO: posterior sampling
    # candidates = jnp.concatenate([
    #     jnp.expand_dims(observations, 0),
    #     candidates,
    #     jnp.expand_dims(goals, 0)
    # ], axis=0)

    # n_observations = jnp.repeat(jnp.expand_dims(observations, 0), candidates.shape[0], axis=0)
    # n_goals = jnp.repeat(jnp.expand_dims(observations, 0), candidates.shape[0], axis=0)

    # v_sw1, v_sw2 = agent.network.select('value')(n_observations, candidates)
    # v_sw = (v_sw1 + v_sw2) / 2

    # v_wg1, v_wg2 = agent.network.select('value')(candidates, n_goals)
    # v_wg = (v_wg1 + v_wg2) / 2

    # v_sg1, v_sg2 = agent.network.select('value')(n_observations, n_goals)
    # v_sg = (v_sg1 + v_sg2) / 2

    # logits = v_sw + v_wg - v_sg
    # waypoint_idx = jax.random.categorical(waypoint_seed, logits)
    # waypoint = candidates[waypoint_idx]

    dist = agent.network.select('actor')(observations, goals, temperature=temperature)
    actions = dist.sample(seed=seed)

    return actions


num_eval_episodes = 10
eval_temperature = 1.0
task_infos = env.unwrapped.task_infos if hasattr(env.unwrapped, 'task_infos') else env.task_infos
num_tasks = len(task_infos)

task_stats = defaultdict()
for task_id in tqdm.trange(1, num_tasks + 1):
    actor_fn = supply_rng(sample_actions, rng=jax.random.PRNGKey(np.random.randint(0, 2**32)))
    trajs = []
    stats = defaultdict(list)

    renders = []
    for i in tqdm.trange(num_eval_episodes):
        traj = defaultdict(list)

        observation, info = env.reset(options=dict(task_id=task_id))
        goal = info.get('goal')
        goal_frame = info.get('goal_rendered')
        done = False
        step = 0
        render = []
        while not done:
            action = actor_fn(observations=observation, goals=goal, temperature=eval_temperature)
            action = np.array(action)
            action = np.clip(action, -1, 1)

            next_observation, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            step += 1

            transition = dict(
                observation=observation,
                next_observation=next_observation,
                action=action,
                reward=reward,
                done=done,
                info=info,
            )
            add_to(traj, transition)
            observation = next_observation
        add_to(stats, flatten(info))
        trajs.append(traj)

    for k, v in stats.items():
        stats[k] = np.mean(v)

    for k, v in stats.items():
        task_stats['task{}/'.format(task_id) + k] = v

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:17<00:00,  1.79s/it]
100%|██████████| 10/10 [00:17<00:00,  1.75s/it]
100%|██████████| 10/10 [00:17<00:00,  1.75s/it]
100%|██████████| 10/10 [00:16<00:00,  1.69s/it]
100%|██████████| 10/10 [00:17<00:00,  1.74s/it]
100%|██████████| 5/5 [01:27<00:00, 17.47s/it]


In [16]:
print(task_stats)

defaultdict(None, {'task1/xy': 6.951728410839584, 'task1/prev_qpos': 0.9962415901716694, 'task1/prev_qvel': 0.03219964167625454, 'task1/qpos': 1.0012937151198646, 'task1/qvel': 0.15622723963146778, 'task1/success': 0.0, 'task2/xy': 13.918956919054352, 'task2/prev_qpos': 1.9387436680231431, 'task2/prev_qvel': 0.02402689661112348, 'task2/qpos': 1.9367531003961342, 'task2/qvel': -0.03792620451454055, 'task2/success': 0.0, 'task3/xy': 18.707416452640224, 'task3/prev_qpos': 2.5601216861492175, 'task3/prev_qvel': -0.05887238485842694, 'task3/qpos': 2.553674647545911, 'task3/qvel': -0.07394577234188338, 'task3/success': 0.0, 'task4/xy': 18.88307265518422, 'task4/prev_qpos': 2.561894354156127, 'task4/prev_qvel': -0.21806561071730365, 'task4/qpos': 2.5476660818221615, 'task4/qvel': -0.07427161360334389, 'task4/success': 0.0, 'task5/xy': 8.665013821602543, 'task5/prev_qpos': 1.213787660182286, 'task5/prev_qvel': 0.044064202611714766, 'task5/qpos': 1.2160577203422989, 'task5/qvel': 0.064477273299