Skip to content

Commit

Permalink
Fix a bug when dynamic grids are used together with scalar inputs
Browse files Browse the repository at this point in the history
The scalar input lowering was incorrectly assuming that its going to be the `i`th input
to the HLO, which was incorrect even before dynamic grids were a thing (but only when
fusions happened, so it was never uncovered before).

PiperOrigin-RevId: 606998887
  • Loading branch information
apaszke authored and jax authors committed Feb 14, 2024
1 parent 8814362 commit b9824d7
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,28 @@ def dynamic_kernel(steps):
dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32)
)

# TODO(apaszke): Add tests for scalar_prefetch too
def test_dynamic_grid_scalar_input(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)

def kernel(scalar_input_ref, output_ref):
output_ref[...] = jnp.full_like(output_ref, scalar_input_ref[0, 0])

@jax.jit
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
out_shape=result_ty,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
grid=(steps * 2,),
)(jnp.array([[42]], dtype=jnp.int32))

np.testing.assert_array_equal(
dynamic_kernel(jnp.int32(4)), np.full(shape, 42.0, np.float32)
)

def test_vmap_trivial_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
Expand Down

0 comments on commit b9824d7

Please sign in to comment.