Skip to content

Commit

Permalink
[Pallas] Pad input/outputs in interpret mode to fix errors in OOB mem…
Browse files Browse the repository at this point in the history
…ory accesses.

PiperOrigin-RevId: 633283991
  • Loading branch information
justinjfu authored and jax authors committed May 13, 2024
1 parent b8ed346 commit 1e48adc
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 12 deletions.
60 changes: 48 additions & 12 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ def _maybe_dynamic_update_slice(start_idx, block_shape, value, update,
assert update.shape == block_shape
return lax.dynamic_update_slice(value, update, start_idx)

def _pad_values_to_block_dimension(value,
block_shape):
"""Pads values so the shape evenly divides into block dimensions.
For example, if values has a shape of (33, 2, 5) with a block_shape of
(32, 2, 4), this function will pad the value of shape to (64, 2, 8).
Args:
value: Array to be padded.
block_shape: Block shapes to use for padding. If None, no padding will
be performed.
Returns:
A padded array.
"""
if block_shape is None:
return value
padded_shape = tuple(
((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
)
if padded_shape != value.shape:
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
pad_value = _uninitialized_value(shape=(), dtype=value.dtype)
value = jnp.pad(value, pad_width, constant_values=pad_value)
return value

def _uninitialized_value(shape, dtype):
if jnp.issubdtype(dtype, jnp.floating):
return jnp.full(shape, jnp.nan, dtype)
Expand Down Expand Up @@ -157,7 +183,29 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
raise NotImplementedError("Padding with aliasing not supported.")
x = lax.pad(x, jnp.zeros((), x.dtype), [(*p, 0) for p in padding])
carry.append(x)

block_shapes_without_mapped_dims = [
None if block_mapping is None else block_mapping.block_shape
for block_mapping in grid_mapping.block_mappings
]
is_indexing_dim = [
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
for bm in block_shapes_without_mapped_dims
]
block_shapes = [
None if (bm is None or iid is None)
else tuple(1 if i else b for i, b in zip(iid, bm))
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
]

# Pad values to evenly divide into block dimensions.
# This allows interpret mode to catch errors on OOB memory accesses
# by poisoning values with NaN. It also fixes an inconstency with
# lax.dynamic_slice where if the slice goes out of bounds, it will instead
# move the start_index backwards so the slice will fit in memory.
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
carry.extend(scratch_values)

num_inout = len(args) + len(out)
grid_start_indices = (jnp.int32(0),) * len(grid)
if grid:
Expand All @@ -180,18 +228,6 @@ def body(carry):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
block_shapes_without_mapped_dims = [
None if block_mapping is None else block_mapping.block_shape
for block_mapping in grid_mapping.block_mappings
]
is_indexing_dim = [
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
for bm in block_shapes_without_mapped_dims
]
block_shapes = [
None if bm is None else tuple(1 if i else b for i, b in zip(iid, bm))
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
with pallas_core.grid_env(local_grid_env):
Expand Down
73 changes: 73 additions & 0 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,5 +2026,78 @@ def test_softmax(self, shape, dtype):
class SoftmaxInterpreterTest(PallasTest):
INTERPRET = True


class PallasInterpretModeOutOfBoundsTest(PallasTest):

INTERPRET: bool = True

def test_interpret_mode_out_of_bounds_access(self):
block_size = 32
# Create input tensors which require a reduction along an axis
# not divisible by block_size.
x = jax.random.normal(jax.random.key(0), (block_size, block_size + 1))
y = jax.random.normal(jax.random.key(1), (block_size + 1, block_size))
expected = jnp.dot(x, y)

in_specs = [
pl.BlockSpec(lambda i, j, k: (i, k), (block_size, block_size)),
pl.BlockSpec(lambda i, j, k: (k, j), (block_size, block_size)),
]
out_spec = pl.BlockSpec(lambda i, j, k: (i, j), (block_size, block_size))

def _unmasked_matmul_kernel(x_ref, y_ref, o_ref):
@pl.when(pl.program_id(2) == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)

o_ref[...] += x_ref[...] @ y_ref[...]

out = pl.pallas_call(
_unmasked_matmul_kernel,
out_shape=expected,
grid=(1, 1, 2),
in_specs=in_specs,
out_specs=out_spec,
interpret=True,
)(x, y)

# With a naive matmul implementation, using uninitialized values (NaN) will
# cause the overall output to be NaN.
with self.subTest('UnmaskedIsNaN'):
np.testing.assert_allclose(
np.isnan(out), jnp.ones_like(out, dtype=jnp.bool_)
)

def _masked_matmul_kernel(x_ref, y_ref, o_ref):
@pl.when(pl.program_id(2) == 0)
def _():
o_ref[:, :] = jnp.zeros_like(o_ref)

# Create a validity mask for OOB values.
num_valid = x.shape[1] - pl.program_id(2) * block_size
num_valid = jnp.minimum(num_valid, block_size)
mask = jnp.tril(jnp.ones_like(x_ref[:, :]))[num_valid - 1][jnp.newaxis, :]
mask = jnp.repeat(mask, block_size, axis=0)

# Mask and multiply.
masked_x = jnp.where(mask, x_ref[:, :], 0.0)
masked_y = jnp.where(mask.T, y_ref[:, :], 0.0)
o_ref[:, :] += masked_x @ masked_y

out = pl.pallas_call(
_masked_matmul_kernel,
out_shape=expected,
grid=(1, 1, 2),
in_specs=in_specs,
out_specs=out_spec,
interpret=True,
)(x, y)

# With a masked matmul implementation, uninitialized values will be
# masked before computation. This should return the correct result.
with self.subTest('MaskedOutputIsCorrect'):
np.testing.assert_allclose(out, expected, atol=1e-5)


if __name__ == "__main__":
absltest.main()

0 comments on commit 1e48adc

Please sign in to comment.