Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions lit_tests/kernel/wave/magic_number_division.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# RUN: python %s | FileCheck %s

from sympy import ceiling

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.utils.general_utils import (
run_test,
)

M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


@run_test
def test_magic_number_div():
"""Test that floordiv/mod by dynamic (runtime) divisors are lowered
to the magic-number multiply-high trick instead of expensive hardware
division.

When kernel dimensions are dynamic, the compiler cannot fold
floordiv/mod into compile-time constants. The magic-number
optimisation precomputes ``ceil(2^32 / d)`` once per unique divisor
and replaces every subsequent division with a 64-bit multiply + shift,
which is significantly cheaper on GPU.

We use a GEMM with GROUP_SIZE_N workgroup reordering to exercise
this: the reordering delinearises the flat workgroup id via
``ceildiv(M, BLOCK_M)``, and the GEMM's multiple memory accesses
(read A, read B, write C) each independently compute reordered
indices, producing enough dynamic floordiv/mod expressions to
demonstrate that the expensive magic-number precomputation
(a single divui) is performed once per divisor and then reused
by multiple cheap multiply-and-shift sequences.
"""
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]
constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

wg0, wg1 = WORKGROUP_0, WORKGROUP_1
num_wg_0 = ceiling(M / BLOCK_M)

flat_wg_index = wg1 * num_wg_0 + wg0
num_wg_group = GROUP_SIZE_N * num_wg_0
group_id = flat_wg_index // num_wg_group
first_wg_id_1 = group_id * GROUP_SIZE_N
new_wg0 = (flat_wg_index % num_wg_group) // GROUP_SIZE_N
new_wg1 = first_wg_id_1 + (flat_wg_index % num_wg_group) % GROUP_SIZE_N

constraints += [tkw.ReorderingConstraint(new_wg0, 0)]
constraints += [tkw.ReorderingConstraint(new_wg1, 1)]

@tkw.wave(constraints)
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.iterate(K, init_args=[c_reg])
def repeat(
acc: tkl.Register[M, N, tkl.f32],
) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a)
b_reg = tkw.read(b)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c)

options = WaveCompileOptions(
subs={
M: 512,
N: 1024,
K: 256,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
GROUP_SIZE_N: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
},
canonicalize=True,
compile_to_mlir=True,
magic_number_div=True,
)

options.dynamic_symbols = [M, N, K]
for sym in options.dynamic_symbols:
del options.subs[sym]

gemm = wave_compile(options, gemm)
print(gemm.asm)

# CHECK-LABEL: func.func @gemm
# CHECK-DAG: arith.constant 4294967295 : i64
# CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i64
#
# Magic precomputation: one divui per unique dynamic divisor.
# CHECK: arith.divui {{.*}} : i64
# CHECK: arith.divui {{.*}} : i64
#
# Multiply-high (shrui >> 32) reusing precomputed magic numbers.
# CHECK: arith.shrui {{.*}}, %[[C32]] : i64
#
# Amortised: mulhi reusing a previously computed magic number
# with a different dividend — no new divui needed.
# CHECK-NOT: arith.divui
# CHECK-NOT: arith.divsi
# CHECK: arith.shrui {{.*}}, %[[C32]] : i64
# CHECK-NOT: arith.divui
# CHECK-NOT: arith.divsi
# CHECK: return
158 changes: 158 additions & 0 deletions wave_lang/kernel/compiler/wave_codegen/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,13 @@ def get_static_dim(s: Optional[IndexExpr]) -> int:
return func_op

def emit(self, graph: Optional[fx.Graph] = None) -> Operation:
global _magic_number_enabled, _magic_number_cache, _magic_entry_block, _magic_divisor_first_seen
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting module level globals at R642, and then resetting it here.
This is not thread-safe.
caches are read-write per-compilation state.

_magic_number_enabled = self.options.magic_number_div
_magic_number_cache = {}
_magic_divisor_first_seen = {}

func = self.emit_func()
_magic_entry_block = func.entry_block
with InsertionPoint.at_block_terminator(func.entry_block), Location.unknown():
self._emit_graph(
graph if graph is not None else self.trace.get_root_graph()
Expand Down Expand Up @@ -633,10 +639,15 @@ def add_emitter_subs(

_emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0)))
_use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1)))
_magic_number_enabled = False
_magic_entry_block = None

_Rational = namedtuple("_Rational", ["numerator", "denominator"])
_ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"])

_magic_number_cache: dict = {}
_magic_divisor_first_seen: dict = {}


def gen_sympy_index(dynamics: dict[IndexSymbol, Value], expr: sympy.Expr) -> Value:
use_affine_expr = _use_affine_expr
Expand Down Expand Up @@ -778,16 +789,163 @@ def muli_expr(lhs, rhs):

return op_expr(lhs, rhs, lambda a, b: a * b)

def _is_dynamic_divisor(val) -> bool:
"""Check if a value is NOT a compile-time constant."""
if isinstance(val, _ApplyExpr):
return not all(
isinstance(a, OpResult) and get_const_val(a) is not None
for a in val.args
)
if isinstance(val, OpResult):
return get_const_val(val) is None
return True

def _divisor_key(val):
"""Hashable key that identifies a divisor by its structure."""
if isinstance(val, _ApplyExpr):
arg_keys = []
for a in val.args:
c = get_const_val(a) if isinstance(a, OpResult) else None
arg_keys.append(("const", c) if c is not None else ("val", id(a)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to this
("val", id(a)) uses the python object's memory address as the cache key for non-constant args.

return ("apply", str(val.expr), tuple(arg_keys))
if isinstance(val, OpResult):
c = get_const_val(val)
if c is not None:
return ("const", c)
return ("val", id(val))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check this once: id(val) is the memory address of the python object. Two structurally identical OpResult values that happen to be different python wrapper objects would get different keys. This will defeat caching.

return ("other", id(val))

def _should_use_magic(rhs_expr) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For discussion:
On the first encounter, the divisor is recorded and magic apply is declined.
On the second, magic is approved and the precomputation is done. But this means the first division uses the slow arith.divui path while the second and subsequent use the fast magic path.

This creates an asymmetry. If a divisor appears in exactly floordiv then mod (common pattern for delinearization), the floordiv gets the slow path and only mod gets the fast path.
The floordiv result is computed via affine apply (which lowers to arith.divsi), then the mod is computed via magic. The two results will be inconsistent.

Do you think, a cleaner approach would be a two-pass design?

  1. scan all expressions first to count divisor occurrences,
  2. then emit with magic for any divisor that appears 2+ times.

"""Return True only when a dynamic divisor is seen for the second time.

First encounter: record and decline (no benefit over a single div).
Second+ encounter: the precomputation is amortised, so use magic.
"""
key = _divisor_key(rhs_expr)
if key in _magic_number_cache:
return True
if key in _magic_divisor_first_seen:
return True
_magic_divisor_first_seen[key] = True
return False

def _mulhi_u32(n_i32, m_i32):
"""Unsigned 32-bit multiply-high: (n * m) >> 32, via 64-bit multiply."""
i64 = IntegerType.get_signless(64)
c32_i64 = arith_d.constant(i64, 32)
n_i64 = arith_d.extui(i64, n_i32)
m_i64 = arith_d.extui(i64, m_i32)
prod_i64 = arith_d.muli(n_i64, m_i64)
hi_i64 = arith_d.shrui(prod_i64, c32_i64)
i32 = IntegerType.get_signless(32)
return arith_d.trunci(i32, hi_i64)

def _precompute_magic_number(divisor_index: Value):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For d=1, this equals 2^32, which doesn't fit in i32.

For d = 1: magic = (2^32 + 0) / 1 = 2^32, truncated to i32 = 0. Then _mulhi_u32(n, 0) = 0 for all n, so n // 1 would return 0 instead of n.
For d = 2: magic = (2^32 + 1) / 2 = 2147483649, which fits in i32. which is fine.
For d = 3: magic = (2^32 + 2) / 3 = 1431655766. fine and so on.

May be, dynamic divisors derived from kernel dimensions (BLOCK_M, BLOCK_N, etc.) are always>1, but better to have a guard here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I added cheap IR which checks at runtime whether the divisor==1, if it is, it then picks the dividend as quotient result and the remainder is set to 0.
This adds only 3 instructions, should be cheap.

"""
Compute magic = ceil(2^32 / d) from a dynamic divisor.
Returns (magic_i32, d_i32) both as i32 Values.
"""
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
d_i32 = arith_d.index_cast(i32, divisor_index)
d_i64 = arith_d.extui(i64, d_i32)
c1_i64 = arith_d.constant(i64, 1)
c32_i64 = arith_d.constant(i64, 32)
pow32 = arith_d.shli(c1_i64, c32_i64)
d_minus_1_i64 = arith_d.subi(d_i64, c1_i64)
numer_i64 = arith_d.addi(pow32, d_minus_1_i64)
magic_i64 = arith_d.divui(numer_i64, d_i64)
magic_i32 = arith_d.trunci(i32, magic_i64)
return magic_i32, d_i32

def _get_or_create_magic(divisor_expr):
"""Get cached (magic_i32, d_i32) or compute and cache them.

On cache miss the precomputation is hoisted to the function
entry block so that the magic constant dominates every use.
"""
key = _divisor_key(divisor_expr)
if key in _magic_number_cache:
return _magic_number_cache[key]
if _magic_entry_block is not None:
with InsertionPoint.at_block_begin(_magic_entry_block):
divisor_val = (
_get_ir_value(divisor_expr)
if isinstance(divisor_expr, _ApplyExpr)
else divisor_expr
)
magic_i32, d_i32 = _precompute_magic_number(divisor_val)
else:
divisor_val = (
_get_ir_value(divisor_expr)
if isinstance(divisor_expr, _ApplyExpr)
else divisor_expr
)
magic_i32, d_i32 = _precompute_magic_number(divisor_val)
_magic_number_cache[key] = (magic_i32, d_i32)
return magic_i32, d_i32

def _magic_div_and_rem(lhs_val, rhs_expr):
"""Compute (quotient, remainder) of lhs_val // rhs via mulhi.

Uses unsigned 32-bit arithmetic (extui, divui, shrui, uge).
Requires both operands to be non-negative and fit in 32 bits.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIt: docstring says "requires both operands to be non-negative" but doesn't explain why unsigned is used in the code when the rest of the emitter uses signed (arith.divsi / arith.remsi)

This holds for GPU index computations: dividends are
workgroup/thread indices and divisors are derived from
positive kernel dimensions.
"""
i32 = IntegerType.get_signless(32)
magic_i32, d_i32 = _get_or_create_magic(rhs_expr)
n_i32 = arith_d.index_cast(i32, lhs_val)
q_i32 = _mulhi_u32(n_i32, magic_i32)
qd_i32 = arith_d.muli(q_i32, d_i32)
r_i32 = arith_d.subi(n_i32, qd_i32)
# Correction: ceil(2^32/d) can overestimate quotient by 1.
# Detect via unsigned remainder >= divisor (wraps on overestimate).
too_big = arith_d.cmpi(arith_d.CmpIPredicate.uge, r_i32, d_i32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correction subtracts 1 from the quotient and adds d to the remainder when r >= d. This is correct for the ceil magic formula when d > 0. If d = 0 (division by zero), divui in the precomputation would produce undefined behavior. It is an existing undefined behavior possibility. We can add a note regarding this.

c1_i32 = arith_d.constant(i32, 1)
c0_i32 = arith_d.constant(i32, 0)
corr = arith_d.select(too_big, c1_i32, c0_i32)
q_final = arith_d.subi(q_i32, corr)
d_or_zero = arith_d.select(too_big, d_i32, c0_i32)
r_final = arith_d.addi(r_i32, d_or_zero)
# Guard: when d == 1 the magic number overflows i32 to 0,
# so fall back to the trivial n // 1 = n, n % 1 = 0.
d_is_one = arith_d.cmpi(arith_d.CmpIPredicate.eq, d_i32, c1_i32)
q_final = arith_d.select(d_is_one, n_i32, q_final)
r_final = arith_d.select(d_is_one, c0_i32, r_final)
q_index = arith_d.index_cast(IndexType.get(), q_final)
r_index = arith_d.index_cast(IndexType.get(), r_final)
return q_index, r_index

def rem_expr(lhs, rhs):
if not use_affine_expr or not check_index_types(lhs, rhs):
return arith_d.remsi(*_broadcast(lhs, rhs))

if (
_magic_number_enabled
and _is_dynamic_divisor(rhs)
and _should_use_magic(rhs)
):
lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs
_, r = _magic_div_and_rem(lhs_val, rhs)
return r

return op_expr(lhs, rhs, lambda a, b: a % b)

def floordiv_expr(lhs, rhs):
if not use_affine_expr or not check_index_types(lhs, rhs):
return arith_d.divsi(*_broadcast(lhs, rhs))

if (
_magic_number_enabled
and _is_dynamic_divisor(rhs)
and _should_use_magic(rhs)
):
lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs
q, _ = _magic_div_and_rem(lhs_val, rhs)
return q

return op_expr(lhs, rhs, lambda a, b: AffineExpr.get_floor_div(a, b))

def ceildiv_expr(lhs, rhs):
Expand Down
1 change: 1 addition & 0 deletions wave_lang/kernel/wave/compile_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class WaveCompileOptions:
enable_mark_hardware_transpose_candidates: bool = True

# === Compiler options ===
magic_number_div: bool = False
minimize_shared_allocs: bool = True
reorder_allocs: bool = True
override_schedule: Optional[str] = None
Expand Down
Loading