Skip to content

Commit

Permalink
[Mosaic GPU] Prepare for writing warp-specialized kernels
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632854287
  • Loading branch information
apaszke authored and jax authors committed May 12, 2024
1 parent 49bd4d6 commit a527b71
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 19 deletions.
3 changes: 3 additions & 0 deletions jax/experimental/mosaic/gpu/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
memref_unsqueeze,
once,
tile_shape,
thread_idx,
warp_idx,
warpgroup_idx,
)
from .wgmma import (
WGMMAAccumulator,
Expand Down
46 changes: 38 additions & 8 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,27 @@ def from_memref_type(cls, memref_ty: ir.Type):
memref_type = ir.MemRefType(memref_ty)
bw = mgpu.bytewidth(memref_type.element_type)
assert 8 % bw == 0 and 8 // bw != 0, bw
return cls(shape=memref_type.shape, vec_size=8 // bw)
if np.prod(memref_type.shape) % WARPGROUP_SIZE != 0:
raise ValueError(
"Ref must have a number of elements that is a multiple of"
f" {WARPGROUP_SIZE}"
)
max_vec_size = np.prod(memref_type.shape) // WARPGROUP_SIZE
return cls(
shape=tuple(memref_type.shape), vec_size=min(8 // bw, max_vec_size)
)

def thread_vec_idxs(self):
"""The indexes to be used for vector load/store WGStridedFragLayout.
Yields:
The indices of the vector that correspond to the current thread.
"""
index = ir.IndexType.get()
cardinality = np.prod(self.shape)
assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0
reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size)
tidx = gpu.thread_id(gpu.Dimension.x)
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index))
off = arith.muli(tidx, c(self.vec_size, tidx.type))
for i in range(reg_num):
yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))]
Expand Down Expand Up @@ -194,23 +203,43 @@ def _pointwise(self, op, *other):
return FragmentedArray(_registers=new_regs, _layout=self.layout)

def __add__(self, other):
return self._pointwise(arith.addf, other)
if ir.FloatType.isinstance(self.mlir_dtype):
return self._pointwise(arith.addf, other)
elif ir.IntegerType.isinstance(self.mlir_dtype):
return self._pointwise(arith.addi, other)
else:
raise NotImplementedError(self.mlir_dtype)

def __mul__(self, other):
return self._pointwise(arith.mulf, other)
if ir.FloatType.isinstance(self.mlir_dtype):
return self._pointwise(arith.mulf, other)
elif ir.IntegerType.isinstance(self.mlir_dtype):
return self._pointwise(arith.muli, other)
else:
raise NotImplementedError(self.mlir_dtype)

def __sub__(self, other):
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError
return self._pointwise(arith.subf, other)

def __truediv__(self, other):
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError
return self._pointwise(arith.divf, other)

def max(self, other):
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError
return self._pointwise(arith.maximumf, other)

def exp(self, approx: bool = False):
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError
def fast_exp(x):
f32 = ir.F32Type.get()
if self.mlir_dtype != f32:
raise NotImplementedError
log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
if x.type == f32:
scaled = arith.mulf(x, log2e)
Expand Down Expand Up @@ -372,8 +401,9 @@ def store_untiled(self, ref: ir.Value):

def _store_untiled_wg_strided(self, ref: ir.Value):
ref_ty = ir.MemRefType(ref.type)
if ref_ty.shape != self.shape:
raise ValueError((ref_ty.shape, self.shape))
ref_shape = tuple(ref_ty.shape)
if ref_shape != self.shape:
raise ValueError((ref_shape, self.shape))
smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat):
vector.store(reg, smem_1d, idx)
Expand All @@ -390,7 +420,7 @@ def _store_untiled_wgmma(self, ref: ir.Value):
def c(x):
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))

tidx = gpu.thread_id(gpu.Dimension.x)
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
row_base = arith.addi(
Expand Down Expand Up @@ -454,7 +484,7 @@ def transfer_tiled(shape, dtype, swizzle: int | None):
def c(x):
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))

tidx = gpu.thread_id(gpu.Dimension.x)
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7}
Expand Down
43 changes: 36 additions & 7 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,37 @@ def wrapper(f):
return wrapper


def get_warp_idx():
def thread_idx():
i32 = ir.IntegerType.get_signless(32)
tidx = arith.index_cast(i32, gpu.thread_id(gpu.Dimension.x))
warp_idx = arith.shrui(tidx, c(5, tidx.type))
as_i32 = lambda x: arith.index_cast(i32, x)
tidx = as_i32(gpu.thread_id(gpu.Dimension.x))
stride = as_i32(gpu.block_dim(gpu.Dimension.x))
for dim in (gpu.Dimension.y, gpu.Dimension.z):
tidx = arith.addi(tidx, arith.muli(as_i32(gpu.thread_id(dim)), stride))
stride = arith.muli(stride, as_i32(gpu.block_dim(dim)))
return tidx


def warp_idx(sync=True):
i32 = ir.IntegerType.get_signless(32)
warp_idx = arith.shrui(thread_idx(), c(5, i32))
if not sync:
return warp_idx
mask = c(0xFFFFFFFF, i32)
return nvvm.shfl_sync(
warp_idx.type, mask, warp_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
)

def warpgroup_idx(sync=True):
i32 = ir.IntegerType.get_signless(32)
wg_idx = arith.shrui(thread_idx(), c(7, i32))
if not sync:
return wg_idx
mask = c(0xFFFFFFFF, i32)
return nvvm.shfl_sync(
wg_idx.type, mask, wg_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
)


# True withon `once()` contexts.
_ONCE_REGION_ACTIVE = False
Expand All @@ -211,7 +233,7 @@ def once():
yield
return

warp = get_warp_idx()
warp = warp_idx()
first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type))
elected = nvvm.elect_sync(ir.IntegerType.get_signless(1))
should_run = arith.andi(first_warp, elected)
Expand Down Expand Up @@ -314,7 +336,9 @@ def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
new_shape = list(ref_ty.shape)
new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])]
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
if ref_ty.layout == identity:
contig_strided_1d = ir.Attribute.parse("strided<[1]>")
# Not sure why but MLIR expects the strided 1D layout to disappear in this op.
if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d:
new_layout = ir.AffineMapAttr.get(
ir.AffineMap.get_identity(ref_ty.rank - fold_rank + 1)
)
Expand Down Expand Up @@ -466,13 +490,14 @@ def commit_shared():

class BarrierArray:

def __init__(self, num_barriers):
def __init__(self, num_barriers: int, arrival_count: int = 1):
barrier_group_ty = ir.Type.parse(
"!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>,"
f" num_barriers={num_barriers}>"
)

self.value = nvgpu.mbarrier_create(barrier_group_ty)
self.num_barriers = num_barriers
index = ir.IndexType.get()
if num_barriers > 32:
raise NotImplementedError("Only up to 32 barriers per group supported")
Expand All @@ -481,7 +506,7 @@ def __init__(self, num_barriers):
memref.store(c(0, i32), self.phases, [])
with once():
for i in range(num_barriers):
nvgpu.mbarrier_init(self.value, c(1, index), c(i, index))
nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index))

def __getitem__(self, offset: ir.Value | int):
if isinstance(offset, int):
Expand Down Expand Up @@ -512,6 +537,10 @@ def wait(self):
memref.store(new_parities, self.barrier_array.phases, [])
self.wait_parity(parity)

def arrive(self):
token_ty = ir.Type.parse("!nvgpu.mbarrier.token")
nvgpu.mbarrier_arrive(token_ty, self.barrier_array.value, self.offset)


class Partition:
source_bounds: tuple[int, ...]
Expand Down
4 changes: 0 additions & 4 deletions tests/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ jax_test(
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py",
tags = [
"manual",
"notap",
],
deps = [
"//jax:mosaic_gpu",
] + py_deps("numpy"),
Expand Down
43 changes: 43 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,49 @@ def kernel(ctx, rhs, out, rhs_smem):
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)


class BarrierTest(TestCase):

def test_wg_communication(self):
i32 = ir.IntegerType.get_signless(32)
def kernel(ctx, dst, tmp):
del ctx # Unused.
barriers = BarrierArray(3, arrival_count=128)
gpu.barrier() # Make sure the barriers are initialized.
wg_idx = arith.divui(mgpu.warp_idx(), c(4, i32))
is_first_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(0, i32))
is_second_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(1, i32))
arr = mgpu.FragmentedArray.splat(
arith.addi(wg_idx, c(1, i32)),
(128,),
mgpu.WGStridedFragLayout((128,), 1),
)
with ir.InsertionPoint(scf.IfOp(is_first_wg).then_block):
arr.store_untiled(tmp)
barriers[0].arrive() # Signal that tmp is ready.
barriers[1].wait() # Wait for the other warp to produce tmp.
final_arr = arr + mgpu.FragmentedArray.load_strided(tmp)
final_arr.store_untiled(memref_slice(dst, 0))
scf.yield_([])
with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block):
barriers[0].wait()
final_arr = arr + mgpu.FragmentedArray.load_strided(tmp)
barriers[2].arrive()
barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp.
arr.store_untiled(tmp)
barriers[1].arrive() # Signal that tmp is ready.
final_arr.store_untiled(memref_slice(dst, 1))
scf.yield_([])
out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32)
y = mosaic_gpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(2 * 128, 1, 1),
(),
out_shape,
jax.ShapeDtypeStruct((128,), jnp.int32),
)()
np.testing.assert_array_equal(y, np.full_like(y, 3, dtype=np.int32))

class TMATest(TestCase):

@parameterized.product(
Expand Down

0 comments on commit a527b71

Please sign in to comment.