| Module                     | Function          | Description                                                             | Import Statement                                   |
|----------------------------|-------------------|-------------------------------------------------------------------------|----------------------------------------------------|
| Key Creation & Manipulation| PRNGKey           | Create a pseudo-random number generator (PRNG) key given an integer seed| `from jax.random import PRNGKey`                  |
|                            | key               | Create a pseudo-random number generator (PRNG) key given an integer seed| `from jax.random import key`                      |
|                            | key_data          | Recover the bits of key data underlying a PRNG key array                | `from jax.random import key_data`                 |
|                            | wrap_key_data     | Wrap an array of key data bits into a PRNG key array                    | `from jax.random import wrap_key_data`            |
|                            | fold_in           | Folds in data to a PRNG key to form a new PRNG key                       | `from jax.random import fold_in`                  |
|                            | split             | Splits a PRNG key into num new keys by adding a leading axis            | `from jax.random import split`                    |
|                            | clone             | Clone a key for reuse                                                   | `from jax.random import clone`                    |
| Random Samplers           | ball              | Sample uniformly from the unit Lp ball                                   | `from jax.random import ball`                     |
|                            | bernoulli         | Sample Bernoulli random values with given shape and mean                 | `from jax.random import bernoulli`                |
|                            | beta              | Sample Beta random values with given shape and float dtype               | `from jax.random import beta`                     |
|                            | binomial          | Sample Binomial random values with given shape and float dtype           | `from jax.random import binomial`                 |
|                            | bits              | Sample uniform bits in the form of unsigned integers                     | `from jax.random import bits`                     |
|                            | categorical       | Sample random values from categorical distributions                      | `from jax.random import categorical`              |
|                            | cauchy            | Sample Cauchy random values with given shape and float dtype            | `from jax.random import cauchy`                   |
|                            | chisquare         | Sample Chisquare random values with given shape and float dtype         | `from jax.random import chisquare`                |
|                            | choice            | Generates a random sample from a given array                              | `from jax.random import choice`                   |
|                            | dirichlet         | Sample Dirichlet random values with given shape and float dtype          | `from jax.random import dirichlet`                |
|                            | double_sided_maxwell | Sample from a double sided Maxwell distribution                      | `from jax.random import double_sided_maxwell`     |
|                            | exponential       | Sample Exponential random values with given shape and float dtype        | `from jax.random import exponential`              |
|                            | f                 | Sample F-distribution random values with given shape and float dtype     | `from jax.random import f`                        |
|                            | gamma             | Sample Gamma random values with given shape and float dtype              | `from jax.random import gamma`                    |
|                            | generalized_normal| Sample from the generalized normal distribution                          | `from jax.random import generalized_normal`       |
|                            | geometric         | Sample Geometric random values with given shape and float dtype          | `from jax.random import geometric`                |
|                            | gumbel            | Sample Gumbel random values with given shape and float dtype             | `from jax.random import gumbel`                   |
|                            | laplace           | Sample Laplace random values with given shape and float dtype            | `from jax.random import laplace`                  |
|                            | loggamma          | Sample log-gamma random values with given shape and float dtype          | `from jax.random import loggamma`                 |
|                            | logistic          | Sample logistic random values with given shape and float dtype           | `from jax.random import logistic`                 |
|                            | lognormal         | Sample lognormal random values with given shape and float dtype          | `from jax.random import lognormal`                |
|                            | maxwell           | Sample from a one sided Maxwell distribution                             | `from jax.random import maxwell`                  |
|                            | multivariate_normal | Sample multivariate normal random values with given mean and covariance  | `from jax.random import multivariate_normal`     |
|                            | normal            | Sample standard normal random values with given shape and float dtype    | `from jax.random import normal`                   |
|                            | orthogonal        | Sample uniformly from the orthogonal group O(n)                           | `from jax.random import orthogonal`               |
|                            | pareto            | Sample Pareto random values with given shape and float dtype             | `from jax.random import pareto`                   |
|                            | permutation       | Returns a randomly permuted array or range                                | `from jax.random import permutation`              |
|                            | poisson           | Sample Poisson random values with given shape and integer dtype           | `from jax.random import poisson`                  |
|                            | rademacher        | Sample from a Rademacher distribution                                    | `from jax.random import rademacher`               |
|                            | randint           | Sample uniform random values in [minval, maxval) with given shape/dtype  | `from jax.random import randint`                  |
|                            | rayleigh          | Sample Rayleigh random values with given shape and float dtype           | `from jax.random import rayleigh`                 |
|                            | t                 | Sample Student's t random values with given shape and float dtype         | `from jax.random import t`                        |
|                            | triangular        | Sample Triangular random values with given shape and float dtype         | `from jax.random import triangular`               |
|                            | truncated_normal  | Sample truncated standard normal random values with given shape and dtype| `from jax.random import truncated_normal`         |
|                            | uniform           | Sample uniform random values in [minval, maxval) with given shape/dtype  | `from jax.random import uniform`                  |
|                            | wald              | Sample Wald random values with given shape and float dtype               | `from jax.random import wald`                     |
|                            | weibull_min       | Sample from a Weibull distribution                                        | `from jax.random import weibull_min`              |


| Module             | Function           | Description                                                     | Import Statement                                   |
|--------------------|--------------------|-----------------------------------------------------------------|----------------------------------------------------|
| Activation         | relu               | Rectified linear unit activation function                      | `from flax.nn.activation import relu`             |
|                    | relu6              | Rectified Linear Unit 6 activation function                    | `from flax.nn.activation import relu6`            |
|                    | sigmoid            | Sigmoid activation function                                     | `from flax.nn.activation import sigmoid`          |
|                    | softplus           | Softplus activation function                                    | `from flax.nn.activation import softplus`         |
|                    | sparse_plus        | Sparse plus function                                            | `from flax.nn.activation import sparse_plus`      |
|                    | soft_sign          | Soft-sign activation function                                   | `from flax.nn.activation import soft_sign`        |
|                    | silu               | SiLU (aka swish) activation function                            | `from flax.nn.activation import silu`             |
|                    | swish              | SiLU (aka swish) activation function                            | `from flax.nn.activation import swish`            |
|                    | log_sigmoid        | Log-sigmoid activation function                                 | `from flax.nn.activation import log_sigmoid`      |
|                    | leaky_relu         | Leaky rectified linear unit activation function                 | `from flax.nn.activation import leaky_relu`       |
|                    | hard_sigmoid       | Hard Sigmoid activation function                                | `from flax.nn.activation import hard_sigmoid`     |
|                    | hard_silu          | Hard SiLU (swish) activation function                          | `from flax.nn.activation import hard_silu`        |
|                    | hard_swish         | Hard SiLU (swish) activation function                          | `from flax.nn.activation import hard_swish`       |
|                    | hard_tanh          | Hard Tanh activation function                                  | `from flax.nn.activation import hard_tanh`        |
|                    | elu                | Exponential linear unit activation function                    | `from flax.nn.activation import elu`              |
|                    | celu               | Continuously-differentiable exponential linear unit activation | `from flax.nn.activation import celu`             |
|                    | selu               | Scaled exponential linear unit activation                      | `from flax.nn.activation import selu`             |
|                    | gelu               | Gaussian error linear unit activation function                  | `from flax.nn.activation import gelu`             |
|                    | glu                | Gated linear unit activation function                          | `from flax.nn.activation import glu`              |
|                    | squareplus         | Squareplus activation function                                  | `from flax.nn.activation import squareplus`       |
| Other Functions    | softmax            | Softmax function                                                | `from flax.nn.activation import softmax`          |
|                    | log_softmax        | Log-Softmax function                                            | `from flax.nn.activation import log_softmax`      |
|                    | logsumexp          | Compute the log of the sum of exponentials of input elements    | `from flax.nn.activation import logsumexp`        |
|                    | standardize        | Normalizes an array by subtracting mean and dividing by std     | `from flax.nn.activation import standardize`      |
|                    | one_hot            | One-hot encodes the given indices                              | `from flax.nn.activation import one_hot`          |


| Module       | Function  | Description                                                   | Import Statement                             |
|--------------|-----------|---------------------------------------------------------------|----------------------------------------------|
| jax.tree     | all       | Call all() over the leaves of a tree.                        | `from jax.tree_util import all`             |
|              | flatten   | Flattens a pytree.                                           | `from jax.tree_util import flatten`         |
|              | leaves    | Gets the leaves of a pytree.                                 | `from jax.tree_util import leaves`          |
|              | map       | Maps a multi-input function over pytree args to produce a new pytree. | `from jax.tree_util import map`      |
|              | reduce    | Call reduce() over the leaves of a tree.                     | `from jax.tree_util import reduce`          |
|              | structure | Gets the treedef for a pytree.                               | `from jax.tree_util import structure`       |
|              | transpose | Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). | `from jax.tree_util import transpose` |
|              | unflatten | Reconstructs a pytree from the treedef and the leaves.       | `from jax.tree_util import unflatten`       |
