Skip to content

Commit

Permalink
Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
Browse files Browse the repository at this point in the history
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)

This change doesn't appear to regress performance:

```
---------Benchmark summary for pmap_shard_outputs---------
  nouts    nshards       mean      %std    relative    mean/baseline
-------  ---------  ---------  --------  ----------  ---------------
     10          8   0.105598  5.06671      1               1.00693
    100          8   0.287756  0.870751     2.72502         0.973204
    500          8   1.20119   0.823624    11.3752          0.955185
   1000          8   2.56071   0           24.2497          0.983063
   5000          8  12.909     0          122.247           0.965925
    100          2   0.173727  5.15115      1.64518         0.98918
    100          4   0.207774  3.71411      1.9676          0.955849
    100          8   0.286103  1.60243      2.70937         0.971869
    100        100   2.34168   0           22.1755          0.904475
    100        500  15.9558    0          151.1             1.00483
```

Fixes #6044
  • Loading branch information
skye committed Mar 13, 2021
1 parent 8d3b4ac commit c56649a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
3 changes: 3 additions & 0 deletions jax/interpreters/pxla.py
Expand Up @@ -1165,6 +1165,9 @@ def partitioned_sharding_spec(num_partitions: int,
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
input_bufs = in_handler(args)
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
if xla.needs_check_special():
for bufs in out_bufs:
xla.check_special("parallel computation", bufs)
return out_handler(out_bufs)


Expand Down
14 changes: 8 additions & 6 deletions jax/interpreters/xla.py
Expand Up @@ -357,7 +357,7 @@ def _execute_compiled_primitive(prim, compiled, result_handler, *args):
device, = compiled.local_devices()
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
out_bufs = compiled.execute(input_bufs)
check_special(prim, out_bufs)
check_special(prim.name, out_bufs)
return result_handler(*out_bufs)

def _execute_replicated_primitive(prim, compiled, result_handler, *args):
Expand All @@ -370,11 +370,13 @@ def _execute_replicated_primitive(prim, compiled, result_handler, *args):
]
return result_handler(*out_bufs)

def needs_check_special():
return FLAGS.jax_debug_infs or FLAGS.jax_debug_nans

def check_special(prim, bufs):
if FLAGS.jax_debug_infs or FLAGS.jax_debug_nans:
def check_special(name, bufs):
if needs_check_special():
for buf in bufs:
_check_special(prim.name, buf.xla_shape(), buf)
_check_special(name, buf.xla_shape(), buf)

def _check_special(name, xla_shape, buf):
assert not xla_shape.is_tuple()
Expand Down Expand Up @@ -845,7 +847,7 @@ def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args):
device, = compiled.local_devices()
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
out_bufs = compiled.execute(input_bufs)
check_special(xla_call_p, out_bufs)
check_special(xla_call_p.name, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
Expand All @@ -856,7 +858,7 @@ def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
buf[0] for buf in compiled.execute_sharded_on_local_devices(
list(zip(*input_bufs)))
]
check_special(xla_call_p, out_bufs)
check_special(xla_call_p.name, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args):
Expand Down
69 changes: 69 additions & 0 deletions tests/debug_nans_test.py
Expand Up @@ -18,9 +18,11 @@

import jax
import numpy as np
from unittest import SkipTest

from jax import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -49,6 +51,12 @@ def testJitComputationNoNaN(self):
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()

def testJitComputationNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x)(A)
ans.block_until_ready()

def testSingleResultPrimitiveNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
Expand All @@ -71,6 +79,67 @@ def f(x):
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)

def testPmap(self):
f = jax.pmap(lambda x: 0. / x)

with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()

if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()

def testPmapNoNaN(self):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready()

@jtu.ignore_warning(message=".*is an experimental.*")
def testXmap(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")

f = jax.experimental.maps.xmap(
lambda x: 0. / x,
in_axes=['i'],
out_axes=['i'],
axis_resources={'i': 'x'})

with jax.experimental.maps.mesh(np.array(jax.local_devices()[:1]), ('x',)):
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()

if jax.device_count() >= 2:
with jax.experimental.maps.mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()

@jtu.ignore_warning(message=".*is an experimental.*")
@jtu.skip_on_devices("cpu", "gpu")
def testPjit(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")

p = jax.experimental.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_axis_resources=p,
out_axis_resources=p)

with jax.experimental.maps.mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()

# TODO(skye): add parallel inf tests, ideally by factoring out test logic

class DebugInfsTest(jtu.JaxTestCase):

Expand Down

0 comments on commit c56649a

Please sign in to comment.