-
Notifications
You must be signed in to change notification settings - Fork 32
Emit magic-number division for dynamic kernels #1222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fc2fbe4
64d1bfd
e0160c6
a0ac1e0
07a6063
421c49e
78cf1e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| # RUN: python %s | FileCheck %s | ||
|
|
||
| from sympy import ceiling | ||
|
|
||
| import wave_lang.kernel.lang as tkl | ||
| import wave_lang.kernel.wave as tkw | ||
| from wave_lang.kernel.lang.global_symbols import * | ||
| from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile | ||
| from wave_lang.kernel.wave.utils.general_utils import ( | ||
| run_test, | ||
| ) | ||
|
|
||
| M = tkl.sym.M | ||
| N = tkl.sym.N | ||
| K = tkl.sym.K | ||
| BLOCK_M = tkl.sym.BLOCK_M | ||
| BLOCK_N = tkl.sym.BLOCK_N | ||
| BLOCK_K = tkl.sym.BLOCK_K | ||
| GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N | ||
| ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE | ||
| ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 | ||
|
|
||
|
|
||
| @run_test | ||
| def test_magic_number_div(): | ||
| """Test that floordiv/mod by dynamic (runtime) divisors are lowered | ||
| to the magic-number multiply-high trick instead of expensive hardware | ||
| division. | ||
|
|
||
| When kernel dimensions are dynamic, the compiler cannot fold | ||
| floordiv/mod into compile-time constants. The magic-number | ||
| optimisation precomputes ``ceil(2^32 / d)`` once per unique divisor | ||
| and replaces every subsequent division with a 64-bit multiply + shift, | ||
| which is significantly cheaper on GPU. | ||
|
|
||
| We use a GEMM with GROUP_SIZE_N workgroup reordering to exercise | ||
| this: the reordering delinearises the flat workgroup id via | ||
| ``ceildiv(M, BLOCK_M)``, and the GEMM's multiple memory accesses | ||
| (read A, read B, write C) each independently compute reordered | ||
| indices, producing enough dynamic floordiv/mod expressions to | ||
| demonstrate that the expensive magic-number precomputation | ||
| (a single divui) is performed once per divisor and then reused | ||
| by multiple cheap multiply-and-shift sequences. | ||
| """ | ||
| constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] | ||
| constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] | ||
| constraints += [tkw.TilingConstraint(K, BLOCK_K)] | ||
| constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] | ||
| constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] | ||
| constraints += [ | ||
| tkw.HardwareConstraint( | ||
| threads_per_wave=64, | ||
| waves_per_block=(2, 2, 1), | ||
| mma_type=tkw.MMAType.F32_16x16x16_F16, | ||
| ) | ||
| ] | ||
|
|
||
| wg0, wg1 = WORKGROUP_0, WORKGROUP_1 | ||
| num_wg_0 = ceiling(M / BLOCK_M) | ||
|
|
||
| flat_wg_index = wg1 * num_wg_0 + wg0 | ||
| num_wg_group = GROUP_SIZE_N * num_wg_0 | ||
| group_id = flat_wg_index // num_wg_group | ||
| first_wg_id_1 = group_id * GROUP_SIZE_N | ||
| new_wg0 = (flat_wg_index % num_wg_group) // GROUP_SIZE_N | ||
| new_wg1 = first_wg_id_1 + (flat_wg_index % num_wg_group) % GROUP_SIZE_N | ||
|
|
||
| constraints += [tkw.ReorderingConstraint(new_wg0, 0)] | ||
| constraints += [tkw.ReorderingConstraint(new_wg1, 1)] | ||
|
|
||
| @tkw.wave(constraints) | ||
| def gemm( | ||
| a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], | ||
| b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], | ||
| c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], | ||
| ): | ||
| c_reg = tkl.Register[M, N, tkl.f32](0.0) | ||
|
|
||
| @tkw.iterate(K, init_args=[c_reg]) | ||
| def repeat( | ||
| acc: tkl.Register[M, N, tkl.f32], | ||
| ) -> tkl.Register[M, N, tkl.f32]: | ||
| a_reg = tkw.read(a) | ||
| b_reg = tkw.read(b) | ||
| acc = tkw.mma(a_reg, b_reg, acc) | ||
| return acc | ||
|
|
||
| tkw.write(repeat, c) | ||
|
|
||
| options = WaveCompileOptions( | ||
| subs={ | ||
| M: 512, | ||
| N: 1024, | ||
| K: 256, | ||
| BLOCK_M: 64, | ||
| BLOCK_N: 64, | ||
| BLOCK_K: 32, | ||
| GROUP_SIZE_N: 4, | ||
| ADDRESS_SPACE: SHARED_ADDRESS_SPACE, | ||
| ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, | ||
| }, | ||
| canonicalize=True, | ||
| compile_to_mlir=True, | ||
| magic_number_div=True, | ||
| ) | ||
|
|
||
| options.dynamic_symbols = [M, N, K] | ||
| for sym in options.dynamic_symbols: | ||
| del options.subs[sym] | ||
|
|
||
| gemm = wave_compile(options, gemm) | ||
| print(gemm.asm) | ||
|
|
||
| # CHECK-LABEL: func.func @gemm | ||
| # CHECK-DAG: arith.constant 4294967295 : i64 | ||
| # CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i64 | ||
| # | ||
| # Magic precomputation: one divui per unique dynamic divisor. | ||
| # CHECK: arith.divui {{.*}} : i64 | ||
| # CHECK: arith.divui {{.*}} : i64 | ||
| # | ||
| # Multiply-high (shrui >> 32) reusing precomputed magic numbers. | ||
| # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 | ||
| # | ||
| # Amortised: mulhi reusing a previously computed magic number | ||
| # with a different dividend — no new divui needed. | ||
| # CHECK-NOT: arith.divui | ||
| # CHECK-NOT: arith.divsi | ||
| # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 | ||
| # CHECK-NOT: arith.divui | ||
| # CHECK-NOT: arith.divsi | ||
| # CHECK: return |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,7 +261,13 @@ def get_static_dim(s: Optional[IndexExpr]) -> int: | |
| return func_op | ||
|
|
||
| def emit(self, graph: Optional[fx.Graph] = None) -> Operation: | ||
| global _magic_number_enabled, _magic_number_cache, _magic_entry_block, _magic_divisor_first_seen | ||
| _magic_number_enabled = self.options.magic_number_div | ||
| _magic_number_cache = {} | ||
| _magic_divisor_first_seen = {} | ||
|
|
||
| func = self.emit_func() | ||
| _magic_entry_block = func.entry_block | ||
| with InsertionPoint.at_block_terminator(func.entry_block), Location.unknown(): | ||
| self._emit_graph( | ||
| graph if graph is not None else self.trace.get_root_graph() | ||
|
|
@@ -633,10 +639,15 @@ def add_emitter_subs( | |
|
|
||
| _emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0))) | ||
| _use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1))) | ||
| _magic_number_enabled = False | ||
| _magic_entry_block = None | ||
|
|
||
| _Rational = namedtuple("_Rational", ["numerator", "denominator"]) | ||
| _ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"]) | ||
|
|
||
| _magic_number_cache: dict = {} | ||
| _magic_divisor_first_seen: dict = {} | ||
|
|
||
|
|
||
| def gen_sympy_index(dynamics: dict[IndexSymbol, Value], expr: sympy.Expr) -> Value: | ||
| use_affine_expr = _use_affine_expr | ||
|
|
@@ -778,16 +789,163 @@ def muli_expr(lhs, rhs): | |
|
|
||
| return op_expr(lhs, rhs, lambda a, b: a * b) | ||
|
|
||
| def _is_dynamic_divisor(val) -> bool: | ||
| """Check if a value is NOT a compile-time constant.""" | ||
| if isinstance(val, _ApplyExpr): | ||
| return not all( | ||
| isinstance(a, OpResult) and get_const_val(a) is not None | ||
| for a in val.args | ||
| ) | ||
| if isinstance(val, OpResult): | ||
| return get_const_val(val) is None | ||
| return True | ||
|
|
||
| def _divisor_key(val): | ||
| """Hashable key that identifies a divisor by its structure.""" | ||
| if isinstance(val, _ApplyExpr): | ||
| arg_keys = [] | ||
| for a in val.args: | ||
| c = get_const_val(a) if isinstance(a, OpResult) else None | ||
| arg_keys.append(("const", c) if c is not None else ("val", id(a))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to this |
||
| return ("apply", str(val.expr), tuple(arg_keys)) | ||
| if isinstance(val, OpResult): | ||
| c = get_const_val(val) | ||
| if c is not None: | ||
| return ("const", c) | ||
| return ("val", id(val)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check this once: id(val) is the memory address of the python object. Two structurally identical |
||
| return ("other", id(val)) | ||
|
|
||
| def _should_use_magic(rhs_expr) -> bool: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For discussion: This creates an asymmetry. If a divisor appears in exactly Do you think, a cleaner approach would be a two-pass design?
|
||
| """Return True only when a dynamic divisor is seen for the second time. | ||
|
|
||
| First encounter: record and decline (no benefit over a single div). | ||
| Second+ encounter: the precomputation is amortised, so use magic. | ||
| """ | ||
| key = _divisor_key(rhs_expr) | ||
| if key in _magic_number_cache: | ||
| return True | ||
| if key in _magic_divisor_first_seen: | ||
| return True | ||
| _magic_divisor_first_seen[key] = True | ||
| return False | ||
|
|
||
| def _mulhi_u32(n_i32, m_i32): | ||
| """Unsigned 32-bit multiply-high: (n * m) >> 32, via 64-bit multiply.""" | ||
| i64 = IntegerType.get_signless(64) | ||
| c32_i64 = arith_d.constant(i64, 32) | ||
| n_i64 = arith_d.extui(i64, n_i32) | ||
| m_i64 = arith_d.extui(i64, m_i32) | ||
| prod_i64 = arith_d.muli(n_i64, m_i64) | ||
| hi_i64 = arith_d.shrui(prod_i64, c32_i64) | ||
| i32 = IntegerType.get_signless(32) | ||
| return arith_d.trunci(i32, hi_i64) | ||
|
|
||
| def _precompute_magic_number(divisor_index: Value): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For For d = 1: magic = May be, dynamic divisors derived from kernel dimensions (BLOCK_M, BLOCK_N, etc.) are always>1, but better to have a guard here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, I added cheap IR which checks at runtime whether the divisor==1, if it is, it then picks the dividend as quotient result and the remainder is set to 0. |
||
| """ | ||
| Compute magic = ceil(2^32 / d) from a dynamic divisor. | ||
| Returns (magic_i32, d_i32) both as i32 Values. | ||
| """ | ||
| i32 = IntegerType.get_signless(32) | ||
| i64 = IntegerType.get_signless(64) | ||
| d_i32 = arith_d.index_cast(i32, divisor_index) | ||
| d_i64 = arith_d.extui(i64, d_i32) | ||
| c1_i64 = arith_d.constant(i64, 1) | ||
| c32_i64 = arith_d.constant(i64, 32) | ||
| pow32 = arith_d.shli(c1_i64, c32_i64) | ||
| d_minus_1_i64 = arith_d.subi(d_i64, c1_i64) | ||
| numer_i64 = arith_d.addi(pow32, d_minus_1_i64) | ||
| magic_i64 = arith_d.divui(numer_i64, d_i64) | ||
| magic_i32 = arith_d.trunci(i32, magic_i64) | ||
| return magic_i32, d_i32 | ||
|
|
||
| def _get_or_create_magic(divisor_expr): | ||
| """Get cached (magic_i32, d_i32) or compute and cache them. | ||
|
|
||
| On cache miss the precomputation is hoisted to the function | ||
| entry block so that the magic constant dominates every use. | ||
| """ | ||
| key = _divisor_key(divisor_expr) | ||
| if key in _magic_number_cache: | ||
| return _magic_number_cache[key] | ||
| if _magic_entry_block is not None: | ||
| with InsertionPoint.at_block_begin(_magic_entry_block): | ||
| divisor_val = ( | ||
| _get_ir_value(divisor_expr) | ||
| if isinstance(divisor_expr, _ApplyExpr) | ||
| else divisor_expr | ||
| ) | ||
| magic_i32, d_i32 = _precompute_magic_number(divisor_val) | ||
| else: | ||
| divisor_val = ( | ||
| _get_ir_value(divisor_expr) | ||
| if isinstance(divisor_expr, _ApplyExpr) | ||
| else divisor_expr | ||
| ) | ||
| magic_i32, d_i32 = _precompute_magic_number(divisor_val) | ||
| _magic_number_cache[key] = (magic_i32, d_i32) | ||
| return magic_i32, d_i32 | ||
|
|
||
| def _magic_div_and_rem(lhs_val, rhs_expr): | ||
| """Compute (quotient, remainder) of lhs_val // rhs via mulhi. | ||
|
|
||
| Uses unsigned 32-bit arithmetic (extui, divui, shrui, uge). | ||
| Requires both operands to be non-negative and fit in 32 bits. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIt: docstring says "requires both operands to be non-negative" but doesn't explain why unsigned is used in the code when the rest of the emitter uses signed (arith.divsi / arith.remsi) |
||
| This holds for GPU index computations: dividends are | ||
| workgroup/thread indices and divisors are derived from | ||
| positive kernel dimensions. | ||
| """ | ||
| i32 = IntegerType.get_signless(32) | ||
| magic_i32, d_i32 = _get_or_create_magic(rhs_expr) | ||
| n_i32 = arith_d.index_cast(i32, lhs_val) | ||
| q_i32 = _mulhi_u32(n_i32, magic_i32) | ||
| qd_i32 = arith_d.muli(q_i32, d_i32) | ||
| r_i32 = arith_d.subi(n_i32, qd_i32) | ||
| # Correction: ceil(2^32/d) can overestimate quotient by 1. | ||
| # Detect via unsigned remainder >= divisor (wraps on overestimate). | ||
| too_big = arith_d.cmpi(arith_d.CmpIPredicate.uge, r_i32, d_i32) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The correction subtracts 1 from the quotient and adds d to the remainder when r >= d. This is correct for the ceil magic formula when d > 0. If d = 0 (division by zero), divui in the precomputation would produce undefined behavior. It is an existing undefined behavior possibility. We can add a note regarding this. |
||
| c1_i32 = arith_d.constant(i32, 1) | ||
| c0_i32 = arith_d.constant(i32, 0) | ||
| corr = arith_d.select(too_big, c1_i32, c0_i32) | ||
| q_final = arith_d.subi(q_i32, corr) | ||
| d_or_zero = arith_d.select(too_big, d_i32, c0_i32) | ||
| r_final = arith_d.addi(r_i32, d_or_zero) | ||
| # Guard: when d == 1 the magic number overflows i32 to 0, | ||
| # so fall back to the trivial n // 1 = n, n % 1 = 0. | ||
| d_is_one = arith_d.cmpi(arith_d.CmpIPredicate.eq, d_i32, c1_i32) | ||
| q_final = arith_d.select(d_is_one, n_i32, q_final) | ||
| r_final = arith_d.select(d_is_one, c0_i32, r_final) | ||
| q_index = arith_d.index_cast(IndexType.get(), q_final) | ||
| r_index = arith_d.index_cast(IndexType.get(), r_final) | ||
| return q_index, r_index | ||
|
|
||
| def rem_expr(lhs, rhs): | ||
| if not use_affine_expr or not check_index_types(lhs, rhs): | ||
| return arith_d.remsi(*_broadcast(lhs, rhs)) | ||
|
|
||
| if ( | ||
| _magic_number_enabled | ||
| and _is_dynamic_divisor(rhs) | ||
| and _should_use_magic(rhs) | ||
| ): | ||
| lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs | ||
| _, r = _magic_div_and_rem(lhs_val, rhs) | ||
| return r | ||
|
|
||
| return op_expr(lhs, rhs, lambda a, b: a % b) | ||
|
|
||
| def floordiv_expr(lhs, rhs): | ||
| if not use_affine_expr or not check_index_types(lhs, rhs): | ||
| return arith_d.divsi(*_broadcast(lhs, rhs)) | ||
|
|
||
| if ( | ||
| _magic_number_enabled | ||
| and _is_dynamic_divisor(rhs) | ||
| and _should_use_magic(rhs) | ||
| ): | ||
| lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs | ||
| q, _ = _magic_div_and_rem(lhs_val, rhs) | ||
| return q | ||
|
|
||
| return op_expr(lhs, rhs, lambda a, b: AffineExpr.get_floor_div(a, b)) | ||
|
|
||
| def ceildiv_expr(lhs, rhs): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting module level globals at R642, and then resetting it here.
This is not thread-safe.
caches are read-write per-compilation state.