Skip to content

Latest commit

 

History

History
65 lines (49 loc) · 1.31 KB

index.rst

File metadata and controls

65 lines (49 loc) · 1.31 KB

JAX reference documentation

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more.

For an introduction to JAX, start at the JAX GitHub page.

.. toctree::
   :maxdepth: 1
   :caption: Tutorials

   notebooks/quickstart
   notebooks/autodiff_cookbook
   notebooks/vmapped_log_probs
   Training a Simple Neural Network, with Tensorflow Datasets Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb>


.. toctree::
   :maxdepth: 1
   :caption: Advanced JAX Tutorials

   notebooks/Common_Gotchas_in_JAX
   notebooks/Custom_derivative_rules_for_Python_code
   notebooks/XLA_in_Python
   notebooks/How_JAX_primitives_work
   notebooks/Writing_custom_interpreters_in_Jax.ipynb

.. toctree::
   :maxdepth: 1
   :caption: Notes

   CHANGELOG
   faq
   jaxpr
   async_dispatch
   concurrency
   gpu_memory_allocation
   profiling
   device_memory_profiling
   pytrees
   rank_promotion_warning
   type_promotion

.. toctree::
   :maxdepth: 2
   :caption: Developer documentation

   developer
   jax_internal_api

.. toctree::
   :maxdepth: 3
   :caption: API documentation

   jax


Indices and tables