Skip to content

Commit

Permalink
Add notes on jitting batch / monte carlo kernel_fnn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 318826347
  • Loading branch information
romanngg committed Jun 29, 2020
1 parent 8d3ca2f commit 46edb0e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
10 changes: 10 additions & 0 deletions neural_tangents/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
the result, allowing to both use multiple accelerators and stay within memory
limits.
Note that you typically should not apply the `jax.jit` decorator to the
resulting `batched_kernel_fn`, as its purpose is explicitly serial execution in
order to save memory. Further, you do not need to apply `jax.jit` to the input
`kernel_fn` function, as it is JITted internally.
Example:
>>> from jax import numpy as np
>>> import neural_tangents as nt
Expand Down Expand Up @@ -61,6 +66,11 @@ def batch(kernel_fn: KernelFn,
store_on_device: bool = True) -> KernelFn:
"""Returns a function that computes a kernel in batches over all devices.
Note that you typically should not apply the `jax.jit` decorator to the
resulting `batched_kernel_fn`, as its purpose is explicitly serial execution
in order to save memory. Further, you do not need to apply `jax.jit` to the
input `kernel_fn` function, as it is JITted internally.
Args:
kernel_fn:
A function that computes a kernel on two batches,
Expand Down
9 changes: 4 additions & 5 deletions neural_tangents/utils/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
from jax.tree_util import tree_multimap
from jax.tree_util import tree_reduce
from neural_tangents.utils import utils
from neural_tangents.utils.typing import \
ApplyFn, EmpiricalKernelFn, PyTree, PRNGKey, Axes
from neural_tangents.utils.typing import ApplyFn, EmpiricalKernelFn, PyTree, PRNGKey, Axes


def linearize(f: Callable[..., np.ndarray],
Expand Down Expand Up @@ -512,12 +511,12 @@ def kernel_fn(x1: np.ndarray,
x2:
second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
params:
A `PyTree` of parameters about which we would like to compute the
neural tangent kernel.
get:
type of the empirical kernel. `get=None` means `get=("nngp", "ntk")`.
Can be a string (`"nngp"`) or a tuple of strings (`("ntk", "nngp")`).
params:
A `PyTree` of parameters about which we would like to compute the
neural tangent kernel.
keys:
`None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
dtype `uint32`. If `key=None`, then the function `f` is deterministic
Expand Down
15 changes: 13 additions & 2 deletions neural_tangents/utils/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions to compute Monte Carlo NNGP and NTK estimates.
"""Function to compute Monte Carlo NNGP and NTK estimates.
The library has a public function `monte_carlo_kernel_fn` that allow to compute
This module contains a function `monte_carlo_kernel_fn` that allow to compute
Monte Carlo estimates of NNGP and NTK kernels of arbitrary functions. For more
details on how individual samples are computed, refer to `utils/empirical.py`.
Note that the `monte_carlo_kernel_fn` accepts arguments like `batch_size`,
`device_count`, and `store_on_device`, and is appropriately batched /
parallelized. You don't need to apply the `nt.batch` or `jax.jit` decorators to
it. Further, you do not need to apply `jax.jit` to the input `apply_fn`
function, as the resulting empirical kernel function is JITted internally.
"""


Expand Down Expand Up @@ -120,6 +126,11 @@ def monte_carlo_kernel_fn(
) -> MonteCarloKernelFn:
"""Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.
Note that the returned function is appropriately batched / parallelized. You
don't need to apply the `nt.batch` or `jax.jit` decorators to it. Further,
you do not need to apply `jax.jit` to the input `apply_fn` function, as the
resulting empirical kernel function is JITted internally.
Args:
init_fn:
a function initializing parameters of the neural network. From
Expand Down

0 comments on commit 46edb0e

Please sign in to comment.