In [None]:
import os
import pickle
import wandb
import jax
from datetime import datetime
from mujoco_playground import registry, wrapper
from helper import parse_cfg
from omegaconf import OmegaConf
from flax.training import checkpoints
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.sac import networks as sac_networks
import mediapy as media
import hydra
from mujoco_playground import registry
import functools
from mujoco_playground.config import dm_control_suite_params
from brax.training.acme import running_statistics
from mujoco_playground._src.wrapper import BraxDomainRandomizationVmapWrapper
from etils import epath
from omegaconf import OmegaConf
from mujoco_playground._src import mjx_env
from typing import Any, Callable, Dict, Optional, Type, Union, Tuple
from mujoco import mjx

os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=2"
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['XLA_FLAGS'] = xla_flags
os.environ['JAX_PLATFORM_NAME'] = 'gpu'



In [2]:
from mujoco_playground._src.wrapper import Wrapper
class BraxDomainRandomizationWrapper(Wrapper):
  """Brax wrapper for domain randomization."""
  def __init__(
      self,
      env: mjx_env.MjxEnv,
      randomization_fn: Callable[[mjx.Model], Tuple[mjx.Model, mjx.Model]],
  ):
    super().__init__(env)
    self._mjx_model, self._in_axes = randomization_fn(self.env.mjx_model)
    self.env.unwrapped._mjx_model = self._mjx_model

  # def _env_fn(self, mjx_model: mjx.Model) -> mjx_env.MjxEnv:
  #   env = self.env
  #   env.unwrapped._mjx_model = mjx_model
  #   return env

  def reset(self, rng: jax.Array) -> mjx_env.State:
    # def reset(mjx_model, rng):
    #   env = self._env_fn(mjx_model=mjx_model)
    #   return env.reset(rng)

    state = self.env.reset(rng)
    return state

  def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
    # def step(mjx_model, s, a):
    #   env = self._env_fn(mjx_model=mjx_model)
    #   return env.step(s, a)

    res = self.env.step(state, action)
    return res


In [3]:

x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]

def evaluate(cfg):

    # Load environment
    env = registry.load(cfg.task)
    env_cfg = registry.get_default_config(cfg.task)
    print("env nbody", env._mj_model.nbody)
    print("env nv", env._mj_model.nv)
    obs_size = env.observation_size
    act_size = env.action_size
    rng = jax.random.PRNGKey(cfg.seed)
    
    if cfg.dynamics_shift:
        path = epath.Path(".").resolve()
        if cfg.dynamics_shift_type == "deterministic":
            dynamics_path = os.path.join(path, "dynamics_shift", "deterministic", f"{cfg.task}.yaml")
            print("dynamics_path:", dynamics_path)
            print("file exists?", os.path.exists(dynamics_path))
            dynamics_cfg = OmegaConf.load(dynamics_path)
            randomization_fn = registry.get_domain_randomizer_eval(cfg.task)
            rng, dynamics_rng = jax.random.split(rng)
            
            randomization_fn = functools.partial(randomization_fn, rng=dynamics_rng ,deterministic_cfg=dynamics_cfg, stochastic_cfg=None)
            env = BraxDomainRandomizationWrapper(
                env,
                randomization_fn=randomization_fn,
            )
        elif cfg.dynamics_shift_type == "stochastic":
            dynamics_path = os.path.join(path, "dynamics_shift", "stochastic", f"{cfg.task}.yaml")
            dynamics_cfg = OmegaConf.load(dynamics_path)
            randomization_fn = registry.get_domain_randomizer_eval(cfg.task)
            rng, dynamics_rng = jax.random.split(rng)
            randomization_fn = functools.partial(randomization_fn, rng = dynamics_rng, deterministic_cfg=None, stochastic_cfg=dynamics_cfg)
            env = BraxDomainRandomizationWrapper(
                env,
                randomization_fn=randomization_fn,
            )
    if cfg.policy == "sac":
        sac_params = dm_control_suite_params.brax_sac_config(cfg.task)
        sac_training_params = dict(sac_params)
        network_factory = sac_networks.make_sac_networks
        if "network_factory" in sac_params:
            del sac_training_params["network_factory"]
            network_factory = functools.partial(
                sac_networks.make_sac_networks,
                **sac_params.network_factory
            )
        sac_network = network_factory(
            observation_size=obs_size,
            action_size=act_size,
            preprocess_observations_fn=running_statistics.normalize if sac_params.normalize_observations else None, 
        )
        make_policy_fn = sac_networks.make_inference_fn(sac_network)
    elif cfg.policy == "ppo":
        network = ppo_networks.make_ppo_networks(obs_size, act_size)
        ppo_params = dm_control_suite_params.brax_ppo_config(cfg.task)
        network_factory = ppo_networks.make_ppo_networks
        if "network_factory" in ppo_params:
            network_factory = functools.partial(
                ppo_networks.make_ppo_networks,
                **ppo_params.network_factory
            )
        ppo_network = network_factory(
            observation_size=obs_size,
            action_size=act_size,
            preprocess_observations_fn=running_statistics.normalize if ppo_params.normalize_observations else None,
        )

        make_policy_fn = ppo_networks.make_inference_fn(ppo_network)

    # Load saved parameters
    save_dir = os.path.join(cfg.work_dir, "models")
    print(f"Loading parameters from {save_dir}")
    with open(os.path.join(save_dir, f"{cfg.policy}_params_latest.pkl"), "rb") as f:
        params = pickle.load(f)

    jit_inference_fn = jax.jit(make_policy_fn(params,deterministic=True))

    jit_reset = jax.jit(env.reset)
    jit_step = jax.jit(env.step)

    # Evaluation loop
    state = jit_reset(jax.random.PRNGKey(0))
    total_reward = 0.0
    
    rollout = [state]
    rng = jax.random.PRNGKey(cfg.seed)

    for _ in range(env_cfg.episode_length):
        act_rng, rng = jax.random.split(rng)
        action, info = jit_inference_fn(state.obs, act_rng)
        state = jit_step(state, action)
        rollout.append(state)
        total_reward += state.reward
    frames = env.render(rollout)
    media.show_video(frames, fps=1.0 / env.dt)
    print(f"Total reward: {total_reward}")

In [7]:

cfg_path= epath.Path(".").resolve()
print(cfg_path)
cfg_path = os.path.join(cfg_path, "config.yaml")
# cfg = compose(config_name="config.yaml")
cfg = OmegaConf.load(cfg_path)
cfg.task="CheetahRun"
cfg.policy="sac"
cfg.dynamics_shift = True
cfg.dynamics_shift_type = "deterministic"  # or "deterministic"
cfg.seed=3
print("cfg:", cfg)
cfg = parse_cfg(cfg)

# cfg = OmegaConf.load(cfg_path)

evaluate(cfg)

/home/sukchul/distributionally_robust_learning/learning
cfg: {'benchmark': 'dm_control', 'task': 'CheetahRun', 'obs': 'state', 'exp_name': 'test', 'checkpoint': '???', 'eval_episodes': 1, 'eval_pi': True, 'eval_value': True, 'eval_freq': 50000, 'policy': 'sac', 'dynamics_shift': True, 'dynamics_shift_type': 'deterministic', 'wandb_project': 'td-mpc(jax)', 'wandb_entity': 'tjrcjf410-seoul-national-university', 'wandb_silent': False, 'use_wandb': True, 'save_csv': True, 'save_video': True, 'save_agent': True, 'seed': 3, 'work_dir': '???', 'task_title': '???', 'multitask': '???', 'tasks': '???', 'obs_shape': '???', 'action_dim': '???', 'episode_length': '???', 'obs_shapes': '???', 'action_dims': '???', 'episode_lengths': '???', 'seed_steps': '???', 'bin_size': '???'}
env nbody 8
env nv 9
dynamics_path: /home/sukchul/distributionally_robust_learning/learning/dynamics_shift/deterministic/CheetahRun.yaml
file exists? True
Loading parameters from /home/sukchul/distributionally_robust_learning

100%|██████████| 1001/1001 [00:01<00:00, 605.72it/s]


0
This browser does not support the video tag.


Total reward: 706.5291137695312
