Skip to content

Commit

Permalink
DOC: add transformations doc to HTML & reorganize contents
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 5, 2021
1 parent 76c1ec3 commit 8e7b405
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 21 deletions.
10 changes: 4 additions & 6 deletions README.md
Expand Up @@ -20,8 +20,7 @@

## What is JAX?

JAX is [Autograd](https://github.com/hips/autograd) and
[XLA](https://www.tensorflow.org/xla),
JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla),
brought together for high-performance machine learning research.

With its updated version of [Autograd](https://github.com/hips/autograd),
Expand All @@ -32,8 +31,7 @@ derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
and the two can be composed arbitrarily to any order.

What’s new is that JAX uses
[XLA](https://www.tensorflow.org/xla)
What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla)
to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
under the hood by default, with library calls getting just-in-time compiled and
executed. But JAX also lets you just-in-time compile your own Python functions
Expand Down Expand Up @@ -220,14 +218,14 @@ function:
```python
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = inputs
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b # `input_vec` on the right-hand side!
activations = jnp.tanh(outputs)
return outputs
```

We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the
We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the
left side of `inputs`, but we’ve written this particular prediction function to
apply only to single input vectors. If we wanted to apply this function to a
batch of inputs at once, semantically we could just write
Expand Down
37 changes: 22 additions & 15 deletions docs/index.rst
@@ -1,11 +1,9 @@
JAX reference documentation
===========================

Composable transformations of Python+NumPy programs: differentiate, vectorize,
JIT to GPU/TPU, and more.

For an introduction to JAX, start at the
`JAX GitHub page <https://github.com/google/jax>`_.
JAX is Autograd_ and XLA_, brought together for high-performance numerical computing and machine learning research.
It provides composable transformations of Python+NumPy programs: differentiate, vectorize,
parallelize, Just-In-Time compile to GPU/TPU, and more.

.. toctree::
:maxdepth: 1
Expand All @@ -22,9 +20,23 @@ For an introduction to JAX, start at the

.. toctree::
:maxdepth: 1
:caption: Advanced JAX Tutorials
:caption: Reference Documentation

faq
transformations
async_dispatch
jaxpr
notebooks/convolutions
pytrees
type_promotion
errors
glossary
CHANGELOG

.. toctree::
:maxdepth: 1
:caption: Advanced JAX Tutorials

notebooks/autodiff_cookbook
notebooks/vmapped_log_probs
notebooks/neural_network_with_tfds_data
Expand All @@ -36,25 +48,16 @@ For an introduction to JAX, start at the
notebooks/maml
notebooks/score_matching


.. toctree::
:maxdepth: 1
:caption: Notes

CHANGELOG
faq
errors
jaxpr
async_dispatch
concurrency
gpu_memory_allocation
profiling
device_memory_profiling
pytrees
rank_promotion_warning
type_promotion
custom_vjp_update
glossary

.. toctree::
:maxdepth: 2
Expand All @@ -77,3 +80,7 @@ Indices and tables
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`


.. _Autograd: https://github.com/hips/autograd)
.. _XLA: https://www.tensorflow.org/xla
2 changes: 2 additions & 0 deletions docs/jax.lax.rst
Expand Up @@ -163,6 +163,8 @@ Custom gradient operators
custom_linear_solve
custom_root

.. _jax-parallel-operators:

Parallel operators
------------------

Expand Down
4 changes: 4 additions & 0 deletions docs/jax.rst
Expand Up @@ -21,6 +21,8 @@ Subpackages
jax.dlpack
jax.profiler

.. _jax-jit:

Just-in-time compilation (:code:`jit`)
--------------------------------------

Expand All @@ -37,6 +39,8 @@ Just-in-time compilation (:code:`jit`)
default_backend
named_call

.. _jax-grad:

Automatic differentiation
-------------------------

Expand Down

0 comments on commit 8e7b405

Please sign in to comment.