Skip to content

Commit

Permalink
Add a Jupyter notebook for visualizing brax agents (#105)
Browse files Browse the repository at this point in the history
* Add a Jupyter notebook for visualizing brax agents
* Add a link to the brax visualization notebook in the notebooks README
  • Loading branch information
engintoklu committed Jun 14, 2024
1 parent 496e6e2 commit a9d434a
Show file tree
Hide file tree
Showing 2 changed files with 382 additions and 1 deletion.
381 changes: 381 additions & 0 deletions examples/notebooks/Brax_Experiments_Visualization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,381 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fc7b28eb-db2b-4a02-87d4-dfc44a9fecc2",
"metadata": {},
"source": [
"# Visualization of the `brax` Experiment Results\n",
"\n",
"Using this notebook, you can see the visualization of the agent trained by the notebook [Brax_Experiments_with_PGPE.ipynb](Brax_Experiments_with_PGPE.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5350a77-c84b-4d05-9e1c-4a6e0e2f6d0d",
"metadata": {},
"outputs": [],
"source": [
"# Name of the pickle file saved by PicklingLogger goes here:\n",
"FNAME = \"...\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4c4d0df-0480-4c5f-a4ab-d34727286cc1",
"metadata": {},
"outputs": [],
"source": [
"# Name of the environment\n",
"ENV_NAME = \"brax::humanoid\""
]
},
{
"cell_type": "markdown",
"id": "d64f3d04-680a-43c5-a807-dbaf9f917198",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7025f8b3-9982-4196-9790-9e8c1491dc06",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import torch\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5bccb74-a64a-4e55-9ae9-9686f772fa60",
"metadata": {},
"outputs": [],
"source": [
"with open(FNAME, \"rb\") as f:\n",
" loaded = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d71c7c3-bb36-47fb-88dc-bd285c54c44a",
"metadata": {},
"outputs": [],
"source": [
"# The unpickled object is a dictionary with these keys:\n",
"list(loaded.keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54c7bee4-bf5c-445c-b009-40484391461c",
"metadata": {},
"outputs": [],
"source": [
"# Loaded center solution\n",
"center = loaded[\"center\"]\n",
"center"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30eca014-be71-4162-b675-0eaf8b037bce",
"metadata": {},
"outputs": [],
"source": [
"# Loaded policy network\n",
"policy = loaded[\"policy\"]\n",
"policy"
]
},
{
"cell_type": "markdown",
"id": "78ce1f2f-cb89-432e-ad98-07f7e1d5d457",
"metadata": {},
"source": [
"Below, we put the values of the center solution into the policy network as parameters:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "545d693d-ccd9-4ea7-b068-06b832f32be5",
"metadata": {},
"outputs": [],
"source": [
"torch.nn.utils.vector_to_parameters(center, policy.parameters())"
]
},
{
"cell_type": "markdown",
"id": "e7c7581d-24ba-4f06-88e2-41bb3274cd37",
"metadata": {},
"source": [
"---\n",
"\n",
"# Visualizing the trained policy\n",
"\n",
"Now that we have our final policy, we manually run and visualize it."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d798a71-003f-4bed-9f12-391dc7b797c0",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30080d4b-227c-4bde-8009-7a97c6e177f1",
"metadata": {},
"outputs": [],
"source": [
"assert ENV_NAME.startswith(\"brax::\"), \"This notebook can only work with brax environments\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c64a6482-0533-40c8-b8f9-1b1ad3ee0019",
"metadata": {},
"outputs": [],
"source": [
"GOT_OLD_BRAX_ENV = ENV_NAME.startswith(\"brax::old::\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee69ba99-bcb7-4d21-9f5c-6c330cb3ed96",
"metadata": {},
"outputs": [],
"source": [
"if GOT_OLD_BRAX_ENV:\n",
" import brax.v1 as brax\n",
" import brax.v1.envs as brax_envs\n",
" import brax.v1.jumpy as jp\n",
" from brax.v1.jumpy import random_prngkey\n",
" from brax.v1.io import html, image\n",
"else:\n",
" import brax\n",
" import brax.envs as brax_envs\n",
" from jax.random import PRNGKey as random_prngkey\n",
" from brax.io import html, image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c99f0ea4-7638-4e8a-a0a2-d9f483c8c096",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import HTML, Image\n",
"from typing import Iterable, Optional\n",
"import random"
]
},
{
"cell_type": "markdown",
"id": "19d38fe1-9d6f-40ec-926d-3c4066a4b66a",
"metadata": {},
"source": [
"Below, we define a utility function named `use_policy(...)`.\n",
"\n",
"The expected arguments of `use_policy(...)` are as follows:\n",
"\n",
"- `torch_module`: The policy object, expected as a `nn.Module` instance.\n",
"- `x`: The observation, as an iterable of real numbers.\n",
"- `h`: The hidden state of the module, if such a state exists and if the module is recurrent. Otherwise, it can be left as None.\n",
"\n",
"The return values of this function are as follows:\n",
"\n",
"- The action recommended by the policy, as a numpy array\n",
"- The hidden state of the module, if the module is a recurrent one."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8c6554d-4bdb-49b9-8c2e-956bce6ddb8a",
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def use_policy(torch_module: nn.Module, x: Iterable, h: Optional[Iterable] = None) -> tuple:\n",
" x = torch.as_tensor(np.asarray(x), dtype=torch.float32)\n",
" if h is None:\n",
" result = torch_module(x)\n",
" else:\n",
" result = torch_module(x, h)\n",
"\n",
" if isinstance(result, tuple):\n",
" x, h = result\n",
" x = x.numpy()\n",
" else:\n",
" x = result.numpy()\n",
" h = None\n",
" \n",
" return x, h"
]
},
{
"cell_type": "markdown",
"id": "28c044fe-4639-411c-be78-ac09cbe5e78f",
"metadata": {},
"source": [
"We now initialize a new instance of our brax environment, and trigger the jit compilation on its `reset` and `step` methods."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1251e82-1d0f-4e43-a6c5-1ec4c0275dab",
"metadata": {},
"outputs": [],
"source": [
"env = brax_envs.create(env_name=\"humanoid\")\n",
"reset = jax.jit(env.reset)\n",
"step = jax.jit(env.step)"
]
},
{
"cell_type": "markdown",
"id": "55229cc2-aad5-4c78-b095-010a538adb40",
"metadata": {},
"source": [
"Below we run our policy and collect the states of the episodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "220ec837-1d1a-401e-a144-5516bdb3e493",
"metadata": {},
"outputs": [],
"source": [
"seed = random.randint(0, (2 ** 32) - 1)\n",
"state = reset(rng=random_prngkey(seed=seed))\n",
"\n",
"h = None\n",
"states = []\n",
"cumulative_reward = 0.0\n",
"\n",
"while True:\n",
" action, h = use_policy(policy, state.obs, h)\n",
" action = jnp.asarray(action)\n",
"\n",
" state = step(state, action)\n",
" cumulative_reward += float(state.reward)\n",
" states.append(state)\n",
" if np.abs(np.array(state.done)) > 1e-4:\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "c9bbe53a-40f5-4bf5-a063-47f1d88032d6",
"metadata": {},
"source": [
"Length of the episode and the total reward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80345e9c-0694-413d-81c2-205dfacdfb5a",
"metadata": {},
"outputs": [],
"source": [
"len(states), cumulative_reward"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d2e2506-5bb2-4b1f-bb00-004465008fa3",
"metadata": {},
"outputs": [],
"source": [
"def pipeline_state(state):\n",
" if hasattr(state, \"qp\"):\n",
" return state.qp\n",
" elif hasattr(state, \"pipeline_state\"):\n",
" return state.pipeline_state\n",
" else:\n",
" assert False"
]
},
{
"cell_type": "markdown",
"id": "8b6ec424-e6cc-453a-93fd-eb07e44c1bd6",
"metadata": {},
"source": [
"Visualization of the policy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29b9c3aa-8068-40bc-9ec2-e5ff759ed22c",
"metadata": {},
"outputs": [],
"source": [
"pipeline_states = [pipeline_state(state) for state in states]\n",
"\n",
"if hasattr(env.sys, \"tree_replace\"):\n",
" env_sys = env.sys.tree_replace({'opt.timestep': env.dt})\n",
"else:\n",
" env_sys = env.sys"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b50e36a7-dd33-4fc3-9f9d-2d7dca93110b",
"metadata": {},
"outputs": [],
"source": [
"HTML(html.render(env_sys, pipeline_states))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 1 addition & 1 deletion examples/notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pip install 'gym[box2d]'
The notebook examples are listed below:

- **[Gym Experiments with PGPE and CoSyNE](Gym_Experiments_with_PGPE_and_CoSyNE.ipynb):** demonstrates how you can solve "LunarLanderContinuous-v2" using both `PGPE` and `CoSyNE` following the configurations described in [the paper proposing ClipUp](https://dl.acm.org/doi/abs/10.1007/978-3-030-58115-2_36) and [the JMLR paper on the CoSyNE algorithm](https://www.jmlr.org/papers/volume9/gomez08a/gomez08a.pdf).
- **[Brax Experiments with PGPE](Brax_Experiments_with_PGPE.ipynb):** demonstrates how you can solve the brax task "humanoid" using PGPE, with GPU support, if available.
- **[Brax Experiments with PGPE](Brax_Experiments_with_PGPE.ipynb):** demonstrates how you can solve the brax task "humanoid" using PGPE, with GPU support, if available. See also [Brax_Experiments_Visualization.ipynb](Brax_Experiments_Visualization.ipynb) for visualizing evolved brax policies.
- **[Minimizing Lennard-Jones Atom Cluster Potentials](Minimizing_Lennard-Jones_Atom_Cluster_Potentials.ipynb):** recreates experiments from [the paper introducing `SNES`](https://dl.acm.org/doi/abs/10.1145/2001576.2001692), showing that the algorithm can effectively solve the challenging task of [minimising Lennard-Jones atom cluster potentials](https://pubs.acs.org/doi/abs/10.1021/jp970984n).
- **[Model Predictive Control with CEM](Model_Predictive_Control_with_CEM/):** demonstrates the application of [the Cross-Entropy Method `CEM`](https://link.springer.com/article/10.1023/A:1010091220143) to Model Predictive Control (MPC) of the MuJoCo task named "Reacher-v4".
- **[Training MNIST30K](Training_MNIST30K.ipynb):** recreates experiments [from a recent paper](https://www.deepmind.com/publications/non-differentiable-supervised-learning-with-evolution-strategies-and-hybrid-methods) which demonstrates that `SNES` can be used to solve supervised learning problems. The script in particular recreates the training of the 30K-parameter 'MNIST30K' model on the MNIST dataset, but can easily be reconfigured to recreate other experiments from that paper.
Expand Down

0 comments on commit a9d434a

Please sign in to comment.