In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8" # 0.9 causes too much lag. 
os.environ['MUJOCO_GL'] = 'egl'

os.environ['JAX_LOG_COMPILES'] = '0'

import time

import functools

import jax.numpy as jp
import numpy as np
import jax
print("JAX Device:", jax.devices())
from jax import config # Analytical gradients work much better with double precision.
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update('jax_default_matmul_precision', 'high')

print("jax.devices():", jax.devices())
print("local_device_count:", jax.local_device_count())

from absl import logging
logging.set_verbosity(logging.DEBUG)

from mujoco_playground import registry
from mujoco_playground import wrapper
from mujoco_playground.config import locomotion_params

# from brax.training.agents.apg import train as apg
from apg_alg.algorithm import apg  # Local modified APG version # type: ignore
from brax.training.agents.apg import networks as apg_networks

from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo

from brax.envs.wrappers import training as brax_training

from brax.training import acting

from brax.io import model

from brax import envs

import matplotlib.pyplot as plt
from IPython.display import HTML, clear_output
from datetime import datetime
import mediapy as media

import wandb

env_name = "Go2Trot"
env_cfg = registry.get_default_config(env_name)
randomizer = registry.get_domain_randomizer(env_name)

JAX Device: [CudaDevice(id=0), CudaDevice(id=1)]
jax.devices(): [CudaDevice(id=0), CudaDevice(id=1)]
local_device_count: 2


DEBUG:absl:Created `ArrayHandler` with primary_host=0, replica_id=0, use_replica_parallel=True, array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7fd0c42e3890>
DEBUG:absl:Handler "orbax.checkpoint._src.handlers.base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler'>. Skipping registration.
DEBUG:absl:Handler "orbax.checkpoint._src.handlers.array_checkpoint_handler.ArrayCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.array_checkpoint_handler.ArrayCheckpointHandler'>. Skipping registration.
DEBUG:absl:Handler "orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler'>. Skipping

# Load environment and params

In [3]:
demo_cfg = registry.get_default_config(env_name)
demo_cfg['env']['reset2ref'] = False
demo_cfg['env']['reference_state_init'] = False
demo_cfg['pert_config']['enable'] = False
demo_env = registry.load(env_name, demo_cfg)
demo_env = brax_training.VmapWrapper(demo_env)

apg_params = locomotion_params.brax_apg_config(env_name)
params = model.load_params('/tmp/trotting_apg_2hz_policy')
params = (params[0], params[1])

In [4]:
print("obs_size:", demo_env.observation_size)
print("action_size:", demo_env.action_size)
params = (params[0], params[1])
print(params[0].mean.shape, params[0].std.shape)

obs_size: 40
action_size: 12
(40,) (40,)


# Brax to torch

In [5]:
import torch
import torch.nn as nn
import numpy as np
import jax
import jax.numpy as jnp
import flax
from flax import linen


# ================================================================
# üß© 1. Go2Policy ÂÆö‰πâ
# ================================================================
class Go2Policy(nn.Module):
    def __init__(self, obs_dim=40, act_dim=12, obs_mean=None, obs_std=None, eps=1e-10):
        super().__init__()
        self.eps = eps

        # Ê≥®ÂÜåÂΩí‰∏ÄÂåñÂèÇÊï∞
        if obs_mean is not None and obs_std is not None:
            self.register_buffer("obs_mean", torch.tensor(obs_mean, dtype=torch.float32))
            self.register_buffer("obs_std", torch.tensor(obs_std, dtype=torch.float32))
        else:
            self.register_buffer("obs_mean", torch.zeros(obs_dim))
            self.register_buffer("obs_std", torch.ones(obs_dim))

        # ÁΩëÁªúÁªìÊûÑÔºöLinear ‚Üí ELU ‚Üí LayerNorm ‚Üí Linear ‚Üí ELU ‚Üí LayerNorm ‚Üí Linear
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ELU(),
            nn.LayerNorm(256, eps=1e-6),
            nn.Linear(256, 128),
            nn.ELU(),
            nn.LayerNorm(128, eps=1e-6),
            nn.Linear(128, act_dim * 2),
        )

    # Ê®°‰ªø running_statistics.normalize
    def normalize_obs(self, x):
        return (x - self.obs_mean) / (self.obs_std + self.eps)

    # ÂâçÂêë‰º†Êí≠
    def forward(self, x):
        x = self.normalize_obs(x)
        out = self.net(x)

        mean, log_std = out.chunk(2, dim=-1)
        std = torch.exp(log_std)
        mean_tanh = torch.tanh(mean)
        return mean_tanh, std


# ================================================================
# ‚öôÔ∏è 2. ÊùÉÈáçÂä†ËΩΩÔºöJAX ‚Üí PyTorch
# ================================================================
def load_jax_to_torch(jax_params, torch_model):
    with torch.no_grad():
        # hidden_0
        torch_model.net[0].weight.data = torch.tensor(
            np.array(jax_params["hidden_0"]["kernel"]).T, dtype=torch.float32
        )
        torch_model.net[0].bias.data = torch.tensor(
            np.array(jax_params["hidden_0"]["bias"]), dtype=torch.float32
        )
        torch_model.net[2].weight.data = torch.tensor(
            np.array(jax_params["LayerNorm_0"]["scale"]), dtype=torch.float32
        )
        torch_model.net[2].bias.data = torch.tensor(
            np.array(jax_params["LayerNorm_0"]["bias"]), dtype=torch.float32
        )

        # hidden_1
        torch_model.net[3].weight.data = torch.tensor(
            np.array(jax_params["hidden_1"]["kernel"]).T, dtype=torch.float32
        )
        torch_model.net[3].bias.data = torch.tensor(
            np.array(jax_params["hidden_1"]["bias"]), dtype=torch.float32
        )
        torch_model.net[5].weight.data = torch.tensor(
            np.array(jax_params["LayerNorm_1"]["scale"]), dtype=torch.float32
        )
        torch_model.net[5].bias.data = torch.tensor(
            np.array(jax_params["LayerNorm_1"]["bias"]), dtype=torch.float32
        )

        # hidden_2
        torch_model.net[6].weight.data = torch.tensor(
            np.array(jax_params["hidden_2"]["kernel"]).T, dtype=torch.float32
        )
        torch_model.net[6].bias.data = torch.tensor(
            np.array(jax_params["hidden_2"]["bias"]), dtype=torch.float32
        )

    print("‚úÖ JAX ÊùÉÈáçÂ∑≤ÊàêÂäüÂ§çÂà∂Âà∞ PyTorch Sequential Ê®°ÂûãÔºÅ")


# ================================================================
# üîç 3. ÂàÜÂ±ÇÊØîËæÉÔºöÁ≤æÁ°ÆÊü•ÁúãÂì™‰∏ÄÂ±ÇÂºÄÂßã‰∏çÂêå
# ================================================================
def compare_layers(jax_params, torch_model, obs):
    """
    ÂàÜÂ±ÇÊØîËæÉ JAX ‰∏é PyTorch ÁöÑËæìÂá∫ÔºåÈÄêÂ±ÇÊâìÂç∞Â∑ÆÂºÇ
    """
    x_jax = obs
    x_torch = torch.tensor(np.array(obs), dtype=torch.float32)

    # ÂèñÂá∫ JAX ÁöÑÊùÉÈáçÂ±ÇÔºàÊåâÈ°∫Â∫èÔºâ
    jax_layers = [
        jax_params["hidden_0"],
        jax_params["hidden_1"],
        jax_params["hidden_2"],
    ]

    layer_idx = 0
    for i, module in enumerate(torch_model.net):
        if isinstance(module, nn.Linear):
            # ÂØπÂ∫î JAX Dense
            W = np.array(jax_layers[layer_idx]["kernel"])
            b = np.array(jax_layers[layer_idx]["bias"])
            x_jax = jnp.dot(x_jax, W) + b
            x_torch = module(x_torch)
            diff = np.max(np.abs(np.array(x_jax) - x_torch.detach().numpy()))
            print(f"[Linear {layer_idx}] max diff = {diff:.6f}")
            layer_idx += 1

        elif isinstance(module, nn.LayerNorm):
            # ÂØπÂ∫î LayerNorm
            scale = np.array(jax_params[f"LayerNorm_{layer_idx-1}"]["scale"])
            bias = np.array(jax_params[f"LayerNorm_{layer_idx-1}"]["bias"])

            mean = np.mean(np.array(x_jax), axis=-1, keepdims=True)
            var = np.var(np.array(x_jax), axis=-1, keepdims=True)
            x_jax = (np.array(x_jax) - mean) / np.sqrt(var + 1e-6)
            x_jax = scale * x_jax + bias

            x_torch = module(x_torch)
            diff = np.max(np.abs(np.array(x_jax) - x_torch.detach().numpy()))
            print(f"[LayerNorm {layer_idx-1}] max diff = {diff:.6f}")

        elif isinstance(module, nn.ELU):
            x_jax = linen.elu(x_jax)
            x_torch = module(x_torch)
            diff = np.max(np.abs(np.array(x_jax) - x_torch.detach().numpy()))
            print(f"[ELU {layer_idx-1}] max diff = {diff:.6f}")


In [6]:
obs_mean = np.array(params[0].mean)
obs_std = np.array(params[0].std)
torch_model = Go2Policy(
    obs_dim=demo_env.observation_size, 
    act_dim=demo_env.action_size,
    obs_mean=obs_mean,
    obs_std=obs_std
)
load_jax_to_torch(params[1]['params'], torch_model)

‚úÖ JAX ÊùÉÈáçÂ∑≤ÊàêÂäüÂ§çÂà∂Âà∞ PyTorch Sequential Ê®°ÂûãÔºÅ


# Test

In [7]:
network_factory = apg_networks.make_apg_networks
apg_training_params = dict(apg_params)
if "network_factory" in apg_params:
    del apg_training_params["network_factory"]
    network_factory = functools.partial(
        apg_networks.make_apg_networks, 
        **apg_params.network_factory)
# from brax.training.acme import running_statistics
normalize = lambda x, y: x

def make_normalize_fn(mean, std, max_abs_value=None):
    def normalize_fn(batch, _unused_processor_params=None):
        def normalize_leaf(data, m, s):
            if not jp.issubdtype(data.dtype, jp.inexact):
                return data
            data = (data - m) / s
            if max_abs_value is not None:
                data = jp.clip(data, -max_abs_value, +max_abs_value)
            return data
        return jax.tree_util.tree_map(normalize_leaf, batch, mean, std)
    return normalize_fn

if apg_params['normalize_observations']:
    # normalize = running_statistics.normalize
    mean, std = params[0].mean, params[0].std
    normalize = make_normalize_fn(mean, std)
apg_network = network_factory(
    demo_env.observation_size, demo_env.action_size, preprocess_observations_fn=normalize
)
make_inference_fn = apg_networks.make_inference_fn(apg_network)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [8]:
demo_reset_fn = jax.jit(demo_env.reset)
rng = jax.random.PRNGKey(42)
rngs = jax.random.split(rng, 1) 
state = demo_reset_fn(rngs)
obs = state.obs
actions, _ = jit_inference_fn(obs, rng)
print("type of observations:", type(obs))
print("obs shape:", obs.shape)

type of observations: <class 'jaxlib._jax.ArrayImpl'>
obs shape: (1, 40)


In [9]:
obs_torch = torch.tensor(np.array(obs), dtype=torch.float32)
torch_actions, torch_std = torch_model(obs_torch)
torch_actions = torch_actions.detach().numpy()

jax_actions_np = np.array(actions)

diff = np.abs(jax_actions_np - torch_actions)

print("JAX actions:", jax_actions_np)
print("Torch actions:", torch_actions)
print("Max difference:", diff.max())

JAX actions: [[-0.00673548  0.06204866 -0.21067517  0.102987    0.03954133 -0.17692343
  -0.10118392  0.03572394 -0.10818056 -0.02216607  0.07562327 -0.15920609]]
Torch actions: [[-0.00673548  0.06204866 -0.21067518  0.10298702  0.03954135 -0.17692345
  -0.10118396  0.03572395 -0.10818054 -0.02216607  0.07562327 -0.15920605]]
Max difference: 4.233897325789382e-08


In [10]:
# Â±ÇÁ∫ßÂØπÊØî
compare_layers(params[1]['params'], torch_model, obs)


[Linear 0] max diff = 0.000000
[ELU 0] max diff = 0.000000
[LayerNorm 0] max diff = 0.000000
[Linear 1] max diff = 0.000000
[ELU 1] max diff = 0.000000
[LayerNorm 1] max diff = 0.000001
[Linear 2] max diff = 0.000001


In [15]:
import onnx
import onnxruntime as ort

torch_model.eval()
torch.onnx.export(
    torch_model,
    obs_torch,
    "/home/yxma/develop/mujoco_envs/mujoco_playground/mujoco_playground/experimental/sim2sim/onnx/go2_apg_policy.onnx",
    input_names=["obs"],
    output_names=["actions", "std"],
    dynamic_axes={
        "obs": {0: "batch_size"},
        "actions": {0: "batch_size"},
        "std": {0: "batch_size"},
    },
    opset_version=17,
    dynamo=False
)

  torch.onnx.export(
