Skip to content

Commit

Permalink
Try out a slightly different Sphinx configuration.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 300834922
  • Loading branch information
sschoenholz committed Mar 13, 2020
1 parent 640a1df commit a060538
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 44 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('../..'))


# -- Project information -----------------------------------------------------

project = u'Neural Tangents'
copyright = u'2020 Google Inc.'
copyright = u'2019, Google LLC.'
author = u'The Neural Tangents Authors'

# The short X.Y version
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
Neural Tangents reference documentation
===========================================

Neural Tangents is a set of tools for constructing and training infinitely wide
neural networks.

.. toctree::
:maxdepth: 2
:caption: Contents:
:caption: Reference:

neural_tangents.stax
neural_tangents.empirical
neural_tangents.predict
neural_tangents.batching
neural_tangents.monte_carlo

Indices and tables
==================
Expand Down
File renamed without changes.
File renamed without changes.
5 changes: 5 additions & 0 deletions docs/neural_tangents.monte_carlo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Monte Carlo Sampling
===========================

.. automodule:: neural_tangents.utils.monte_carlo
:members:
File renamed without changes.
File renamed without changes.
78 changes: 38 additions & 40 deletions neural_tangents/utils/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,26 @@ def monte_carlo_kernel_fn(init_fn,
"""Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.
Args:
init_fn: a function initializing parameters of the neural network. From
:init_fn: a function initializing parameters of the neural network. From
`jax.experimental.stax`: "takes an rng key and an input shape and returns
an `(output_shape, params)` pair".
apply_fn: a function computing the output of the neural network.
:apply_fn: a function computing the output of the neural network.
From `jax.experimental.stax`: "takes params, inputs, and an rng key and
applies the layer".
key: RNG (`jax.random.PRNGKey`) for sampling random networks. Must have
:key: RNG (`jax.random.PRNGKey`) for sampling random networks. Must have
shape `(2,)`.
n_samples: number of Monte Carlo samples. Can be either an integer or an
:n_samples: number of Monte Carlo samples. Can be either an integer or an
iterable of integers at which the resulting generator will yield
estimates. Example: use `n_samples=[2**k for k in range(10)]` for the
generator to yield estimates using 1, 2, 4, ..., 512 Monte Carlo samples.
batch_size: an integer making the kernel computed in batches of `x1` and
:batch_size: an integer making the kernel computed in batches of `x1` and
`x2` of this size. `0` means computing the whole kernel. Must divide
`x1.shape[0]` and `x2.shape[0]`.
device_count: an integer making the kernel be computed in parallel across
:device_count: an integer making the kernel be computed in parallel across
this number of devices (e.g. GPUs or TPU cores). `-1` means use all
available devices. `0` means compute on a single device sequentially. If
not `0`, must divide `x1.shape[0]`.
store_on_device: a boolean, indicating whether to store the resulting
:store_on_device: a boolean, indicating whether to store the resulting
kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger
kernels may fit.
Expand All @@ -120,39 +120,37 @@ def monte_carlo_kernel_fn(init_fn,
`n` samples for `n in n_samples`.
Example:
```python
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(256, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(512, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>>
>>> n_samples = 200
>>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
>>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
>>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
>>>
>>> n_samples = [1, 10, 100, 1000]
>>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
>>> n_samples)
>>> kernel_samples = kernel_fn_generator(x_train, x_test, get=('nngp', 'ntk'))
>>> for n, kernel in zip(n_samples, kernel_samples):
>>> print(n, kernel)
>>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.
```
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(256, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(512, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>>
>>> n_samples = 200
>>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
>>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
>>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
>>>
>>> n_samples = [1, 10, 100, 1000]
>>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
>>> n_samples)
>>> kernel_samples = kernel_fn_generator(x_train, x_test, get=('nngp', 'ntk'))
>>> for n, kernel in zip(n_samples, kernel_samples):
>>> print(n, kernel)
>>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.
"""
kernel_fn = empirical.empirical_kernel_fn(apply_fn)

Expand Down

0 comments on commit a060538

Please sign in to comment.