Skip to content

Commit

Permalink
[Mosaic GPU] Use a custom TMA descriptor initialization method
Browse files Browse the repository at this point in the history
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
  • Loading branch information
apaszke authored and jax authors committed Apr 26, 2024
1 parent 268b39d commit 9b03195
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 21 deletions.
161 changes: 141 additions & 20 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -25,6 +26,7 @@

import jax
from jax._src import config
from jax._src import core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib import xla_client
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
Expand Down Expand Up @@ -58,6 +60,9 @@
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")

TMA_DESCRIPTOR_BYTES = 128
TMA_DESCRIPTOR_ALIGNMENT = 64


c = mgpu.c # This is too common to fully qualify.

Expand Down Expand Up @@ -97,11 +102,13 @@


@mosaic_gpu_p.def_abstract_eval
def _mosaic_gpu_abstract_eval(*_, module, out_types):
def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes):
del module, gmem_scratch_bytes # Unused.
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]


def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes):
del out_types # Unused.
runtime_path = (
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent.parent.parent
/ "mosaic" / "gpu" / "libmlir_cuda_runtime.so"
Expand All @@ -127,11 +134,16 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
) # pytype: disable=attribute-error
op = mlir.custom_call(
"mosaic_gpu",
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
result_types=[
*(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out),
mlir.aval_to_ir_type(
jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8)
),
],
operands=args,
backend_config=ptr_bytes,
)
return op.results
return op.results[:-1] # Skip the scratch space.

mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")

Expand Down Expand Up @@ -227,7 +239,12 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
@dataclasses.dataclass()
class LaunchContext:
launch_op: gpu.LaunchOp
gmem_scratch_ptr: ir.Value
profiler: OnDeviceProfiler | None = None
next_scratch_offset: int = 0
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
default_factory=list, init=False
)
tma_descriptors: dict[
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
ir.Value,
Expand All @@ -241,6 +258,37 @@ def named_region(self, *args, **kwargs):
else:
yield

def _alloc_scratch(
self,
size: int,
alignment: int | None = None,
host_init: Callable[[ir.Value], None] = lambda _: None,
device_init: Callable[[ir.Value], Any] = lambda x: x,
) -> ir.Value:
"""Allocates a GMEM scratch buffer.
The buffer is initialized on the host and then copied to GMEM before the
kernel launch.
"""
i8 = ir.IntegerType.get_signless(8)
ptr_ty = ir.Type.parse("!llvm.ptr")
if alignment is None:
alignment = size
if self.next_scratch_offset % alignment:
raise NotImplementedError # TODO(apaszke): Pad to match alignment
alloc_base = self.next_scratch_offset
self.next_scratch_offset += size
def host_init_wrapped(host_ptr):
with ir.InsertionPoint(self.launch_op):
host_init(
llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
)
self.host_scratch_init.append(host_init_wrapped)
with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
return device_init(llvm.getelementptr(
ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8
))

def _get_tma_desc(
self,
ref,
Expand All @@ -265,13 +313,42 @@ def _get_tma_desc(
with ir.InsertionPoint(self.launch_op):
for t in gmem_transform:
ref = t.apply(ref)
ref_unranked = memref.cast(
ir.UnrankedMemRefType.get(ref_ty.element_type, None), ref
)
tma_desc = nvgpu.tma_create_descriptor(
tensor_map_ty,
ref_unranked,
[c(s, index) for s in transformed_slice_shape],
ref_ty = ir.MemRefType(ref.type)

i64 = ir.IntegerType.get_signless(64)
ptr_ty = ir.Type.parse("!llvm.ptr")
def init_tma_desc(host_ptr):
_, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
as_i64 = lambda i: arith.index_cast(i64, i)
alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx))
llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings...
base_ptr = llvm.getelementptr(
ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type,
)
rank = ref_ty.rank
assert rank * 2 == len(sizes_and_strides)
args = [
host_ptr,
base_ptr,
c(utils.bytewidth(ref_ty.element_type), i64),
c(rank, i64),
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
c(0 if swizzle is None else swizzle, i64),
utils.pack_array([c(v, i64) for v in transformed_slice_shape]),
]
func.call([], "mosaic_gpu_init_tma_desc", args)
def cast_tma_desc(device_ptr):
nvvm.prefetch_tensormap(device_ptr)
return builtin.unrealized_conversion_cast(
[tensor_map_ty], [device_ptr]
)
tma_desc = self._alloc_scratch(
TMA_DESCRIPTOR_BYTES,
alignment=TMA_DESCRIPTOR_ALIGNMENT,
host_init=init_tma_desc,
device_init=cast_tma_desc,
)
self.tma_descriptors[tma_desc_key] = tma_desc
return tma_desc
Expand Down Expand Up @@ -378,18 +455,14 @@ def await_async_copy(
nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
gpu.barrier() # Groups are supposedly tracked per-thread

def _prefetch_tma_descs(self):
with ir.InsertionPoint(self.launch_op.body.blocks[0]):
with mgpu.once():
for desc in self.tma_descriptors.values():
nvgpu.tma_prefetch_descriptor(desc)


# TODO(apaszke): Inline this
@contextlib.contextmanager
def _launch(
token,
grid,
block,
gmem_scratch_ptr,
smem_buffers,
profiler_spec: profiler.ProfilerSpec | None = None,
maybe_prof_buffer: ir.Value | None = None,
Expand Down Expand Up @@ -449,7 +522,7 @@ def _launch(
else:
prof = None
smem_ref_tree = jax.tree.unflatten(smem_buffer_tree, smem_refs)
yield LaunchContext(launch_op, prof), smem_ref_tree
yield LaunchContext(launch_op, gmem_scratch_ptr, prof), smem_ref_tree
if prof is not None:
prof.finalize(grid=grid)
gpu.terminator()
Expand All @@ -466,6 +539,8 @@ def as_gpu_kernel(
):
ptr_ty = ir.Type.parse("!llvm.ptr")
token_ty = ir.Type.parse("!gpu.async.token")
i8 = ir.IntegerType.get_signless(8)
i64 = ir.IntegerType.get_signless(64)

def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
Expand All @@ -489,20 +564,46 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:

module = ir.Module.create()
with ir.InsertionPoint(module.body):
_declare_runtime_functions()
gmem_scratch_bytes = 0
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
def main(token_ptr, buffers):
nonlocal gmem_scratch_bytes
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
arg_refs = []
i = -1
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
gmem_scratch_ptr = llvm.LoadOp(
ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i + 1], ptr_ty)
)
in_refs = arg_refs[:len(in_ref_tys)]
out_refs = arg_refs[len(in_ref_tys):]
prof_buffer = out_refs.pop() if prof_spec is not None else None
with _launch(
token, grid, block, smem_scratch_shape, prof_spec, prof_buffer
token, grid, block, gmem_scratch_ptr, smem_scratch_shape,
prof_spec, prof_buffer
) as (launch_ctx, smem_refs):
body(launch_ctx, *in_refs, *out_refs, smem_refs)
gmem_scratch_bytes = launch_ctx.next_scratch_offset
# Allocate and initialize the host buffer right before the launch.
# Note that we couldn't do that before, because we had to run the body
# to learn what the scratch contains.
with ir.InsertionPoint(launch_ctx.launch_op):
host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
for init_callback in launch_ctx.host_scratch_init:
init_callback(host_scratch_ptr)
func.call(
[],
"mosaic_gpu_memcpy_async_h2d",
[
gmem_scratch_ptr,
host_scratch_ptr,
c(gmem_scratch_bytes, i64),
token_ptr,
],
)
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()

Expand All @@ -523,7 +624,12 @@ def _check_args(args):
pass_manager.run(module.operation)

def bind(*args):
return mosaic_gpu_p.bind(*args, out_types=out_shape, module=module)
return mosaic_gpu_p.bind(
*args,
out_types=out_shape,
module=module,
gmem_scratch_bytes=gmem_scratch_bytes,
)

if prof_spec is not None:
@jax.jit
Expand Down Expand Up @@ -552,6 +658,21 @@ def kernel(*args):
return kernel


def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
func.FuncOp(
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
)
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
func.FuncOp(
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
)


def dump_low_level(module):
dump_ptx = mosaic_gpu_dump_ptx.value
dump_ptxas = mosaic_gpu_dump_ptxas.value
Expand Down
13 changes: 13 additions & 0 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
return builtin.unrealized_conversion_cast([memref_ty], [desc])


def pack_array(values):
if not values:
raise ValueError("Empty array")
elem_ty = values[0].type
i64 = ir.IntegerType.get_signless(64)
ptr_ty = ir.Type.parse("!llvm.ptr")
arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty)
for i, v in enumerate(values):
elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty)
llvm.store(v, elem_ptr)
return arr_ptr


def get_contiguous_strides(xs):
strides_ret = []
stride = 1
Expand Down
1 change: 1 addition & 0 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ cc_library(
defines = ["JAX_GPU_CUDA=1"],
visibility = ["//visibility:public"],
deps = [
"@xla//xla/tsl/cuda:cupti",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
],
Expand Down
7 changes: 7 additions & 0 deletions jaxlib/mlir/_mlir_libs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ pybind_extension(
"-fexceptions",
"-fno-strict-aliasing",
],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
visibility = ["//third_party/py/jax:__subpackages__"],
deps = [
":jaxlib_mlir_capi_shared_library",
Expand All @@ -227,6 +233,7 @@ pybind_extension(
"//jaxlib/mosaic/gpu:mlir_capi",
"@nanobind",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
],
)

Expand Down
13 changes: 12 additions & 1 deletion jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,20 @@ cc_library(
alwayslink = True,
)

cc_library(
name = "runtime",
srcs = ["runtime.cc"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
)

cc_binary(
name = "libmlir_cuda_runtime.so",
srcs = ["@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp"],
srcs = [
"runtime.cc",
"@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp",
],
copts = ["-fvisibility=default"],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
Expand Down

0 comments on commit 9b03195

Please sign in to comment.