Skip to content

Commit

Permalink
[Mosaic GPU] Implement a simple profilng tool using CUDA events
Browse files Browse the repository at this point in the history
The other JAX profiling tools are a little heavyweight when we only care about
timing a single kernel programatically.

Also adapt wgmma.py to match failures triggered by upstream MLIR changes.

PiperOrigin-RevId: 628096973
  • Loading branch information
apaszke authored and jax authors committed Apr 25, 2024
1 parent fad2c0e commit ded9272
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 1 deletion.
57 changes: 57 additions & 0 deletions jax/experimental/mosaic/gpu/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
# ==============================================================================

import contextlib
import ctypes
import functools
import json

import jax
from jax._src.interpreters import mlir
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
from jax._src.lib import xla_client
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
Expand All @@ -30,6 +35,58 @@
# ruff: noqa: F405
# mypy: ignore-errors

xla_client.register_custom_call_target(
"mosaic_gpu_record_event",
mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(),
platform="CUDA",
)

record_event_p = jax.core.Primitive("record_event")
record_event_p.multiple_results = True

@record_event_p.def_abstract_eval
def _record_event_abstract_eval(*args, event):
del event # Unused.
return args

@functools.partial(mlir.register_lowering, record_event_p, platform="cuda")
def _record_event_lowering_rule(ctx, *args, event):
ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes(
8, byteorder="little"
) # pytype: disable=attribute-error
op = mlir.custom_call(
"mosaic_gpu_record_event",
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands=args,
backend_config=ptr_bytes,
operand_output_aliases={i: i for i in range(len(args))},
)
return op.results

def _record_event(args, event):
flat_args, treedef = jax.tree.flatten(args)
return jax.tree.unflatten(
treedef, record_event_p.bind(*flat_args, event=event)
)

def measure(f, *args):
# TODO(apaszke): Raise if this is called under jit.
start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
try:
@jax.jit
def run(*args):
return _record_event(f(*_record_event(args, start_event)), end_event)
results = jax.block_until_ready(run(*args))
elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed(
start_event, end_event
)
finally:
mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event)
mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event)
return results, elapsed


class ProfilerSpec:
ENTER = 0
EXIT = 1 << 31
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/wgmma.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def take_regs(n):
ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"

def lc(x):
return llvm.mlir_constant(i32, ir.IntegerAttr.get(i32, x))
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result

def as_i32_reg(v):
return llvm.extractelement(
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mlir/_mlir_libs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ pybind_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"//jaxlib/mosaic/gpu:mlir_capi",
"@nanobind",
"@xla//xla/service:custom_call_status",
Expand Down
31 changes: 31 additions & 0 deletions jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <cstdint>
#include "nanobind/nanobind.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "jaxlib/mosaic/gpu/integrations/c/passes.h"
#include "xla/service/custom_call_status.h"
Expand All @@ -18,10 +20,39 @@ void MosaicKernelCall(void* stream, void** buffers, char* opaque,
func(args);
}

void EventRecordCall(void* stream, void** buffers, char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto* event = reinterpret_cast<gpuEvent_t**>(opaque);
if (gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream)) !=
gpuSuccess) {
const char message[] = "Failed to record event";
XlaCustomCallStatusSetFailure(status, message, sizeof(message));
}
}

NB_MODULE(_mosaic_gpu_ext, m) {
m.def("_custom_call_capsule",
[]() { return EncapsulateFunction(MosaicKernelCall); });
m.def("register_passes", []() { return mlirMosaicGpuRegisterPasses(); });
m.def("_gpu_event_create", []() {
gpuEvent_t* event = new gpuEvent_t();
gpuEventCreate(event, GPU_EVENT_DEFAULT);
return reinterpret_cast<uintptr_t>(event);
});
m.def("_gpu_event_destroy", [](uintptr_t event) {
gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
});
m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) {
float elapsed_ms = -1;
if (gpuEventElapsedTime(
&elapsed_ms, *reinterpret_cast<gpuEvent_t*>(start_event),
*reinterpret_cast<gpuEvent_t*>(end_event)) != gpuSuccess) {
throw std::runtime_error("Failed to get elapsed time between events");
}
return elapsed_ms;
});
m.def("_record_event_capsule",
[]() { return EncapsulateFunction(EventRecordCall); });
}

} // namespace
Expand Down
8 changes: 8 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
else:
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
Expand Down Expand Up @@ -799,5 +800,12 @@ def kernel(ctx, *args):
np.testing.assert_array_equal(inp, result)


class ProfilerTest(TestCase):

def test_measure(self):
x = jnp.arange(1024 * 1024)
profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ded9272

Please sign in to comment.