# Example of visualization for Brax and Mujoco playground Environments

First, you need open the checkpoint function by setting `checkpoint.enable=true` in your training script. For example

```shell
python scripts/train.py agent=exp/ppo/playground/locomotion/Go1Joystick env=playground/locomotion/Go1JoystickFlatTerrain checkpoint.enable=true
```

After training, the checkpoints are stored in the `./outputs/train/<timestamp>/checkpoints` folder (or in `./multirun`, see [logging](https://evorl.readthedocs.io/latest/guide/quickstart.html#logging)). Then use the following code for visualization.

## Prerequisite

* Ensure the related graphic drivers are installed: [Guide](https://github.com/google-deepmind/dm_control?tab=readme-ov-file#rendering)

## Visualization

In [None]:
from functools import partial
import os
import hydra
import mediapy
from html import escape
from IPython.display import HTML, display
from omegaconf import OmegaConf
from hydra_utils import (
    set_omegaconf_resolvers,
    set_absl_log_level,
)

os.environ["MUJOCO_GL"] = "egl"  # Use EGL for headless GPU rendering
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
set_absl_log_level("warning")
set_omegaconf_resolvers()

In [None]:
import mujoco
import brax
import brax.io.html
import jax

from evorl.workflows import Workflow
from evorl.envs import create_env, AutoresetMode
from evorl.utils.orbax_utils import load
from evorl.utils.jax_utils import tree_get
from evorl.sample_batch import SampleBatch

In [None]:
def brax_render(env, rollout, **render_kwargs) -> None:
    braxenv = env.unwrapped.unwrapped
    images = braxenv.render(rollout, **render_kwargs)
    fps = 1.0 / braxenv.dt
    print(f"FPS: {fps:.2f}")
    mediapy.show_video(images, fps=fps)


def brax_html_render(env, rollout, **render_kwargs):
    braxenv = env.unwrapped.unwrapped

    html_str = brax.io.html.render(
        braxenv.sys.tree_replace({"opt.timestep": braxenv.dt}),
        rollout,
        **render_kwargs,
    )

    iframe_str = f"""
        <iframe
            height="{render_kwargs.get("height", 480)}"
            srcdoc="{escape(html_str)}"
            frameborder="0"
            allowfullscreen
        ></iframe>
        """

    display(HTML(iframe_str))


def playground_render(env, rollout, **render_kwargs) -> None:
    brax_render(env, rollout, **render_kwargs)


def visualize(path, agent, env, **render_kwargs) -> None:
    with hydra.initialize(
        version_base=None, config_path="../configs", job_name="visualize"
    ):
        config = hydra.compose(
            config_name="config", overrides=[f"agent={agent}", f"env={env}"]
        )

    print(OmegaConf.to_yaml(config, resolve=True))

    workflow_cls = hydra.utils.get_class(config.workflow_cls)
    workflow_cls = type(workflow_cls.__name__, (workflow_cls,), {})

    devices = jax.local_devices()
    if len(devices) > 1:
        raise NotImplementedError
    else:
        workflow: Workflow = workflow_cls.build_from_config(
            config, enable_jit=config.enable_jit
        )

    try:
        state = workflow.init(jax.random.PRNGKey(config.seed))
        state = load(path, state)
        agent_state = state.agent_state
        agent = workflow.agent

        env = create_env(
            config.env,
            episode_length=config.env.max_episode_steps,
            parallel=1,
            autoreset_mode=AutoresetMode.DISABLED,
        )

        if config.env.env_type == "brax":
            retrieve_fn = lambda state: state.env_state.pipeline_state
        elif config.env.env_type == "playground":
            retrieve_fn = lambda state: state.env_state
        else:
            raise NotImplementedError

        def evaluate(env, action_fn, key):
            jit_reset = jax.jit(env.reset)
            jit_step = jax.jit(env.step)

            key, env_key = jax.random.split(key)
            env_state = jit_reset(env_key)
            rollout = [tree_get(retrieve_fn(env_state), 0)]

            for i in range(config.env.max_episode_steps):
                key, action_key = jax.random.split(key)
                action, _ = action_fn(SampleBatch(obs=env_state.obs), action_key)

                env_state = jit_step(env_state, action)
                rollout.append(tree_get(retrieve_fn(env_state), 0))

                if env_state.done[0]:
                    break

            print("rollout length:", len(rollout) - 1)

            return rollout

        key = jax.random.PRNGKey(0)
        rollout = evaluate(env, partial(agent.evaluate_actions, agent_state), key)

        if config.env.env_type == "brax":
            if render_kwargs.get("use_html", False):
                del render_kwargs["use_html"]
                brax_html_render(env, rollout, **render_kwargs)
            else:
                brax_render(env, rollout, **render_kwargs)
        elif config.env.env_type == "playground":
            playground_render(env, rollout, **render_kwargs)
        else:
            raise NotImplementedError

    except Exception as e:
        raise e
    finally:
        workflow.close()

In [None]:
# Example of a playground env:

# scene_option = mujoco.MjvOption()
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False

render_kwargs = dict(
    camera="track",
    # scene_option=scene_option,
    width=640,
    height=480,
)

root_path = os.path.abspath("../")

# Note: make sure `path` is an absolute path
visualize(
    path=os.path.join(
        root_path,
        "output/train/2025-04-16_12-59-31/checkpoints/1200/default",
    ),
    agent="exp/ppo/playground/locomotion/Go1Joystick",
    env="playground/locomotion/Go1JoystickRoughTerrain",
    **render_kwargs,
)

In [None]:
# Example of a brax env:

render_kwargs = dict(
    camera="track",
    width=640,
    height=480,
)

root_path = os.path.abspath("../")

visualize(
    path=os.path.join(
        root_path,
        "outputs/train/2025-04-22_23-14-16/checkpoints/153/default",
    ),
    agent="exp/ppo/brax/humanoid",
    env="brax/humanoid",
    **render_kwargs,
)

In [None]:
# Example of a brax env with HTML rendering:

render_kwargs = dict(
    height=480,
    use_html = True
)

root_path = os.path.abspath("../")

visualize(
    path=os.path.join(
        root_path,
        "outputs/train/2025-04-22_23-14-16/checkpoints/153/default",
    ),
    agent="exp/ppo/brax/humanoid",
    env="brax/humanoid",
    **render_kwargs,
)