Skip to content

Commit

Permalink
[JAX] move example libraries from jax.experimental into `jax.exampl…
Browse files Browse the repository at this point in the history
…e_libraries`

The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
  • Loading branch information
froystig authored and jax authors committed Oct 20, 2021
1 parent 349d0d0 commit 623c201
Show file tree
Hide file tree
Showing 25 changed files with 1,068 additions and 979 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...main).

* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`


## jax 0.2.24 (Oct 19, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ For a deeper dive into JAX:
notebooks](https://github.com/google/jax/tree/main/docs/notebooks).

You can also take a look at [the mini-libraries in
`jax.experimental`](https://github.com/google/jax/tree/main/jax/experimental/README.md),
`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/experimental/README.md),
like [`stax` for building neural
networks](https://github.com/google/jax/tree/main/jax/experimental/README.md#neural-net-building-with-stax)
networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax)
and [`optimizers` for first-order stochastic
optimization](https://github.com/google/jax/tree/main/jax/experimental/README.md#first-order-optimization),
optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization),
or the [examples](https://github.com/google/jax/tree/main/examples).

## Transformations
Expand Down
7 changes: 7 additions & 0 deletions docs/jax.example_libraries.optimizers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
jax.example_libraries.optimizers module
=======================================

.. automodule:: jax.example_libraries.optimizers
:members:
:undoc-members:
:show-inheritance:
10 changes: 10 additions & 0 deletions docs/jax.example_libraries.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
jax.example_libraries package
=============================

.. toctree::
:maxdepth: 1

jax.example_libraries.optimizers
jax.example_libraries.stax

.. automodule:: jax.example_libraries
7 changes: 7 additions & 0 deletions docs/jax.example_libraries.stax.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
jax.example_libraries.stax module
=================================

.. automodule:: jax.example_libraries.stax
:members:
:undoc-members:
:show-inheritance:
7 changes: 0 additions & 7 deletions docs/jax.experimental.optimizers.rst

This file was deleted.

2 changes: 0 additions & 2 deletions docs/jax.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ jax.experimental package
jax.experimental.loops
jax.experimental.maps
jax.experimental.pjit
jax.experimental.optimizers
jax.experimental.sparse
jax.experimental.stax

.. automodule:: jax.experimental

Expand Down
7 changes: 0 additions & 7 deletions docs/jax.experimental.stax.rst

This file was deleted.

1 change: 1 addition & 0 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Subpackages

jax.numpy
jax.scipy
jax.example_libraries
jax.experimental
jax.image
jax.lax
Expand Down
2 changes: 1 addition & 1 deletion examples/advi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from jax import jit, grad, vmap
from jax import random
from jax.experimental import optimizers
from jax.example_libraries import optimizers
import jax.numpy as jnp
import jax.scipy.stats.norm as norm

Expand Down
4 changes: 2 additions & 2 deletions examples/differentially_private_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@
from jax import jit
from jax import random
from jax import vmap
from jax.experimental import optimizers
from jax.experimental import stax
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp
from jax.examples import datasets
Expand Down
2 changes: 1 addition & 1 deletion examples/kernel_lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy.random as npr

import jax.numpy as jnp
from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax import grad, jit, make_jaxpr, vmap, lax


Expand Down
10 changes: 5 additions & 5 deletions examples/mnist_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

"""A basic MNIST example using JAX with the mini-libraries stax and optimizers.
The mini-library jax.experimental.stax is for neural network building, and
the mini-library jax.experimental.optimizers is for first-order stochastic
The mini-library jax.example_libraries.stax is for neural network building, and
the mini-library jax.example_libraries.optimizers is for first-order stochastic
optimization.
"""

Expand All @@ -27,9 +27,9 @@

import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
from examples import datasets


Expand Down
6 changes: 3 additions & 3 deletions examples/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import jax
import jax.numpy as jnp
from jax import jit, grad, lax, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, FanOut, Relu, Softplus
from examples import datasets


Expand Down
10 changes: 5 additions & 5 deletions examples/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
FanOut, Flatten, GeneralConv, Identity,
MaxPool, Relu, LogSoftmax)
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import (AvgPool, BatchNorm, Conv, Dense,
FanInSum, FanOut, Flatten, GeneralConv,
Identity, MaxPool, Relu, LogSoftmax)


# ResNet blocks compose other layers
Expand Down
29 changes: 15 additions & 14 deletions jax/experimental/README.md → jax/example_libraries/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ constructor functions for common basic pairs, like `Conv` and `Relu`, and these
pairs can be composed in series using `stax.serial` or in parallel using
`stax.parallel`.

Heres an example:
Here's an example:

```python
import jax.numpy as jnp
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
from jax.example_libraries import stax
from jax.example_libraries.stax import (
Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax)

# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
Expand All @@ -54,20 +55,20 @@ predictions = net_apply(net_params, inputs)

### First-order optimization

JAX has a minimal optimization library focused on stochastic first-order
optimizers. Every optimizer is modeled as an `(init_fun, update_fun,
get_params)` triple of functions. The `init_fun` is used to initialize the
optimizer state, which could include things like momentum variables, and the
`update_fun` accepts a gradient and an optimizer state to produce a new
optimizer state. The `get_params` function extracts the current iterate (i.e.
the current parameters) from the optimizer state. The parameters being optimized
can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can
store your parameters however youd like.
The file `optimizers.py` contains a minimal optimization library focused on
stochastic first-order optimizers. Every optimizer is modeled as an
`(init_fun, update_fun, get_params)` triple of functions. The `init_fun` is used
to initialize the optimizer state, which could include things like momentum
variables, and the `update_fun` accepts a gradient and an optimizer state to
produce a new optimizer state. The `get_params` function extracts the current
iterate (i.e. the current parameters) from the optimizer state. The parameters
being optimized can be ndarrays or arbitrarily-nested list/tuple/dict
structures, so you can store your parameters however you'd like.

Heres an example, using `jit` to compile the whole update end-to-end:
Here's an example, using `jit` to compile the whole update end-to-end:

```python
from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax import jit, grad

# Define a simple squared-error loss
Expand Down
13 changes: 13 additions & 0 deletions jax/example_libraries/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

0 comments on commit 623c201

Please sign in to comment.