Skip to content

Commit

Permalink
Added support for printing scalar values in Pallas TPU kernels
Browse files Browse the repository at this point in the history
The implementation uses the new tpu.log operation in the Mosaic TPU dialect.

Note that

* the logging only happens if --xla_tpu_enable_log_recorder is set;
* only scalars can be printed;
* placeholders only accept i32 arguments at the moment.

PiperOrigin-RevId: 636585852
  • Loading branch information
superbobry authored and jax authors committed May 23, 2024
1 parent 8f6fc11 commit daa81e6
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 24 deletions.
37 changes: 33 additions & 4 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,20 @@
"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
from __future__ import annotations

from collections.abc import Sequence
import dataclasses
import functools
import string
from typing import Any, Callable
from collections.abc import Sequence

from jaxlib.mlir.ir import Module

import jax
from jax import core as jax_core
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import linear_util as lu
from jax._src import ad_util
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import source_info_util
Expand Down Expand Up @@ -59,6 +58,7 @@
from jax._src.util import unzip2
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from jaxlib.mlir.ir import Module
import numpy as np

# TODO(sharadmv): enable type checking
Expand Down Expand Up @@ -2331,3 +2331,32 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int):


lowering_rules[tpu_primitives.delay_p] = _delay_rule


def _debug_print_rule(
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
):
primitives.check_debug_print_format(fmt, *args)
if has_placeholders:
if not all(
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
for arg in args
):
raise TypeError(
"All arguments must be 32-bit integers when using"
" placeholders (`{...}`). If you need to print values of other types,"
" remove placeholders from the format string."
)

# TPU expects $0, $1 etc as placeholders.
tpu_fmt = "".join(
f"{text}${idx}"
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
)
else:
tpu_fmt = fmt
tpu.log(args, tpu_fmt, formatted=has_placeholders)
return ()


lowering_rules[primitives.debug_print_p] = _debug_print_rule
21 changes: 2 additions & 19 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import dataclasses
import functools
import math
import string
from typing import Any, cast

import jax
Expand Down Expand Up @@ -356,26 +355,10 @@ def _debug_print_lowering_rule(
ctx: LoweringRuleContext,
*args,
fmt,
has_placeholders,
_formatter=string.Formatter(),
has_placeholders: bool,
):
del has_placeholders
n_placeholders = 0
for _, field, spec, conversion in string.Formatter().parse(fmt):
if spec or conversion:
raise ValueError(
"The format string should not contain any format specs orconversions"
)
if field:
raise ValueError(
"The format string should not reference arguments by position or name"
)
n_placeholders += field is not None
if len(args) != n_placeholders:
raise TypeError(
f"The format string expects {n_placeholders} "
f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}"
)
primitives.check_debug_print_format(fmt, *args)
mgpu.debug_print(fmt, *args)
return ()

Expand Down
28 changes: 27 additions & 1 deletion jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,9 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
conversions are not supported.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
the format string.
*args: The scalar values to print.
""" # fmt: skip
has_placeholders = False
Expand All @@ -527,6 +530,29 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders)


def check_debug_print_format(
fmt: str, *args: jax.ArrayLike
):
n_placeholders = 0
for _, field, spec, conversion in string.Formatter().parse(fmt):
if field is not None:
n_placeholders += 1
if spec or conversion:
raise ValueError(
"The format string should not contain any format specs or conversions"
)
if field:
raise ValueError(
"The format string should not reference arguments by position or name"
)

if len(args) != n_placeholders:
raise TypeError(
f"The format string expects {n_placeholders} "
f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}"
)


@debug_print_p.def_impl
def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool):
if has_placeholders:
Expand All @@ -538,7 +564,7 @@ def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool):

@debug_print_p.def_effectful_abstract_eval
def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool):
del fmt
del fmt, has_placeholders
if any(aval.shape for aval in avals):
raise ValueError("Only scalar values are supported")
return [], {debug_print_effect}
Expand Down
10 changes: 10 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,16 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> {
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
}

def TPU_LogOp : TPU_Op<"log"> {
let arguments = (ins
Variadic<AnyType>:$inputs,
StrAttr:$tag,
DefaultValuedAttr<BoolAttr, "false">:$formatted
);
let results = (outs);
let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }];
}

def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",
Expand Down
31 changes: 31 additions & 0 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,5 +2197,36 @@ def kernel(x_ref, y_ref, o_ref):
np.testing.assert_array_equal(r[5], result[5])


class PallasCallPrintTest(PallasTPUTest):

def test_debug_print(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
)
def kernel(x_ref, o_ref):
pl.debug_print('It works!')

x = jnp.array([4.2, 2.4]).astype(jnp.float32)
kernel(x)

def test_debug_print_with_values(self):
@functools.partial(
self.pallas_call,
in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),),
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
)
def kernel(x_ref, o_ref):
pl.debug_print('x[0] == {}', x_ref[0])

x = jnp.array([42, 24]).astype(jnp.int32)
compiled_kernel = (
jax.jit(kernel)
.lower(x)
.compile({'xla_tpu_enable_log_recorder': 'true'})
)
compiled_kernel(x)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit daa81e6

Please sign in to comment.