Skip to content

Commit

Permalink
DOC: move transformations doc from README to HTML doc
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 3, 2021
1 parent 9187873 commit 23fbc42
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 260 deletions.
247 changes: 1 addition & 246 deletions README.md
Expand Up @@ -8,7 +8,6 @@
![PyPI version](https://img.shields.io/pypi/v/jax)

[**Quickstart**](#quickstart-colab-in-the-cloud)
| [**Transformations**](#transformations)
| [**Install guide**](#installation)
| [**Neural net libraries**](#neural-network-libraries)
| [**Change logs**](https://jax.readthedocs.io/en/latest/CHANGELOG.html)
Expand All @@ -32,8 +31,7 @@ derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
and the two can be composed arbitrarily to any order.

What’s new is that JAX uses
[XLA](https://www.tensorflow.org/xla)
What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla)
to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
under the hood by default, with library calls getting just-in-time compiled and
executed. But JAX also lets you just-in-time compile your own Python functions
Expand Down Expand Up @@ -78,7 +76,6 @@ perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grad

### Contents
* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
* [Transformations](#transformations)
* [Current gotchas](#current-gotchas)
* [Installation](#installation)
* [Neural net libraries](#neural-network-libraries)
Expand Down Expand Up @@ -108,248 +105,6 @@ and [`optimizers` for first-order stochastic
optimization](https://github.com/google/jax/tree/master/jax/experimental/README.md#first-order-optimization),
or the [examples](https://github.com/google/jax/tree/master/examples).

## Transformations

At its core, JAX is an extensible system for transforming numerical functions.
Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`.

### Automatic differentiation with `grad`

JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
The most popular function is
[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
for reverse-mode gradients:

```python
from jax import grad
import jax.numpy as jnp

def tanh(x): # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.4199743
```

You can differentiate to any order with `grad`.

```python
print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673
```

For more advanced autodiff, you can use
[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
reverse-mode vector-Jacobian products and
[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
one another, and with other JAX transformations. Here's one way to compose those
to make a function that efficiently computes [full Hessian
matrices](https://jax.readthedocs.io/en/latest/jax.html#jax.hessian):

```python
from jax import jit, jacfwd, jacrev

def hessian(fun):
return jit(jacfwd(jacrev(fun)))
```

As with [Autograd](https://github.com/hips/autograd), you're free to use
differentiation with Python control structures:

```python
def abs_val(x):
if x > 0:
return x
else:
return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
```

See the [reference docs on automatic
differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
and the [JAX Autodiff
Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
for more.

### Compilation with `jit`

You can use XLA to compile your functions end-to-end with
[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
used either as an `@jit` decorator or as a higher-order function.

```python
import jax.numpy as jnp
from jax import jit

def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
```

You can mix `jit` and `grad` and any other JAX transformation however you like.

Using `jit` puts constraints on the kind of Python control flow
the function can use; see
the [Gotchas
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
for more.

### Auto-vectorization with `vmap`

[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
the vectorizing map.
It has the familiar semantics of mapping a function along array axes, but
instead of keeping the loop on the outside, it pushes the loop down into a
function’s primitive operations for better performance.

Using `vmap` can save you from having to carry around batch dimensions in your
code. For example, consider this simple *unbatched* neural network prediction
function:

```python
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = inputs
for W, b in params:
outputs = jnp.dot(W, activations) + b # `input_vec` on the right-hand side!
activations = jnp.tanh(outputs)
return outputs
```

We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the
left side of `inputs`, but we’ve written this particular prediction function to
apply only to single input vectors. If we wanted to apply this function to a
batch of inputs at once, semantically we could just write

```python
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
```

But pushing one example through the network at a time would be slow! It’s better
to vectorize the computation, so that at every layer we’re doing matrix-matrix
multiplication rather than matrix-vector multiplication.

The `vmap` function does that transformation for us. That is, if we write

```python
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
```

then the `vmap` function will push the outer loop inside the function, and our
machine will end up executing matrix-matrix multiplications exactly as if we’d
done the batching by hand.

It’s easy enough to manually batch a simple neural network without `vmap`, but
in other cases manual vectorization can be impractical or impossible. Take the
problem of efficiently computing per-example gradients: that is, for a fixed set
of parameters, we want to compute the gradient of our loss function evaluated
separately at each example in a batch. With `vmap`, it’s easy:

```python
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
```

Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other
JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
differentiation for fast Jacobian and Hessian matrix calculations in
`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.

### SPMD programming with `pmap`

For parallel programming of multiple accelerators, like multiple GPUs, use
[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
With `pmap` you write single-program multiple-data (SPMD) programs, including
fast parallel collective communication operations. Applying `pmap` will mean
that the function you write is compiled by XLA (similarly to `jit`), then
replicated and executed in parallel across devices.

Here's an example on an 8-GPU machine:

```python
from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
```

In addition to expressing pure maps, you can use fast [collective communication
operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
between devices:

```python
from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
```

You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
sophisticated communication patterns.

It all composes, so you're free to differentiate through parallel computations:

```python
from jax import grad

@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)

print(f(x))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
```

When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
backward pass of the computation is parallelized just like the forward pass.

See the [SPMD
Cookbook](https://colab.research.google.com/github/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
and the [SPMD MNIST classifier from scratch
example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py)
for more.

## Current gotchas

For a more thorough survey of current gotchas, with examples and explanations,
Expand Down
35 changes: 21 additions & 14 deletions docs/index.rst
@@ -1,11 +1,9 @@
JAX reference documentation
===========================

Composable transformations of Python+NumPy programs: differentiate, vectorize,
JIT to GPU/TPU, and more.

For an introduction to JAX, start at the
`JAX GitHub page <https://github.com/google/jax>`_.
JAX is Autograd_ and XLA_, brought together for high-performance machine learning research.
It provides composable transformations of Python+NumPy programs: differentiate, vectorize,
parallelize, Just-In-Time compile to GPU/TPU, and more.

.. toctree::
:maxdepth: 1
Expand All @@ -22,9 +20,22 @@ For an introduction to JAX, start at the

.. toctree::
:maxdepth: 1
:caption: Advanced JAX Tutorials
:caption: Reference Documentation

faq
transformations
async_dispatch
jaxpr
notebooks/convolutions
pytrees
type_promotion
glossary
CHANGELOG

.. toctree::
:maxdepth: 1
:caption: Advanced JAX Tutorials

notebooks/autodiff_cookbook
notebooks/vmapped_log_probs
notebooks/neural_network_with_tfds_data
Expand All @@ -36,24 +47,16 @@ For an introduction to JAX, start at the
notebooks/maml
notebooks/score_matching


.. toctree::
:maxdepth: 1
:caption: Notes

CHANGELOG
faq
jaxpr
async_dispatch
concurrency
gpu_memory_allocation
profiling
device_memory_profiling
pytrees
rank_promotion_warning
type_promotion
custom_vjp_update
glossary

.. toctree::
:maxdepth: 2
Expand All @@ -76,3 +79,7 @@ Indices and tables
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`


.. _Autograd: https://github.com/hips/autograd)
.. _XLA: https://www.tensorflow.org/xla
2 changes: 2 additions & 0 deletions docs/jax.lax.rst
Expand Up @@ -162,6 +162,8 @@ Custom gradient operators
custom_linear_solve
custom_root

.. _jax-parallel-operators:

Parallel operators
------------------

Expand Down
4 changes: 4 additions & 0 deletions docs/jax.rst
Expand Up @@ -21,6 +21,8 @@ Subpackages
jax.dlpack
jax.profiler

.. _jax-jit:

Just-in-time compilation (:code:`jit`)
--------------------------------------

Expand All @@ -37,6 +39,8 @@ Just-in-time compilation (:code:`jit`)
default_backend
named_call

.. _jax-grad:

Automatic differentiation
-------------------------

Expand Down

0 comments on commit 23fbc42

Please sign in to comment.