Skip to content

Commit

Permalink
inline and remove empty_mlir rules
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed May 18, 2023
1 parent 180e26d commit 129a4a5
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 12 deletions.
6 changes: 3 additions & 3 deletions jax/_src/lax/lax.py
Expand Up @@ -4720,9 +4720,9 @@ def empty(dtype):
empty_p = core.Primitive('empty')
empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
def _empty_lower(ctx, *, dtype):
if dtypes.is_opaque_dtype(dtype):
return dtype._rules.empty_mlir(ctx, ctx.avals_out[0])
return mlir.ir_constants(np.zeros((), np.dtype(dtype)))
dtype = dtype if dtypes.is_opaque_dtype(dtype) else np.dtype(dtype)
phys_aval = core.physical_aval(core.ShapedArray((), dtype))
return mlir.ir_constants(np.zeros(phys_aval.shape, phys_aval.dtype))
mlir.register_lowering(empty_p, _empty_lower)


Expand Down
5 changes: 0 additions & 5 deletions jax/_src/prng.py
Expand Up @@ -521,11 +521,6 @@ def make_sharded_array(aval, sharding, arrays, committed):

# element-type-polymorphic primitive lowering rules

@staticmethod
def empty_mlir(ctx, aval_out) -> Sequence[ir.Value]:
return mlir.ir_constants(np.zeros(aval_out.dtype.impl.key_shape,
dtype=np.dtype('uint32')))

@staticmethod
def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides) -> ir.Value:
key_shape = aval_out.dtype.impl.key_shape
Expand Down
4 changes: 0 additions & 4 deletions tests/lax_test.py
Expand Up @@ -2869,10 +2869,6 @@ def handler(arr):

# element-type-polymorphic primitive lowering rules

@staticmethod
def empty_mlir(ctx, aval_out):
return mlir.ir_constants(np.zeros((2,), dtype=np.dtype('uint32')))

@staticmethod
def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides):
start_indices = (*start_indices, 0)
Expand Down

0 comments on commit 129a4a5

Please sign in to comment.