Skip to content

Commit

Permalink
Fix missing handler when lexically capturing a ShardedDeviceArray whe…
Browse files Browse the repository at this point in the history
…n MLIR enabled.
  • Loading branch information
hawkinsp committed Feb 8, 2022
1 parent 44c6c05 commit 5679fed
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,9 +740,15 @@ def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
canonicalize_types=canonicalize_types)


def _sharded_device_array_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(np.asarray(val),
canonicalize_types=canonicalize_types)

def _register_handlers_for_sharded_device_array(sda):
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
mlir.register_constant_handler(sda,
_sharded_device_array_mlir_constant_handler)

core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
dispatch.device_put_handlers[sda] = dispatch._device_put_array
Expand Down
5 changes: 5 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,11 @@ def testShardedDeviceArrays(self):
# test that the repr doesn't crash
repr(z)

# test that we can lexically capture a sda as a constant.
g = jit(lambda z: z + y)
self.assertAllClose(g(7), y + 7)


# Tests edge cases in lax._reshape_sharded_device_array
@parameterized.named_parameters(
{"testcase_name": "_in={}_out={}".format(in_shape, out_shape)
Expand Down

0 comments on commit 5679fed

Please sign in to comment.