From 2a41c5c599416d213fb5be59abf495c33217b5cd Mon Sep 17 00:00:00 2001 From: Kate Baumli Date: Fri, 26 Nov 2021 06:36:10 -0800 Subject: [PATCH] Set up RLax sphinx documentation for readthedocs to build and serve documentation from the public github. PiperOrigin-RevId: 412441791 --- .readthedocs.yaml | 17 + docs/.gitignore | 1 + docs/Makefile | 19 + docs/api.rst | 731 +++++++++++++++++++++++++++++ docs/conf.py | 192 ++++++++ docs/ext/coverage_check.py | 58 +++ docs/index.rst | 153 ++++++ requirements/requirements-docs.txt | 11 + rlax/_src/test_utils.py | 42 ++ 9 files changed, 1224 insertions(+) create mode 100644 .readthedocs.yaml create mode 100644 docs/.gitignore create mode 100644 docs/Makefile create mode 100644 docs/api.rst create mode 100644 docs/conf.py create mode 100644 docs/ext/coverage_check.py create mode 100644 docs/index.rst create mode 100644 requirements/requirements-docs.txt create mode 100644 rlax/_src/test_utils.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..f003f39 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +sphinx: + builder: html + configuration: docs/conf.py + fail_on_warning: false + +python: + version: 3.7 + install: + - requirements: requirements/requirements-docs.txt + - requirements: requirements/requirements.txt + - method: setuptools + path: . diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..e35d885 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +_build diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..5128596 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..38caf73 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,731 @@ +* Values, including both state and action-values; +* Values for Non-linear generalizations of the Bellman equations. +* Return Distributions, aka distributional value functions; +* General Value Functions, for cumulants other than the main reward; +* Policies, via policy-gradients in both continuous and discrete action spaces. + +Value Learning +============== + +.. currentmodule:: rlax + +.. autosummary:: + + categorical_double_q_learning + categorical_l2_project + categorical_q_learning + categorical_td_learning + discounted_returns + double_q_learning + expected_sarsa + general_off_policy_returns_from_action_values + general_off_policy_returns_from_q_and_v + lambda_returns + leaky_vtrace + leaky_vtrace_td_error_and_advantage + n_step_bootstrapped_returns + persistent_q_learning + q_lambda + q_learning + quantile_expected_sarsa + quantile_q_learning + qv_learning + qv_max + retrace + retrace_continuous + sarsa + sarsa_lambda + td_lambda + td_learning + transformed_general_off_policy_returns_from_action_values + transformed_lambda_returns + transformed_n_step_q_learning + transformed_n_step_returns + transformed_q_lambda + transformed_retrace + vtrace + vtrace_td_error_and_advantage + + +Categorical Double Q Learning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_double_q_learning + +Categorical L2 Project +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_l2_project + +Categorical Q Learning +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_q_learning + +Categorical TD Learning +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_td_learning + +Discounted Returns +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: discounted_returns + +Double Q Learning +~~~~~~~~~~~~~~~~~ + +.. autofunction:: double_q_learning + +Expected SARSA +~~~~~~~~~~~~~~ + +.. autofunction:: expected_sarsa + +General Off Policy Returns From Action Values +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: general_off_policy_returns_from_action_values + +General Off Policy Returns From Q and V +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: general_off_policy_returns_from_q_and_v + +Lambda Returns +~~~~~~~~~~~~~~ + +.. autofunction:: lambda_returns + +Leaky VTrace +~~~~~~~~~~~~ + +.. autofunction:: leaky_vtrace + +N Step Bootstrapped Returns +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: n_step_bootstrapped_returns + +Leaky VTrace TD Error and Advantage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: leaky_vtrace_td_error_and_advantage + +Persistent Q Learning +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: persistent_q_learning + +Q-Lambda +~~~~~~~~ + +.. autofunction:: q_lambda + +Q Learning +~~~~~~~~~~ + +.. autofunction:: q_learning + + +Quantile Expected Sarsa +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: quantile_expected_sarsa + +Quantile Q Learning +~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: quantile_q_learning + +QV Learning +~~~~~~~~~~~ + +.. autofunction:: qv_learning + +QV Max +~~~~~~~ + +.. autofunction:: qv_max + +Retrace +~~~~~~~ + +.. autofunction:: retrace + +Retrace Continuous +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: retrace_continuous + +SARSA +~~~~~ + +.. autofunction:: sarsa + +SARSA Lambda +~~~~~~~~~~~~ + +.. autofunction:: sarsa_lambda + +TD Lambda +~~~~~~~~~ + +.. autofunction:: td_lambda + +TD Learning +~~~~~~~~~ + +.. autofunction:: td_learning + +Transformed General Off Policy Returns from Action Values +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformed_general_off_policy_returns_from_action_values + +Transformed Lambda Returns +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformed_lambda_returns + +Transformed N Step Q Learning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformed_n_step_q_learning + +Transformed N Step Returns +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformed_n_step_returns + +Transformed Q Lambda +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformed_q_lambda + + + Transformed Retrace +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformed_retrace + +Truncated Generalized Advantage Estimation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: truncated_generalized_advantage_estimation + +VTrace +~~~~~~ + +.. autofunction:: vtrace + + +Policy Optimization +=================== + +.. currentmodule:: rlax + +.. autosummary:: + + clipped_surrogate_pg_loss + dpg_loss + entropy_loss + LagrangePenalty + mpo_loss + mpo_compute_weights_and_temperature_loss + policy_gradient_loss + qpg_loss + rm_loss + rpg_loss + + +Clipped Surrogate PG Loss +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: clipped_surrogate_pg_loss + + +Compute Parametric KL Penalty and Dual Loss +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: compute_parametric_kl_penalty_and_dual_loss + +DPG Loss +~~~~~~~~ + +.. autofunction:: dpg_loss + +Entropy Loss +~~~~~~~~~~~~ + +.. autofunction:: entropy_loss + +Lagrange Penalty +~~~~~~~~~~~~~~~~ + +.. autoclass:: LagrangePenalty + :members: + + +MPO Compute Weights and Temperature Loss +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: mpo_compute_weights_and_temperature_loss + +MPO Loss +~~~~~~~~ + +.. autofunction:: mpo_loss + +Policy Gradient Loss +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: policy_gradient_loss + +QPG Loss +~~~~~~~~ + +.. autofunction:: qpg_loss + +RM Loss +~~~~~~~~ + +.. autofunction:: rm_loss + +RPG Loss +~~~~~~~~ + +.. autofunction:: rpg_loss + +MPO Compute Weights and Temperature Loss +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: vmpo_compute_weights_and_temperature_loss + +VMPO Loss +~~~~~~~~~ + +.. autofunction:: vmpo_loss + + +Exploration +=========== + +.. currentmodule:: rlax + +.. autosummary:: + + add_dirichlet_noise + add_gaussian_noise + add_ornstein_uhlenbeck_noise + episodic_memory_intrinsic_rewards + knn_query + + +Add Dirichlet Noise +~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: add_dirichlet_noise + +Add Gaussian Noise +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: add_gaussian_noise + +Add Ornstein Uhlenbeck Noise +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: add_ornstein_uhlenbeck_noise + +Episodic Memory Intrinsic Rewards +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: episodic_memory_intrinsic_rewards + +KNN Query +~~~~~~~~~ + +.. autofunction:: knn_query + + +Utilities +========= + + +.. currentmodule:: rlax + +.. autosummary:: + + AllSum + batched_index + clip_gradient + lhs_broadcast + one_hot + embed_oar + tree_map_zipped + tree_select + tree_split_key + tree_split_leaves + conditional_update + incremental_update + periodic_update + + +All Sum +~~~~~~~ + +.. autoclass:: AllSum + :members: + +Batched Index +~~~~~~~~~~~~~ + +.. autofunction:: batched_index + +Clip Gradient +~~~~~~~~~~~~~ + +.. autofunction:: clip_gradient + +LHS Broadcast +~~~~~~~~~~~~~ + +.. autofunction:: lhs_broadcast + +One Hot +~~~~~~ + +.. autofunction:: one_hot + +Embed OAR +~~~~~~~~~ + +.. autofunction:: embed_oar + +Tree Map Zipped +~~~~~~~~~~~~~~~ + +.. autofunction:: tree_map_zipped + +Tree Select +~~~~~~~~~~~ + +.. autofunction:: tree_select + +Tree Split Key +~~~~~~~~~~~~~~ + +.. autofunction:: tree_split_key + +Tree Split Leaves +~~~~~~~~~~~~~~~~~ + +.. autofunction:: tree_split_leaves + +Conditional Update +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: conditional_update + +Incremental Update +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: incremental_update + +Periodic Update +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: periodic_update + + +General Value Functions +======================= + + +.. currentmodule:: rlax + +.. autosummary:: + + pixel_control_rewards + feature_control_rewards + + +Pixel Control Rewards +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pixel_control_rewards + +Feature Control Rewards +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: feature_control_rewards + + + +Pop Art +======== + + +.. currentmodule:: rlax + +.. autosummary:: + + art + normalize + pop + popart + PopArtState + unnormalize + unnormalize_linear + +Art +~~~ + +.. autofunction:: art + +Normalize +~~~~~~~~~ + +.. autofunction:: normalize + +Pop +~~~ + +.. autofunction:: pop + +PopArt +~~~~~~ + +.. autofunction:: popart + +PopArtState +~~~~~~~~~~~ + +.. autoclass:: PopArtState + :members: + +Unnormalize +~~~~~~~~~~~ + +.. autofunction:: unnormalize + +Unnormalize Linear +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: unnormalize_linear + + + +Transforms +========== + + +.. currentmodule:: rlax + +.. autosummary:: + + HYPERBOLIC_SIN_PAIR + identity + IDENTITY_PAIR + logit + power + sigmoid + signed_expm1 + signed_hyperbolic + SIGNED_HYPERBOLIC_PAIR + signed_logp1 + SIGNED_LOGP1_PAIR + signed_parabolic + transform_from_2hot + transform_to_2hot + TxPair + + +Identity +~~~~~~~~ + +.. autofunction:: identity + + +Logit +~~~~~ + +.. autofunction:: logit + + +Power +~~~~~ + +.. autofunction:: power + + +Sigmoid +~~~~~~~ + +.. autofunction:: sigmoid + + +Signed Exponential +~~~~~~~~~~~~~ + +.. autofunction:: signed_expm1 + + +Signed Hyperbolic +~~~~~~~~~~~~~~~~~ + +.. autofunction:: signed_hyperbolic + + +Signed Logarithm +~~~~~~~~~~~~~~~~ + +.. autofunction:: signed_logp1 + + +Signed Parabolic +~~~~~~~~~~~~~~~~ + +.. autofunction:: signed_parabolic + + +Transform from 2 Hot +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transform_from_2hot + +Transform to 2 Hot +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transform_to_2hot + + +Losses +====== + +.. currentmodule:: rlax + +.. autosummary:: + + l2_loss + likelihood + log_loss + huber_loss + pixel_control_loss + + +L2 Loss +~~~~~~~ + +.. autofunction:: l2_loss + + +Likelihood +~~~~~~~~~~ + +.. autofunction:: likelihood + +Log Loss +~~~~~~~~ + +.. autofunction:: log_loss + +Huber Loss +~~~~~~~~~~ + +.. autofunction:: huber_loss + +Pixel Control Loss +~~~~~~~~~~~~~~~~~~ + + .. autofunction:: pixel_control_loss + + +Distributions +============= + +.. currentmodule:: rlax + +.. autosummary:: + + categorical_cross_entropy + categorical_importance_sampling_ratios + categorical_kl_divergence + categorical_sample + clipped_entropy_softmax + epsilon_greedy + epsilon_softmax + gaussian_diagonal + greedy + multivariate_normal_kl_divergence + safe_epsilon_softmax + softmax + squashed_gaussian + + +Categorical Cross Entropy +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_cross_entropy + + +Categorical Importance Sampling Ratios +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_importance_sampling_ratios + +Categorical KL Divergence +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_kl_divergence + +Categorical Sample +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: categorical_sample + +Clipped Entropy Softmax +~~~~~~~~~~~~~~~~~~~~~~~ + + .. autofunction:: clipped_entropy_softmax + +Epsilon Greedy +~~~~~~~~~~~~~~ + +.. autofunction:: epsilon_greedy + +Epsilon Softmax +~~~~~~~~~~~~~~~ + +.. autofunction:: epsilon_softmax + +Gaussian Diagonal +~~~~~~~~~~~~~~~~~ + +.. autofunction:: gaussian_diagonal + +Greedy +~~~~~~ + + .. autofunction:: greedy + +Multivariate Normal KL Divergence +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: multivariate_normal_kl_divergence + +Safe Epsilon Softmax +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: safe_epsilon_softmax + +Softmax +~~~~~~~ + +.. autofunction:: softmax + +Squashed Gaussian +~~~~~~~~~~~~~~~~~ + +.. autofunction:: squashed_gaussian + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..e5df9a2 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,192 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configuration file for the Sphinx documentation builder.""" + +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +# pylint: disable=g-bad-import-order +# pylint: disable=g-import-not-at-top +import inspect +import os +import sys +import typing + + +def _add_annotations_import(path): + """Appends a future annotations import to the file at the given path.""" + with open(path) as f: + contents = f.read() + if contents.startswith('from __future__ import annotations'): + # If we run sphinx multiple times then we will append the future import + # multiple times too. + return + + assert contents.startswith('#'), (path, contents.split('\n')[0]) + with open(path, 'w') as f: + # NOTE: This is subtle and not unit tested, we're prefixing the first line + # in each Python file with this future import. It is important to prefix + # not insert a newline such that source code locations are accurate (we link + # to GitHub). The assertion above ensures that the first line in the file is + # a comment so it is safe to prefix it. + f.write('from __future__ import annotations ') + f.write(contents) + + +def _recursive_add_annotations_import(): + for path, _, files in os.walk('../rlax/'): + for file in files: + if file.endswith('.py'): + _add_annotations_import(os.path.abspath(os.path.join(path, file))) + +if 'READTHEDOCS' in os.environ: + _recursive_add_annotations_import() + +typing.get_type_hints = lambda obj, *unused: obj.__annotations__ +sys.path.insert(0, os.path.abspath('../')) +sys.path.append(os.path.abspath('ext')) + +import rlax +import sphinxcontrib.katex as katex + +# -- Project information ----------------------------------------------------- + +project = 'RLax' +copyright = '2021, DeepMind' # pylint: disable=redefined-builtin +author = 'RLax Contributors' + +# -- General configuration --------------------------------------------------- + +master_doc = 'index' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', + 'sphinx.ext.inheritance_diagram', + 'sphinx.ext.intersphinx', + 'sphinx.ext.linkcode', + 'sphinx.ext.napoleon', + 'sphinxcontrib.bibtex', + 'sphinxcontrib.katex', + 'sphinx_autodoc_typehints', + 'sphinx_rtd_theme', + 'coverage_check', + 'myst_nb', # This is used for the .ipynb notebooks +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# -- Options for autodoc ----------------------------------------------------- + +autodoc_default_options = { + 'member-order': 'bysource', + 'special-members': True, + 'exclude-members': '__repr__, __str__, __weakref__', +} + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +# html_favicon = '_static/favicon.ico' + +# -- Options for myst ------------------------------------------------------- + +jupyter_execute_notebooks = 'force' +execution_allow_errors = False + +# -- Options for katex ------------------------------------------------------ + +# See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html +latex_macros = r""" + \def \d #1{\operatorname{#1}} +""" + +# Translate LaTeX macros to KaTeX and add to options for HTML builder +katex_macros = katex.latex_defs_to_katex_macros(latex_macros) +katex_options = 'macros: {' + katex_macros + '}' + +# Add LaTeX macros for LATEX builder +latex_elements = {'preamble': latex_macros} + +# -- Source code links ------------------------------------------------------- + + +def linkcode_resolve(domain, info): + """Resolve a GitHub URL corresponding to Python object.""" + if domain != 'py': + return None + + try: + mod = sys.modules[info['module']] + except ImportError: + return None + + obj = mod + try: + for attr in info['fullname'].split('.'): + obj = getattr(obj, attr) + except AttributeError: + return None + else: + obj = inspect.unwrap(obj) + + try: + filename = inspect.getsourcefile(obj) + except TypeError: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except OSError: + return None + + # TODO(slebedev): support tags after we release an initial version. + return 'https://github.com/deepmind/rlax/tree/master/rlax/%s#L%d#L%d' % ( + os.path.relpath(filename, start=os.path.dirname( + rlax.__file__)), lineno, lineno + len(source) - 1) + + +# -- Intersphinx configuration ----------------------------------------------- + +intersphinx_mapping = { + 'jax': ('https://jax.readthedocs.io/en/latest/', None), +} + +source_suffix = ['.rst', '.md', '.ipynb'] diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py new file mode 100644 index 0000000..a6340ea --- /dev/null +++ b/docs/ext/coverage_check.py @@ -0,0 +1,58 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Asserts all public symbols are covered in the docs.""" + +from typing import Any, Mapping + +import rlax +from rlax._src import test_utils +from sphinx import application +from sphinx import builders +from sphinx import errors + + +def rlax_public_symbols(): + names = set() + for module_name, module in test_utils.find_internal_python_modules(rlax): + for name in module.__all__: + names.add(module_name + "." + name) + return names + + +class RLaxCoverageCheck(builders.Builder): + """Builder that checks all public symbols are included.""" + + name = "coverage_check" + + def get_outdated_docs(self) -> str: + return "coverage_check" + + def write(self, *ignored: Any) -> None: + pass + + def finish(self) -> None: + documented_objects = frozenset(self.env.domaindata["py"]["objects"]) + undocumented_objects = set(rlax_public_symbols()) - documented_objects + if undocumented_objects: + undocumented_objects = tuple(sorted(undocumented_objects)) + raise errors.SphinxError( + "All public symbols must be included in our documentation, did you " + "forget to add an entry to `api.rst`?\n" + f"Undocumented symbols: {undocumented_objects}") + + +def setup(app: application.Sphinx) -> Mapping[str, Any]: + app.add_builder(RLaxCoverageCheck) + return dict(version=rlax.__version__, parallel_read_safe=True) diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..a6e395e --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,153 @@ +:github_url: https://github.com/deepmind/rlax/tree/master/docs + +RLax +----- + +RLax (pronounced "relax") is a library built on top of JAX that exposes useful +building blocks for implementing reinforcement learning agents. + + + +Installation +------------ + +RLax can be installed with pip directly from github, with the following command: + +`pip install git+git://github.com/deepmind/rlax.git`. + +or from PyPI: + +`pip install rlax` + +All RLax code may then be just in time compiled for different hardware +(e.g. CPU, GPU, TPU) using `jax.jit`. + +In order to run the `examples/` you will also need to clone the repo and +install the additional requirements: +`optax `_, +`haiku `_, and +`bsuite `_. + + +Content +------- +The operations and functions provided are not complete algorithms, but +implementations of reinforcement learning specific mathematical operations that +are needed when building fully-functional agents capable of learning: + + + +The library supports both on-policy and off-policy learning (i.e. learning from +data sampled from a policy different from the agent's policy). + + +Usage +----- + +See `examples/` for examples of using some of the functions in RLax to +implement a few simple reinforcement learning agents, and demonstrate learning +on BSuite's version of the Catch environment (a common unit-test for +agent development in the reinforcement learning literature): + +Other examples of JAX reinforcement learning agents using `rlax` can be found in +`bsuite `_. + + +Background +---------- +Reinforcement learning studies the problem of a learning system (the *agent*), +which must learn to interact with the universe it is embedded in (the +*environment*). + +Agent and environment interact on discrete steps. On each step the agent selects +an *action*, and is provided in return a (partial) snapshot of the state of the +environment (the *observation*), and a scalar feedback signal (the *reward*). + +The behaviour of the agent is characterized by a probability distribution over +actions, conditioned on past observations of the environment (the *policy*). The +agents seeks a policy that, from any given step, maximises the discounted +cumulative reward that will be collected from that point onwards (the *return*). + +Often the agent policy or the environment dynamics itself are stochastic. In +this case the return is a random variable, and the optimal agent's policy is +typically more precisely specified as a policy that maximises the expectation of +the return (the *value*), under the agent's and environment's stochasticity. + +Reinforcement Learning Algorithms +--------------------------------- + + +There are three prototypical families of reinforcement learning algorithms: + +1. those that estimate the value of states and actions, and infer a policy by + *inspection* (e.g. by selecting the action with highest estimated value) +2. those that learn a model of the environment (capable of predicting the + observations and rewards) and infer a policy via *planning*. +3. those that parameterize a policy that can be directly *executed*, + +In any case, policies, values or models are just functions. In deep +reinforcement learning such functions are represented by a neural network. +In this setting, it is common to formulate reinforcement learning updates as +differentiable pseudo-loss functions (analogously to (un-)supervised learning). +Under automatic differentiation, the original update rule is recovered. + +Note however, that in particular, the updates are only valid if the input data +is sampled in the correct manner. For example, a policy gradient loss is only +valid if the input trajectory is an unbiased sample from the current policy; +i.e. the data are on-policy. The library cannot check or enforce such +constraints. Links to papers describing how each operation is used are however +provided in the functions' doc-strings. + + +Naming Conventions and Developer Guidelines +------------------------------------------- + +We define functions and operations for agents interacting with a single stream +of experience. The JAX construct `vmap` can be used to apply these same +functions to batches (e.g. to support *replay* and *parallel* data generation). + +Many functions consider policies, actions, rewards, values, in consecutive +timesteps in order to compute their outputs. In this case the suffix `_t` and +`tm1` is often to clarify on which step each input was generated, e.g: + +* `q_tm1`: the action value in the `source` state of a transition. +* `a_tm1`: the action that was selected in the `source` state. +* `r_t`: the resulting rewards collected in the `destination` state. +* `discount_t`: the `discount` associated with a transition. +* `q_t`: the action values in the `destination` state. + +Extensive testing is provided for each function. All tests should also verify +the output of `rlax` functions when compiled to XLA using `jax.jit` and when +performing batch operations using `jax.vmap`. + + + +.. toctree:: + :caption: API Documentation + :maxdepth: 2 + + api + + +Contribute +---------- + +- `Issue tracker `_ +- `Source code `_ + +Support +------- + +If you are having issues, please let us know by filing an issue on our +`issue tracker `_. + +License +------- + +RLax is licensed under the Apache 2.0 License. + + +Indices and Tables +================== + +* :ref:`genindex` diff --git a/requirements/requirements-docs.txt b/requirements/requirements-docs.txt new file mode 100644 index 0000000..ef04820 --- /dev/null +++ b/requirements/requirements-docs.txt @@ -0,0 +1,11 @@ +sphinx==3.3.0 +sphinx_rtd_theme==0.5.0 +sphinxcontrib-katex==0.7.1 +sphinxcontrib-bibtex==1.0.0 +sphinx-autodoc-typehints==1.11.1 +IPython==7.16.1 +ipykernel==5.3.4 +pandoc==1.0.2 +myst_nb==0.13.1 +docutils==0.16 +matplotlib==3.5.0 diff --git a/rlax/_src/test_utils.py b/rlax/_src/test_utils.py new file mode 100644 index 0000000..811692f --- /dev/null +++ b/rlax/_src/test_utils.py @@ -0,0 +1,42 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Testing utilities for RLax.""" + +import inspect +import types +from typing import Sequence, Tuple + + +def find_internal_python_modules( + root_module: types.ModuleType, +) -> Sequence[Tuple[str, types.ModuleType]]: + """Returns `(name, module)` for all RLax submodules under `root_module`.""" + modules = set([(root_module.__name__, root_module)]) + visited = set() + to_visit = [root_module] + + while to_visit: + mod = to_visit.pop() + visited.add(mod) + + for name in dir(mod): + obj = getattr(mod, name) + if inspect.ismodule(obj) and obj not in visited: + if obj.__name__.startswith('rlax'): + if '_src' not in obj.__name__: + to_visit.append(obj) + modules.add((obj.__name__, obj)) + + return sorted(modules)