Skip to content

Commit

Permalink
Make the _parallel decorator work with functions of any* signature …
Browse files Browse the repository at this point in the history
…by broadcasting numpy arrays and closing over non-arrays.

The first argument has to be a numpy array with the leading dimension of size `device_count` for pmapping. *NO ARRAYS ARE ALLOWED IN KEYWORD ARGUMENTS.

PiperOrigin-RevId: 270212471
  • Loading branch information
romanngg authored and sschoenholz committed Sep 21, 2019
1 parent f6892ca commit 7cef742
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 20 deletions.
61 changes: 60 additions & 1 deletion neural_tangents/tests/batch_test.py
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the Neural Tangents library."""

from __future__ import absolute_import
Expand All @@ -21,7 +20,9 @@
from jax import test_util as jtu
from jax.api import jit
from jax.config import config as jax_config
import jax.numpy as np
import jax.random as random
from jax.tree_util import tree_map
from neural_tangents import stax
from neural_tangents.utils import batch
from neural_tangents.utils import empirical
Expand Down Expand Up @@ -264,6 +265,64 @@ def testAutomatic(self, train_shape, test_shape, network, name, ker_fun):
_test_kernel_against_batched(self, ker_fun, kernel_batched, data_self,
data_other)

def test_jit_or_pmap_broadcast(self):
def ker_fun(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65):
res = np.abs(np.matmul(x1, x2))
if do_square:
res *= res
if do_flip:
res = -res

res *= random.uniform(keys) * p
return [res, params]

params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
x2 = np.arange(0, 10).reshape((10,))
keys = random.PRNGKey(1)

ker_fun_pmapped = batch._jit_or_pmap_broadcast(ker_fun, device_count=0)
x1 = np.arange(0, 10).reshape((1, 10))
for do_flip in [True, False]:
for do_square in [True, False]:
with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0):
res_1 = ker_fun(
x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65)
res_2 = ker_fun_pmapped(
x1, x2, do_flip, keys, do_square, params, _unused=True)
self.assertAllClose(res_1, res_2, True)

utils.stub_out_pmap(batch, 1)
x1 = np.arange(0, 10).reshape((1, 10))
ker_fun_pmapped = batch._jit_or_pmap_broadcast(ker_fun, device_count=1)
for do_flip in [True, False]:
for do_square in [True, False]:
with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1):
res_1 = ker_fun(
x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65)
res_2 = ker_fun_pmapped(
x1, x2, do_flip, keys, do_square, params, _unused=None)
self.assertAllClose(res_1[0], res_2[0], True)
self.assertAllClose(
tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1],
True)

ker_fun_pmapped = batch._jit_or_pmap_broadcast(ker_fun, device_count=2)
x1 = np.arange(0, 20).reshape((2, 10))
utils.stub_out_pmap(batch, 2)

def broadcast(arg):
return np.broadcast_to(arg, (2,) + arg.shape)

for do_flip in [True, False]:
for do_square in [True, False]:
with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2):
res_1 = ker_fun(x1, x2, do_flip, keys, do_square, params, p=0.2)
res_2 = ker_fun_pmapped(
x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2)
self.assertAllClose(res_1[0][0], res_2[0][0], True)
self.assertAllClose(res_1[0][1], res_2[0][1], True)
self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1], True)


if __name__ == '__main__':
jtu.absltest.main()
119 changes: 100 additions & 19 deletions neural_tangents/utils/batch.py
Expand Up @@ -22,6 +22,7 @@
from jax.api import pmap
from jax.lib import xla_bridge
import jax.numpy as np
from jax.tree_util import tree_all
from jax.tree_util import tree_map
from jax.tree_util import tree_multimap
from neural_tangents.utils.kernel import Kernel
Expand Down Expand Up @@ -211,16 +212,10 @@ def _parallel(ker_fun, device_count=-1):
A new function with the same signature as ker_fun that computes the kernel
by batching over the dataset in parallel over a specified number of cores.
"""

ker_fun = _jit_or_pmap_broadcast(ker_fun, device_count)
if device_count == -1:
device_count = xla_bridge.device_count()

def broadcast(arg):
# TODO(romann): remove this when JAX allows `axis_in` for `pmap`.
return np.broadcast_to(arg, (device_count,) + arg.shape)

ker_fun = pmap(ker_fun)

def parallel_fn(x1, x2=None, *args, **kwargs):
if x2 is None:
# TODO(schsam): Only compute the upper triangular part of the kernel.
Expand All @@ -242,17 +237,8 @@ def parallel_fn(x1, x2=None, *args, **kwargs):
_device_count = ragged
n1_per_device = 1

if n1_per_device:
x1s = np.reshape(
x1, (_device_count, n1_per_device,) + input_shape)
else:
x1s = np.reshape(x1, (n1, 1,) + input_shape)

x2s = broadcast(x2)
args = tree_map(broadcast, args)
kwargs = tree_map(broadcast, kwargs)

kernel = ker_fun(x1s, x2s, *args, **kwargs)
x1 = np.reshape(x1, (_device_count, n1_per_device,) + input_shape)
kernel = ker_fun(x1, x2, *args, **kwargs)
return _flatten_kernel(kernel)

# Set function attributes so that `serial` can detect whether or not it is
Expand Down Expand Up @@ -289,9 +275,104 @@ def batch(ker_fun, batch_size=0, device_count=-1, store_on_device=True):
if (device_count == -1 and xla_bridge.device_count() > 1) or device_count > 0:
ker_fun = _parallel(ker_fun, device_count)
else:
ker_fun = jit(ker_fun)
ker_fun = _jit_or_pmap_broadcast(ker_fun, device_count=0)

if not batch_size:
return ker_fun

return _serial(ker_fun, batch_size, store_on_device)


def _is_np_ndarray(x):
return tree_all(tree_map(lambda y: isinstance(y, np.ndarray), x))


def _merge_dicts(a, b):
# TODO(schsam): Replace by {**a, **b} when Python 2 is depricated.
merged = dict(a)
merged.update(b)
return merged


def _get_jit_or_pmap_broadcast():
"""Initializes a cache of pmapped functions closed over non-`np.ndarray` args.
Returns:
A `jit_or_pmap_broadcast` function allowing to jit or pmap a function as a
closure over all non-`np.ndarray` args, all `kwargs`, while broadcasting
all `np.ndarray`s in `args` except the first one.
"""
cache = {}

def jit_or_pmap_broadcast(f, device_count=-1):
"""Pmap `f` over the first argument by closing over or broadcasting others.
Args:
f: function to pmap. First argument must be a `np.ndarray` with leading
axis having the size of `device_count`.
device_count: number of XLA devices. `-1` means all available devices. `0`
means to just `jit` the function.
Returns:
A function of the same signature as `f` pmapped over the first argument
with other arguments either closed over (non-`np.ndarray`s in `args` and
all `kwargs`) or broadcasted to `(device_count,) + old_shape` (for
`np.ndarray`s). If `device_count == 0`, `f` is closed over and jitted
over all non-array arguments and all `kwargs`.
Raises:
An error if `kwargs` have a `np.ndarray`.
TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. See
https://github.com/google/jax/issues/912
"""
key = (f, device_count)

if device_count == -1:
device_count = xla_bridge.device_count()

# TODO(romann): adapt this when JAX allows `axis_in` for `pmap`.
def broadcast(arg):
if device_count == 0:
return arg
return np.broadcast_to(arg, (device_count,) + arg.shape)

def f_pmapped(x, *args, **kwargs):
args_np, args_np_idxs = [], []
args_other = {}

# TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it.
# https://github.com/google/jax/issues/912
# Filter out `np.ndarray`s from other arguments.
for i, arg in enumerate(args):
if _is_np_ndarray(arg):
args_np.append(arg)
args_np_idxs.append(i)
else:
args_other[i] = arg

# Check cache before jitting.
_key = key + tuple(args_other.items()) + tuple(kwargs.items())
if _key in cache:
_f = cache[_key]
else:
# Define a `np.ndarray`-only function as a closure over other arguments.
def _f(_x, *_args_np):
# Merge args.
_args_np = {i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np)}
_args = _merge_dicts(_args_np, args_other)
_args = tuple(v for k, v in sorted(_args.items()))
return f(_x, *_args, **kwargs)

_f = jit(_f) if device_count == 0 else pmap(_f)
cache[_key] = _f

# Broadcast `np.ndarray` arguments and apply the new function to them.
args_np = tree_map(broadcast, args_np)
return _f(x, *args_np)

return f_pmapped

return jit_or_pmap_broadcast


_jit_or_pmap_broadcast = _get_jit_or_pmap_broadcast()

1 comment on commit 7cef742

@sschoenholz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes issue #5.

Please sign in to comment.