Skip to content

Commit

Permalink
Use local device count for pmap in np.batch (#157)
Browse files Browse the repository at this point in the history
* Use local device count for pmap in np.batch

to work correctly in multi-host contexts
  • Loading branch information
jglaser committed Jul 17, 2022
1 parent a38c637 commit 23bea9f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions neural_tangents/_src/batching.py
Expand Up @@ -110,7 +110,7 @@ def batch(kernel_fn: _KernelFn,
input_req = getattr(kernel_fn, 'input_req', {})
dropout_in_analytic_kernel = input_req.get('use_dropout', False)
use_multidevice = device_count > 0 or (device_count == -1 and
jax.device_count() > 1)
jax.local_device_count() > 1)
use_serial = bool(batch_size)
if use_multidevice:
kernel_fn = _parallel(kernel_fn, use_serial,
Expand Down Expand Up @@ -522,7 +522,7 @@ def _parallel(kernel_fn: _KernelFn,
"""

if device_count == -1:
device_count = jax.device_count()
device_count = jax.local_device_count()

def _check_dropout(n1, n2, kwargs):
dropout_in_empirical_kernel = getattr(kwargs, 'rng', None) is not None
Expand Down Expand Up @@ -700,7 +700,7 @@ def jit_or_pmap_broadcast(f: Callable, device_count: int = -1) -> Callable:
key = (f, device_count)

if device_count == -1:
device_count = jax.device_count()
device_count = jax.local_device_count()

# TODO(romann): adapt this when JAX allows `axis_in` for `pmap`.
def broadcast(arg: np.ndarray) -> np.ndarray:
Expand Down

0 comments on commit 23bea9f

Please sign in to comment.