From daa81e6fb5079cb69141f57a287974aa4500cc19 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 23 May 2024 10:01:22 -0700 Subject: [PATCH] Added support for printing scalar values in Pallas TPU kernels The implementation uses the new tpu.log operation in the Mosaic TPU dialect. Note that * the logging only happens if --xla_tpu_enable_log_recorder is set; * only scalars can be printed; * placeholders only accept i32 arguments at the moment. PiperOrigin-RevId: 636585852 --- jax/_src/pallas/mosaic/lowering.py | 37 +++++++++++++++++++++++--- jax/_src/pallas/mosaic_gpu/lowering.py | 21 ++------------- jax/_src/pallas/primitives.py | 28 ++++++++++++++++++- jaxlib/mosaic/dialect/tpu/tpu.td | 10 +++++++ tests/pallas/pallas_call_tpu_test.py | 31 +++++++++++++++++++++ 5 files changed, 103 insertions(+), 24 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 75e320ba2070..f43e8d633e28 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,21 +15,20 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations +from collections.abc import Sequence import dataclasses import functools +import string from typing import Any, Callable -from collections.abc import Sequence - -from jaxlib.mlir.ir import Module import jax from jax import core as jax_core from jax import lax from jax import tree_util +from jax._src import ad_util from jax._src import custom_derivatives from jax._src import debugging from jax._src import linear_util as lu -from jax._src import ad_util from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import source_info_util @@ -59,6 +58,7 @@ from jax._src.util import unzip2 from jax.experimental.mosaic.dialects import tpu import jax.numpy as jnp +from jaxlib.mlir.ir import Module import numpy as np # TODO(sharadmv): enable type checking @@ -2331,3 +2331,32 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int): lowering_rules[tpu_primitives.delay_p] = _delay_rule + + +def _debug_print_rule( + ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool +): + primitives.check_debug_print_format(fmt, *args) + if has_placeholders: + if not all( + isinstance(arg.type, ir.IntegerType) and arg.type.width == 32 + for arg in args + ): + raise TypeError( + "All arguments must be 32-bit integers when using" + " placeholders (`{...}`). If you need to print values of other types," + " remove placeholders from the format string." + ) + + # TPU expects $0, $1 etc as placeholders. + tpu_fmt = "".join( + f"{text}${idx}" + for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt)) + ) + else: + tpu_fmt = fmt + tpu.log(args, tpu_fmt, formatted=has_placeholders) + return () + + +lowering_rules[primitives.debug_print_p] = _debug_print_rule diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 82442a45b7bc..a2ab59267a05 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -20,7 +20,6 @@ import dataclasses import functools import math -import string from typing import Any, cast import jax @@ -356,26 +355,10 @@ def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, fmt, - has_placeholders, - _formatter=string.Formatter(), + has_placeholders: bool, ): del has_placeholders - n_placeholders = 0 - for _, field, spec, conversion in string.Formatter().parse(fmt): - if spec or conversion: - raise ValueError( - "The format string should not contain any format specs orconversions" - ) - if field: - raise ValueError( - "The format string should not reference arguments by position or name" - ) - n_placeholders += field is not None - if len(args) != n_placeholders: - raise TypeError( - f"The format string expects {n_placeholders} " - f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}" - ) + primitives.check_debug_print_format(fmt, *args) mgpu.debug_print(fmt, *args) return () diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index f4d1d81ee685..a8264d858026 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -518,6 +518,9 @@ def debug_print(fmt: str, *args: jax.ArrayLike): * On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must contain a placeholder for each value to be printed. Format specs and conversions are not supported. + * In TPU, if ``fmt`` contains placeholders, all values must be 32-bit + integers. If there are no placeholders, the values are printed after + the format string. *args: The scalar values to print. """ # fmt: skip has_placeholders = False @@ -527,6 +530,29 @@ def debug_print(fmt: str, *args: jax.ArrayLike): return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders) +def check_debug_print_format( + fmt: str, *args: jax.ArrayLike +): + n_placeholders = 0 + for _, field, spec, conversion in string.Formatter().parse(fmt): + if field is not None: + n_placeholders += 1 + if spec or conversion: + raise ValueError( + "The format string should not contain any format specs or conversions" + ) + if field: + raise ValueError( + "The format string should not reference arguments by position or name" + ) + + if len(args) != n_placeholders: + raise TypeError( + f"The format string expects {n_placeholders} " + f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}" + ) + + @debug_print_p.def_impl def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool): if has_placeholders: @@ -538,7 +564,7 @@ def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool): @debug_print_p.def_effectful_abstract_eval def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool): - del fmt + del fmt, has_placeholders if any(aval.shape for aval in avals): raise ValueError("Only scalar values are supported") return [], {debug_print_effect} diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index cf123bc90c15..457a485325ab 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -561,6 +561,16 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; } +def TPU_LogOp : TPU_Op<"log"> { + let arguments = (ins + Variadic:$inputs, + StrAttr:$tag, + DefaultValuedAttr:$formatted + ); + let results = (outs); + let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; +} + def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> { let dependentDialects = [ "::mlir::func::FuncDialect", diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 46bccfda1681..7c9e1baa1388 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -2197,5 +2197,36 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_array_equal(r[5], result[5]) +class PallasCallPrintTest(PallasTPUTest): + + def test_debug_print(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + ) + def kernel(x_ref, o_ref): + pl.debug_print('It works!') + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + kernel(x) + + def test_debug_print_with_values(self): + @functools.partial( + self.pallas_call, + in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),), + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + ) + def kernel(x_ref, o_ref): + pl.debug_print('x[0] == {}', x_ref[0]) + + x = jnp.array([42, 24]).astype(jnp.int32) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + compiled_kernel(x) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())