In [44]:
import os
import pickle
import wandb
import jax
from datetime import datetime
from mujoco_playground import registry, wrapper
from helper import parse_cfg
from omegaconf import OmegaConf
from flax.training import checkpoints
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.sac import networks as sac_networks
import mediapy as media
import hydra
from custom_envs import registry
import functools
from mujoco_playground.config import dm_control_suite_params
from brax.training.acme import running_statistics
from mujoco_playground._src.wrapper import BraxDomainRandomizationVmapWrapper
from etils import epath
from omegaconf import OmegaConf
from mujoco_playground._src import mjx_env
from typing import Any, Callable, Dict, Optional, Type, Union, Tuple
from mujoco import mjx
import numpy
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['XLA_FLAGS'] = xla_flags
os.environ["MUJOCO_GL"] = "egl"
os.environ['JAX_PLATFORM_NAME'] = 'gpu'
os.environ['CUDA_VISIBLE_DEVICES'] = '6'  # change to your GPU id

In [45]:
from mujoco_playground._src.wrapper import Wrapper
class BraxDomainRandomizationWrapper(Wrapper):
  """Brax wrapper for domain randomization."""
  def __init__(
      self,
      env: mjx_env.MjxEnv,
      randomization_fn: Callable[[mjx.Model], Tuple[mjx.Model, mjx.Model]],
  ):
    super().__init__(env)
    self._mjx_model, self._in_axes = randomization_fn(self.env.mjx_model)
    self.env.unwrapped._mjx_model = self._mjx_model

  # def _env_fn(self, mjx_model: mjx.Model) -> mjx_env.MjxEnv:
  #   env = self.env
  #   env.unwrapped._mjx_model = mjx_model
  #   return env

  def reset(self, rng: jax.Array) -> mjx_env.State:
    # def reset(mjx_model, rng):
    #   env = self._env_fn(mjx_model=mjx_model)
    #   return env.reset(rng)

    state = self.env.reset(rng)
    return state

  def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
    # def step(mjx_model, s, a):
    #   env = self._env_fn(mjx_model=mjx_model)
    #   return env.step(s, a)

    res = self.env.step(state, action)
    return res


In [46]:
CAMERAS = {
    "AcrobotSwingup": "fixed",
    "AcrobotSwingupSparse": "fixed",
    "BallInCup": "cam0",
    "CartpoleBalance": "fixed",
    "CartpoleBalanceSparse": "fixed",
    "CartpoleSwingup": "fixed",
    "CartpoleSwingupSparse": "fixed",
    "CheetahRun": "side",
    "FingerSpin": "cam0",
    "FingerTurnEasy": "cam0",
    "FingerTurnHard": "cam0",
    "FishSwim": "fixed_top",
    "HopperHop": "cam0",
    "HopperStand": "cam0",
    "HumanoidStand": "side",
    "HumanoidWalk": "side",
    "HumanoidRun": "side",
    "PendulumSwingup": "fixed",
    "PointMass": "cam0",
    "ReacherEasy": "fixed",
    "ReacherHard": "fixed",
    "SwimmerSwimmer6": "tracking1",
    "WalkerRun": "side",
    "WalkerWalk": "side",
    "WalkerStand": "side",
}


In [47]:

from learning.configs.dm_control_training_config import brax_rambo_config
from learning.configs.dm_control_training_config import brax_flowsac_config
from agents.flowsac import networks as flowsac_networks
from agents.rambo import networks as rambo_networks

x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]

def evaluate(cfg):

    # Load environment
    env = registry.load(cfg.task)
    env_cfg = registry.get_default_config(cfg.task)
    print("env nbody", env._mj_model.nbody)
    print("env nv", env._mj_model.nv)
    obs_size = env.observation_size
    act_size = env.action_size
    rng = jax.random.PRNGKey(cfg.eval_seed)
    
    if cfg.randomization:
        path = epath.Path(".").resolve()

        randomization_fn = registry.get_domain_randomizer_eval(cfg.task)
        rng, dynamics_rng = jax.random.split(rng)
        randomization_fn = functools.partial(randomization_fn, rng = dynamics_rng, params=env.dr_range)
        env = BraxDomainRandomizationWrapper(
            env,
            randomization_fn=randomization_fn,
        )
    if cfg.policy == "sac":
        sac_params = dm_control_suite_params.brax_sac_config(cfg.task)
        sac_training_params = dict(sac_params)
        network_factory = sac_networks.make_sac_networks
        if "network_factory" in sac_params:
            del sac_training_params["network_factory"]
            network_factory = functools.partial(
                sac_networks.make_sac_networks,
                **sac_params.network_factory
            )
        sac_network = network_factory(
            observation_size=obs_size,
            action_size=act_size,
            preprocess_observations_fn=running_statistics.normalize if sac_params.normalize_observations else None, 
        )
        make_policy_fn = sac_networks.make_inference_fn(sac_network)
    elif cfg.policy == "ppo":
        ppo_params = dm_control_suite_params.brax_ppo_config(cfg.task)
        network_factory = ppo_networks.make_ppo_networks
        if "network_factory" in ppo_params:
            network_factory = functools.partial(
                ppo_networks.make_ppo_networks,
                **ppo_params.network_factory
            )
        ppo_network = network_factory(
            observation_size=obs_size,
            action_size=act_size,
            preprocess_observations_fn=running_statistics.normalize if ppo_params.normalize_observations else None,
        )

        make_policy_fn = ppo_networks.make_inference_fn(ppo_network)
    elif cfg.policy == "rambo":
        rambo_params = brax_rambo_config(cfg.task)
        rambo_training_params = dict(rambo_params)
        network_factory = rambo_networks.make_rambo_networks
        if "network_factory" in rambo_params:
            del rambo_training_params["network_factory"]
            network_factory = functools.partial(
                rambo_networks.make_rambo_networks,
                **rambo_params.network_factory
            )
        rambo_network = network_factory(
            observation_size=obs_size,
            action_size=act_size,
            preprocess_observations_fn=running_statistics.normalize if rambo_params.normalize_observations else None,
        )
        make_policy_fn = rambo_networks.make_inference_fn(rambo_network)
    elif cfg.policy == "flowsac":
        flowsac_params = brax_flowsac_config(cfg.task)
        flowsac_training_params = dict(flowsac_params)
        network_factory = flowsac_networks.make_flowsac_networks
        if "network_factory" in flowsac_params:
            del flowsac_training_params["network_factory"]
            network_factory = functools.partial(
                flowsac_networks.make_flowsac_networks,
                **flowsac_params.network_factory
            )
        flowsac_network = network_factory(
            observation_size=obs_size,
            action_size=act_size,
            preprocess_observations_fn=running_statistics.normalize if flowsac_params.normalize_observations else None,
            dynamics_param_size=len(env.dr_range)
        )
        make_policy_fn = flowsac_networks.make_inference_fn(flowsac_network)

    # Load saved parameters
    save_dir = os.path.join(cfg.work_dir, "models")
    print(f"Loading parameters from {save_dir}")
    with open(os.path.join(save_dir, f"{cfg.policy}_params_latest.pkl"), "rb") as f:
        params = pickle.load(f)

    jit_inference_fn = jax.jit(make_policy_fn(params,deterministic=True))

    jit_reset = jax.jit(env.reset)
    jit_step = jax.jit(env.step)

    rollout = []
    rewards = []
    rng =jax.random.PRNGKey(cfg.eval_seed)
    reset_rng, rng = jax.random.split(rng)
    reset_rngs = jax.random.split(reset_rng, 10)
    for i in range(10):
        # Evaluation loop
        
        state = jit_reset(reset_rngs[i])
        total_reward = 0.0
        
        rollout = [state]

        for _ in range(env_cfg.episode_length):
            act_rng, rng = jax.random.split(rng)
            action, info = jit_inference_fn(state.obs, act_rng)
            state = jit_step(state, action)
            rollout.append(state)
            total_reward += state.reward
        rewards.append(total_reward)
    frames = env.render(rollout, camera=CAMERAS[cfg.task])
    media.show_video(frames, fps=1.0 / env.dt)
    import numpy as np
    print(f"Total reward: {np.array(rewards).mean()}")


In [51]:

cfg_path= epath.Path(".").resolve()
print(cfg_path)
cfg_path = os.path.join(cfg_path, "config.yaml")
# cfg = compose(config_name="config.yaml")
cfg = OmegaConf.load(cfg_path)
cfg.task="CartpoleSwingup"
cfg.policy="sac"
cfg.seed=1
print("cfg:", cfg)
cfg = parse_cfg(cfg)
cfg.eval_seed= 55
print("work dir", cfg.work_dir)

# cfg = OmegaConf.load(cfg_path)

evaluate(cfg)
#843.25


/raid/users/tjrcjf410/distributionally_robust_learning/learning
cfg: {'benchmark': 'dm_control', 'task': 'CartpoleSwingup', 'obs': 'state', 'exp_name': 'test', 'checkpoint': '???', 'eval_episodes': 1, 'eval_pi': True, 'eval_value': True, 'eval_freq': 50000, 'eval_with_training_env': False, 'policy': 'sac', 'asymmetric_critic': False, 'randomization': True, 'eval_randomization': True, 'wandb_project': 'wdsac-exp', 'wandb_entity': 'tjrcjf410-seoul-national-university', 'wandb_silent': False, 'use_wandb': True, 'save_csv': True, 'save_video': True, 'save_agent': True, 'seed': 1, 'work_dir': '???', 'task_title': '???', 'multitask': '???', 'tasks': '???', 'obs_shape': '???', 'action_dim': '???', 'episode_length': '???', 'obs_shapes': '???', 'action_dims': '???', 'episode_lengths': '???', 'seed_steps': '???', 'bin_size': '???', 'real_ratio': '???', 'rollout_length': '???', 'adv_weight': '???', 'batch_size': '???', 'rollout_batch_size': '???', 'n_nominals': '???', 'delta': '???', 'lambda_upda

100%|██████████| 1001/1001 [00:01<00:00, 545.14it/s]


0
This browser does not support the video tag.


Total reward: 798.7515869140625


In [49]:
randomization_fn = registry.get_domain_randomizer_eval(cfg.task)
rng = jax.random.PRNGKey(cfg.seed)
rng, dynamics_rng = jax.random.split(rng)
env = registry.load(cfg.task)
randomization_fn = functools.partial(randomization_fn, rng = dynamics_rng, params=env.dr_range)
env = BraxDomainRandomizationWrapper(
                env,
                randomization_fn=randomization_fn,
            )


In [50]:
env.dr_range


(Array([ 0.9,  0. ,  0. , -0.3, -0.3,  0.5,  0.5], dtype=float32),
 Array([1. , 1. , 1. , 0.3, 0.3, 1.5, 1.5], dtype=float32))