In [None]:
import os
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['XLA_FLAGS'] = xla_flags
# os.environ["MUJOCO_GL"] = "egl"
os.environ['JAX_PLATFORM_NAME'] = 'gpu'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # change to your GPU id
import pickle
from datetime import datetime
from helper import parse_cfg
from omegaconf import OmegaConf
from flax.training import checkpoints

# import mediapy as media
import hydra
import functools
from brax.training.acme import running_statistics
from etils import epath
from omegaconf import OmegaConf
from typing import Any, Callable, Dict, Optional, Type, Union, Tuple
from mujoco import mjx
import numpy
from custom_envs import registry, dm_control_suite, locomotion
import jax
from learning.agents.ppo import networks as ppo_networks
from learning.configs import dm_control_training_config, locomotion_training_config
from custom_envs import mjx_env
import jax.numpy as jnp
import matplotlib.pyplot as plt


In [None]:

from learning.agents.sampler_ppo.train import _unpmap
from learning.module.wrapper.adv_wrapper import wrap_for_adv_training
from learning.module.wrapper.evaluator import AdvEvaluator


def evaluate(cfg, key):
    num_eval_envs = 4096
    # Load environment
    env = registry.load(cfg.task)
    env_cfg = registry.get_default_config(cfg.task)


    print("env nv", env._mj_model.nv)
    obs_size = env.observation_size
    act_size = env.action_size
    rng = jax.random.PRNGKey(cfg.eval_seed)
    randomizer = registry.get_domain_randomizer_ood(cfg.task) if cfg.ood_setting else registry.get_domain_randomizer_eval(cfg.task)
    v_randomization_fn = functools.partial(
        randomizer,
        dr_range= env.ood_range if cfg.ood_setting else (env.dr_range_wide if cfg.dr_wide else env.dr_range) ,
    )
    dr_range_low, dr_range_high = env.ood_range if cfg.ood_setting else (env.dr_range_wide if cfg.dr_wide else env.dr_range) 
    eval_env = wrap_for_adv_training(
        env,
        episode_length=env_cfg.episode_length,
        action_repeat=env_cfg.action_repeat,
        randomization_fn=v_randomization_fn,
        param_size = len(dr_range_low),
        dr_range_low=dr_range_low,
        dr_range_high=dr_range_high,
    )  # pytype: disable=wrong-keyword-args
    

    if "ppo" in cfg.policy:
        if cfg.task in dm_control_suite._envs:
            ppo_params = dm_control_training_config.brax_ppo_config(cfg.task)
        elif cfg.task in locomotion._envs:
            ppo_params = locomotion_training_config.locomotion_ppo_config(cfg.task)
        network_factory = ppo_networks.make_ppo_networks
        if "network_factory" in ppo_params:
            network_factory = functools.partial(
                ppo_networks.make_ppo_networks,
                **ppo_params.network_factory
            )
        ppo_network = network_factory(
            observation_size=obs_size,
            action_size=act_size,
            preprocess_observations_fn=running_statistics.normalize if ppo_params.normalize_observations else None,
        )

        make_policy_fn = ppo_networks.make_inference_fn(ppo_network)
    # Load saved parameters
    save_dir = os.path.join(cfg.work_dir, "models")
    print(f"Loading parameters from {save_dir}")
    with open(os.path.join(save_dir, f"{cfg.policy}_params_latest.pkl"), "rb") as f:
        params = pickle.load(f)
    eval_key, key = jax.random.split(key)
    evaluator = AdvEvaluator(
        eval_env,
        functools.partial(make_policy_fn, deterministic=True),
        num_eval_envs=num_eval_envs,
        episode_length=env_cfg.episode_length,
        action_repeat=env_cfg.action_repeat,
        key=eval_key,
    )
    if len(dr_range_low) > 2:
        param_key, key = jax.random.split(key)
        dynamics_params_grid = jax.random.uniform(param_key, shape=(4096, len(dr_range_low)), minval=dr_range_low, maxval=dr_range_high)
        metrics, reward_1d, epi_length = evaluator.run_evaluation(
            params,
            dynamics_params=dynamics_params_grid,
            training_metrics={},
            num_eval_seeds=10,
            success_threshold=0.7,
        )
        return metrics, reward_1d

    elif len(dr_range_low) == 2:
        x, y = jnp.meshgrid(jnp.linspace(dr_range_low[0], dr_range_high[0], 64),\
                              jnp.linspace(dr_range_low[1], dr_range_high[1], 64))
        dynamics_params_grid = jnp.c_[x.ravel(), y.ravel()]
        metrics, reward_1d, epi_length = evaluator.run_evaluation(
            params,
            dynamics_params=dynamics_params_grid,
            training_metrics={},
            num_eval_seeds=10,
            success_threshold=0.7,
        )
        
        # --- PLOTTING SECTION ---
        eval_fig = plt.figure()
        reward_2d = reward_1d.reshape(x.shape)
        
        # Define the threshold you want to visualize (e.g., 500, or a value from cfg)
        boundary_threshold = 600
        
        vmin, vmax = 0, 1000
        
        import numpy as np
        import matplotlib.patches as patches
        
        # 1. Plot the filled contours (Heatmap)
        levels = np.linspace(vmin, vmax, 11)
        ctf = plt.contourf(x, y, reward_2d, levels=levels, cmap='viridis')
        cbar = eval_fig.colorbar(ctf, ticks=levels)
        
        # 2. Add the specific boundary line for the threshold
        # We check if the threshold is within the data range to avoid errors
        if np.min(reward_2d) < boundary_threshold < np.max(reward_2d):
            boundary_line = plt.contour(
                x, y, reward_2d, 
                levels=[boundary_threshold], 
                colors='white',       # High contrast color
                linestyles='dashed',  # Distinct style
                linewidths=2
            )
            # Optional: Add label to the line
            plt.clabel(boundary_line, inline=True, fontsize=10, fmt=f'Thresh: {boundary_threshold}')

        # 3. Add the nominal range rectangle
        nominal_low, nominal_high = env.dr_range
        width = nominal_high[0] - nominal_low[0]
        height = nominal_high[1] - nominal_low[1]
        xlabel, ylabel = env.ood_label if cfg.ood_setting else env.dr_label
        plt.xlabel(xlabel) 
        plt.ylabel(ylabel)
        if cfg.dr_wide:
            rect = patches.Rectangle(
                (nominal_low[0], nominal_low[1]),
                width,
                height,
                linewidth=2,
                edgecolor='r',
                facecolor='none',
                linestyle='--',
                label='Trained Range'
            )
            plt.gca().add_patch(rect)
            plt.legend(loc='upper left')
        eval_fig.suptitle(f"Evaluation on Each Params")
        eval_fig.tight_layout()
        eval_fig.canvas.draw()
        
        return metrics, reward_2d, eval_fig


In [None]:

cfg_path= epath.Path(".").resolve()
print(cfg_path)
cfg_path = os.path.join(cfg_path, "config.yaml")
# cfg = compose(config_name="config.yaml")
cfg = OmegaConf.load(cfg_path)
cfg.task="CheetahRun"
cfg.policy="ppo"
cfg.seed=1
print("cfg:", cfg)
cfg = parse_cfg(cfg)
if cfg.policy =="ppo":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}"
elif cfg.policy =="gmmppo":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/beta={cfg.beta}"
elif cfg.policy =='epoptppo':
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/epsilon={cfg.epsilon}"
cfg.dr_wide=False
cfg.ood_setting=True
cfg.eval_seed= 53
print("work dir", cfg.work_dir)

metrics, reward_2d, fig = evaluate(cfg, jax.random.PRNGKey(cfg.eval_seed))

metrics


In [None]:

cfg_path= epath.Path(".").resolve()
print(cfg_path)
cfg_path = os.path.join(cfg_path, "config.yaml")
# cfg = compose(config_name="config.yaml")
cfg = OmegaConf.load(cfg_path)
cfg.task="CheetahRun"
cfg.policy="gmmppo"
cfg.seed=2
cfg.beta=-20
print("cfg:", cfg)
cfg = parse_cfg(cfg)
if cfg.policy =="ppo":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}"
elif cfg.policy =="gmmppo":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/beta={cfg.beta}"
elif cfg.policy =='epoptppo':
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/epsilon={cfg.epsilon}"
cfg.dr_wide=False
cfg.ood_setting=True
cfg.eval_seed=23
print("work dir", cfg.work_dir)

metrics, reward_2d, fig = evaluate(cfg, jax.random.PRNGKey(cfg.eval_seed))

metrics


In [None]:

cfg_path= epath.Path(".").resolve()
print(cfg_path)
cfg_path = os.path.join(cfg_path, "config.yaml")
# cfg = compose(config_name="config.yaml")
cfg = OmegaConf.load(cfg_path)
cfg.task="CheetahRun"
cfg.policy="epoptppo"
cfg.seed=1
cfg.beta=-20
cfg.epsilon=0.4
print("cfg:", cfg)
cfg = parse_cfg(cfg)
if cfg.policy =="ppo":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}"
elif cfg.policy =="gmmppo":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/beta={cfg.beta}"
elif cfg.policy =='epoptppo':
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/epsilon={cfg.epsilon}"
cfg.dr_wide=False
cfg.ood_setting=True
cfg.eval_seed= 56
print("work dir", cfg.work_dir)

metrics, reward_2d, fig = evaluate(cfg, jax.random.PRNGKey(cfg.eval_seed))

metrics


In [None]:

cfg_path= epath.Path(".").resolve()
print(cfg_path)
cfg_path = os.path.join(cfg_path, "config.yaml")
# cfg = compose(config_name="config.yaml")
cfg = OmegaConf.load(cfg_path)
cfg.task="CheetahRun"
cfg.policy="adrppo"
cfg.seed=0
cfg.epsilon=0.4
print("cfg:", cfg)
cfg = parse_cfg(cfg)
if cfg.policy =="ppo" or cfg.policy=="adrppo" or cfg.policy=="":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}"
elif cfg.policy =="gmmppo":
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/beta={cfg.beta}"
elif cfg.policy =='epoptppo':
    cfg.work_dir = f"./logs/{cfg.task}/{cfg.seed}/{cfg.policy}/epsilon={cfg.epsilon}"
cfg.dr_wide=False
cfg.ood_setting=True
cfg.eval_seed= 56
print("work dir", cfg.work_dir)

metrics, reward_2d, fig = evaluate(cfg, jax.random.PRNGKey(cfg.eval_seed))

metrics
