-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a Jupyter notebook for visualizing brax agents (#105)
* Add a Jupyter notebook for visualizing brax agents * Add a link to the brax visualization notebook in the notebooks README
- Loading branch information
1 parent
496e6e2
commit a9d434a
Showing
2 changed files
with
382 additions
and
1 deletion.
There are no files selected for viewing
381 changes: 381 additions & 0 deletions
381
examples/notebooks/Brax_Experiments_Visualization.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,381 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fc7b28eb-db2b-4a02-87d4-dfc44a9fecc2", | ||
"metadata": {}, | ||
"source": [ | ||
"# Visualization of the `brax` Experiment Results\n", | ||
"\n", | ||
"Using this notebook, you can see the visualization of the agent trained by the notebook [Brax_Experiments_with_PGPE.ipynb](Brax_Experiments_with_PGPE.ipynb)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d5350a77-c84b-4d05-9e1c-4a6e0e2f6d0d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Name of the pickle file saved by PicklingLogger goes here:\n", | ||
"FNAME = \"...\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e4c4d0df-0480-4c5f-a4ab-d34727286cc1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Name of the environment\n", | ||
"ENV_NAME = \"brax::humanoid\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d64f3d04-680a-43c5-a807-dbaf9f917198", | ||
"metadata": {}, | ||
"source": [ | ||
"---" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7025f8b3-9982-4196-9790-9e8c1491dc06", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pickle\n", | ||
"import torch\n", | ||
"from torch import nn" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d5bccb74-a64a-4e55-9ae9-9686f772fa60", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"with open(FNAME, \"rb\") as f:\n", | ||
" loaded = pickle.load(f)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0d71c7c3-bb36-47fb-88dc-bd285c54c44a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# The unpickled object is a dictionary with these keys:\n", | ||
"list(loaded.keys())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "54c7bee4-bf5c-445c-b009-40484391461c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Loaded center solution\n", | ||
"center = loaded[\"center\"]\n", | ||
"center" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "30eca014-be71-4162-b675-0eaf8b037bce", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Loaded policy network\n", | ||
"policy = loaded[\"policy\"]\n", | ||
"policy" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "78ce1f2f-cb89-432e-ad98-07f7e1d5d457", | ||
"metadata": {}, | ||
"source": [ | ||
"Below, we put the values of the center solution into the policy network as parameters:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "545d693d-ccd9-4ea7-b068-06b832f32be5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"torch.nn.utils.vector_to_parameters(center, policy.parameters())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e7c7581d-24ba-4f06-88e2-41bb3274cd37", | ||
"metadata": {}, | ||
"source": [ | ||
"---\n", | ||
"\n", | ||
"# Visualizing the trained policy\n", | ||
"\n", | ||
"Now that we have our final policy, we manually run and visualize it." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8d798a71-003f-4bed-9f12-391dc7b797c0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "30080d4b-227c-4bde-8009-7a97c6e177f1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"assert ENV_NAME.startswith(\"brax::\"), \"This notebook can only work with brax environments\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c64a6482-0533-40c8-b8f9-1b1ad3ee0019", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"GOT_OLD_BRAX_ENV = ENV_NAME.startswith(\"brax::old::\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ee69ba99-bcb7-4d21-9f5c-6c330cb3ed96", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if GOT_OLD_BRAX_ENV:\n", | ||
" import brax.v1 as brax\n", | ||
" import brax.v1.envs as brax_envs\n", | ||
" import brax.v1.jumpy as jp\n", | ||
" from brax.v1.jumpy import random_prngkey\n", | ||
" from brax.v1.io import html, image\n", | ||
"else:\n", | ||
" import brax\n", | ||
" import brax.envs as brax_envs\n", | ||
" from jax.random import PRNGKey as random_prngkey\n", | ||
" from brax.io import html, image" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c99f0ea4-7638-4e8a-a0a2-d9f483c8c096", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from IPython.display import HTML, Image\n", | ||
"from typing import Iterable, Optional\n", | ||
"import random" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "19d38fe1-9d6f-40ec-926d-3c4066a4b66a", | ||
"metadata": {}, | ||
"source": [ | ||
"Below, we define a utility function named `use_policy(...)`.\n", | ||
"\n", | ||
"The expected arguments of `use_policy(...)` are as follows:\n", | ||
"\n", | ||
"- `torch_module`: The policy object, expected as a `nn.Module` instance.\n", | ||
"- `x`: The observation, as an iterable of real numbers.\n", | ||
"- `h`: The hidden state of the module, if such a state exists and if the module is recurrent. Otherwise, it can be left as None.\n", | ||
"\n", | ||
"The return values of this function are as follows:\n", | ||
"\n", | ||
"- The action recommended by the policy, as a numpy array\n", | ||
"- The hidden state of the module, if the module is a recurrent one." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e8c6554d-4bdb-49b9-8c2e-956bce6ddb8a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@torch.no_grad()\n", | ||
"def use_policy(torch_module: nn.Module, x: Iterable, h: Optional[Iterable] = None) -> tuple:\n", | ||
" x = torch.as_tensor(np.asarray(x), dtype=torch.float32)\n", | ||
" if h is None:\n", | ||
" result = torch_module(x)\n", | ||
" else:\n", | ||
" result = torch_module(x, h)\n", | ||
"\n", | ||
" if isinstance(result, tuple):\n", | ||
" x, h = result\n", | ||
" x = x.numpy()\n", | ||
" else:\n", | ||
" x = result.numpy()\n", | ||
" h = None\n", | ||
" \n", | ||
" return x, h" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "28c044fe-4639-411c-be78-ac09cbe5e78f", | ||
"metadata": {}, | ||
"source": [ | ||
"We now initialize a new instance of our brax environment, and trigger the jit compilation on its `reset` and `step` methods." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c1251e82-1d0f-4e43-a6c5-1ec4c0275dab", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"env = brax_envs.create(env_name=\"humanoid\")\n", | ||
"reset = jax.jit(env.reset)\n", | ||
"step = jax.jit(env.step)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "55229cc2-aad5-4c78-b095-010a538adb40", | ||
"metadata": {}, | ||
"source": [ | ||
"Below we run our policy and collect the states of the episodes." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "220ec837-1d1a-401e-a144-5516bdb3e493", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"seed = random.randint(0, (2 ** 32) - 1)\n", | ||
"state = reset(rng=random_prngkey(seed=seed))\n", | ||
"\n", | ||
"h = None\n", | ||
"states = []\n", | ||
"cumulative_reward = 0.0\n", | ||
"\n", | ||
"while True:\n", | ||
" action, h = use_policy(policy, state.obs, h)\n", | ||
" action = jnp.asarray(action)\n", | ||
"\n", | ||
" state = step(state, action)\n", | ||
" cumulative_reward += float(state.reward)\n", | ||
" states.append(state)\n", | ||
" if np.abs(np.array(state.done)) > 1e-4:\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c9bbe53a-40f5-4bf5-a063-47f1d88032d6", | ||
"metadata": {}, | ||
"source": [ | ||
"Length of the episode and the total reward:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "80345e9c-0694-413d-81c2-205dfacdfb5a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"len(states), cumulative_reward" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8d2e2506-5bb2-4b1f-bb00-004465008fa3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def pipeline_state(state):\n", | ||
" if hasattr(state, \"qp\"):\n", | ||
" return state.qp\n", | ||
" elif hasattr(state, \"pipeline_state\"):\n", | ||
" return state.pipeline_state\n", | ||
" else:\n", | ||
" assert False" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8b6ec424-e6cc-453a-93fd-eb07e44c1bd6", | ||
"metadata": {}, | ||
"source": [ | ||
"Visualization of the policy:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "29b9c3aa-8068-40bc-9ec2-e5ff759ed22c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline_states = [pipeline_state(state) for state in states]\n", | ||
"\n", | ||
"if hasattr(env.sys, \"tree_replace\"):\n", | ||
" env_sys = env.sys.tree_replace({'opt.timestep': env.dt})\n", | ||
"else:\n", | ||
" env_sys = env.sys" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b50e36a7-dd33-4fc3-9f9d-2d7dca93110b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"HTML(html.render(env_sys, pipeline_states))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters