Skip to content

Conversation

@harsh-nod
Copy link
Collaborator

This PR adds a asm lowering of a simple mma kernel. The lhs and rhs are promoted to shared memory and then fed to a 16x16x16 MMA. A lit test and e2e test are added for correctness.

Copilot AI review requested due to automatic review settings October 30, 2025 22:52
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR significantly expands the ASM backend to support MFMA (Matrix Multiply-Accumulate) operations on AMD CDNA architectures. The key changes include:

  • Refactored SymPy alias from sp to sympy for consistency
  • Added comprehensive MFMA support with AGPR allocation and LDS staging
  • Implemented generic expression emitter for complex SymPy expressions
  • Added dynamic register allocation with architecture-specific granularities
  • Refactored operation handlers into a separate module for better organization
  • Enhanced instruction set with DS (LDS), MFMA, and additional buffer operations

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
wave_lang/kernel/wave/asm/utils.py Changed SymPy import alias from sp to sympy; added TYPE_CHECKING for forward references; added byte offset expression builders
wave_lang/kernel/wave/asm/register_allocator.py Removed implicit v0 reservation; added AGPRAllocator; enhanced VGPR allocator with reserve/free methods
wave_lang/kernel/wave/asm/mlir_walker.py Refactored operation handling to handlers.py; added LDS and MFMA tracking state
wave_lang/kernel/wave/asm/kernel_model.py Added lds_size_bytes field
wave_lang/kernel/wave/asm/instructions.py Added LDS, MFMA, AGPR, and additional buffer instructions
wave_lang/kernel/wave/asm/handlers.py New file: extracted operation handlers from mlir_walker.py
wave_lang/kernel/wave/asm/expression_emitter.py New file: generic SymPy expression visitor for AMDGCN assembly
wave_lang/kernel/wave/asm/asm_emitter.py Enhanced with MFMA support, dynamic register allocation, LDS operations
wave_lang/kernel/wave/asm/init.py Removed affine_map_simplifies_to_tid_x export
tests/kernel/wave/expression_emitter_test.py New file: comprehensive tests for expression emitter
tests/kernel/wave/asm_backend_test.py Added MFMA kernel test
lit_tests/kernel/wave/asm.py Updated tests for dynamic register allocation; added MFMA test
docs/wave/asm_backend.rst Comprehensive documentation updates for MFMA and new features

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if hasattr(value_attribute, "value"):
constant_value = int(value_attribute.value)
kernel_info.index_env[str(operation.result)] = constant_value
except (AttributeError, ValueError, TypeError) as e:
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception variable e is caught but never used. Either remove it from the except clause if it's not needed, or log it for debugging purposes. Consider: except (AttributeError, ValueError, TypeError):

Suggested change
except (AttributeError, ValueError, TypeError) as e:
except (AttributeError, ValueError, TypeError):

Copilot uses AI. Check for mistakes.
Comment on lines 177 to 179
except (AttributeError, TypeError):
# If simplification fails, expression may not be convertible to constant
pass
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Comment 1, consider removing the unused exception variables throughout the file for consistency. This pattern appears multiple times (lines 177, 314, 411, 435).

Copilot uses AI. Check for mistakes.
Comment on lines 455 to 456
import sympy as sp

Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent SymPy import alias. The file imports sympy at the module level, but here it's imported as sp. For consistency with the PR's refactoring goal (changing sp → sympy), use import sympy or reference the module-level import.

Suggested change
import sympy as sp

Copilot uses AI. Check for mistakes.
Uses the actual indices from vector.store and the memref's stride information
to compute the byte offset, rather than forcing lane-linear packing.
"""
import sympy as sp
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another inconsistent SymPy import alias. This should use the module-level sympy import for consistency.

Copilot uses AI. Check for mistakes.
Comment on lines 231 to 249
arg_names = [f"arg{i}_ptr" for i in range(num_args)]
if kernel_name == "mma" and num_args == 3:
arg_names = ["A_ptr", "B_ptr", "C_ptr"]
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hardcoded special case for kernel name 'mma' creates tight coupling and fragility. Consider making argument names configurable through KernelInfo or metadata, rather than hardcoding a specific kernel name check.

Suggested change
arg_names = [f"arg{i}_ptr" for i in range(num_args)]
if kernel_name == "mma" and num_args == 3:
arg_names = ["A_ptr", "B_ptr", "C_ptr"]
# Use argument names from KernelInfo if available, otherwise default
arg_names = getattr(self.kernel_info, "arg_names", None)
if not arg_names or len(arg_names) != num_args:
arg_names = [f"arg{i}_ptr" for i in range(num_args)]

Copilot uses AI. Check for mistakes.
Comment on lines +169 to +168
acc = tkw.mma(a_reg, b_reg, c_reg)
tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of mma is used even though it is always None.

Suggested change
acc = tkw.mma(a_reg, b_reg, c_reg)
tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
tkw.mma(a_reg, b_reg, c_reg)
tkw.write(c_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

Copilot uses AI. Check for mistakes.

return self.walker.emitter.materialize_byte_offset_expr(kernel_info, byte_offset_expr)

def _extract_source_registers(self, vector_bytes):
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OperationHandlers._extract_source_registers returns tuple of size 2 and tuple of size 4.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings October 31, 2025 14:59
@harsh-nod harsh-nod force-pushed the wave_gemm_asm branch 2 times, most recently from 825d8a0 to 80370a2 Compare October 31, 2025 15:00
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +212 to +358
const_v = self.emitter.vgpr_allocator.alloc_v()
self.emitter.emit_instruction(VMovB32(const_v, multiplier))
self.emitter.emit_instruction(VMulLoU32(temp_v, src_v, const_v))
self.emitter.vgpr_allocator.free_v(const_v)
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] For non-power-of-2 multipliers, consider using inline constants instead of allocating a VGPR. AMDGCN instructions can accept inline constants directly in many cases, which would save a register allocation and a move instruction.

Suggested change
const_v = self.emitter.vgpr_allocator.alloc_v()
self.emitter.emit_instruction(VMovB32(const_v, multiplier))
self.emitter.emit_instruction(VMulLoU32(temp_v, src_v, const_v))
self.emitter.vgpr_allocator.free_v(const_v)
# Use inline constant for multiplier
self.emitter.emit_instruction(VMulLoU32(temp_v, src_v, multiplier))
# No VGPR allocation needed for inline constant
# (no free_v call needed)

Copilot uses AI. Check for mistakes.
Comment on lines +164 to +167
self.emit(" .amdhsa_system_sgpr_workgroup_id_x 0")
self.emit(" .amdhsa_system_sgpr_workgroup_id_y 0")
self.emit(" .amdhsa_system_sgpr_workgroup_id_z 0")
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workgroup IDs are now always disabled (set to 0). If any kernels require workgroup IDs in the future, this will be a hard-to-debug issue. Consider adding a comment explaining why these are disabled or making this configurable based on kernel requirements.

Copilot uses AI. Check for mistakes.
def __init__(self):
self.s_max = 1 # s[0:1] reserved for kernarg ptr
self.v_used = set([0]) # v0 used by lane id; v2 will be used on demand
self.v_used = set() # No implicit reservations
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment 'No implicit reservations' is misleading because the reserve() method (line 52-55) still allows explicit reservations. Consider updating to: '# Explicit reservations tracked via reserve() method' for clarity.

Suggested change
self.v_used = set() # No implicit reservations
self.v_used = set() # Explicit reservations tracked via reserve() method

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings October 31, 2025 17:16
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 14 out of 14 changed files in this pull request and generated 11 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

from typing import Dict, List, Optional, Tuple, TYPE_CHECKING

import sympy as sp
import sympy
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] While changing from import sympy as sp to import sympy provides consistency, this change affects ~50+ references throughout the file (e.g., sp.Exprsympy.Expr, sp.Symbolsympy.Symbol). Consider whether the benefits of this change outweigh the increased verbosity. If consistency is the goal, ensure this pattern is followed project-wide.

Copilot uses AI. Check for mistakes.

class VGPRAllocator:
def __init__(self, register_file: RegFile, base=2):
def __init__(self, register_file: RegFile, base=0):
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from base=2 to base=0 alters the default starting VGPR index. This is a breaking change if any code relied on the previous default behavior where v2 was the standard voffset register. Ensure all callers are updated or explicitly pass base=2 if they need the old behavior.

Suggested change
def __init__(self, register_file: RegFile, base=0):
def __init__(self, register_file: RegFile, base=2):

Copilot uses AI. Check for mistakes.
Comment on lines +186 to +196
def __init__(
self,
dst_regs: Tuple[int, int],
vindex_reg: Union[str, int],
srd_regs: Tuple[int, int, int, int],
offset: int,
comment: str = None,
):
# Format vindex_reg as string if it's an integer
if isinstance(vindex_reg, int):
vindex_reg = f"v{vindex_reg}"
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern of converting vindex_reg from int to string is repeated in multiple instruction classes (BufferLoadDwordx4, BufferStoreDwordx4, BufferLoadDwordx2, BufferStoreDword). Consider extracting this into a helper function or base class method to reduce code duplication and ensure consistent formatting.

Copilot uses AI. Check for mistakes.
Comment on lines +206 to +352
elif multiplier > 0 and (multiplier & (multiplier - 1)) == 0:
# Power of 2: use shift
shift = multiplier.bit_length() - 1
self.emitter.emit_instruction(VLshlRevB32(temp_v, shift, src_v))
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The power-of-2 check (multiplier & (multiplier - 1)) == 0 incorrectly treats 0 as a power of 2. While the condition multiplier > 0 prevents 0 from reaching this code, if multiplier is 0 (e.g., from constant folding), it would incorrectly enter this branch. Consider using a more explicit check like multiplier > 0 and (multiplier & (multiplier - 1)) == 0 or add a specific check for multiplier == 0.

Copilot uses AI. Check for mistakes.
Comment on lines +242 to +386
if divisor_val <= 0 or (divisor_val & (divisor_val - 1)) != 0:
raise ValueError(f"Mod divisor must be power-of-two, got: {divisor_val}")
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same power-of-2 validation logic appears in both _handle_mod (line 242) and _handle_floor (line 267). Consider extracting this into a helper method like _validate_power_of_two(value, operation_name) to reduce duplication and ensure consistent error messages.

Copilot uses AI. Check for mistakes.
Comment on lines +211 to +221
if self.register_file.a_used:
# AGPRs are allocated, compute how many we need
agprs_used = max(self.register_file.a_used) + 1
# Round up to AGPR granularity (same as VGPR granularity)
agprs_used = (
(agprs_used + vgpr_granularity - 1) // vgpr_granularity
) * vgpr_granularity
# Total arch VGPRs must accommodate both VGPRs and AGPRs
total_arch_vgprs = accum_offset + agprs_used
vgprs_used = max(vgprs_used, total_arch_vgprs)
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The AGPR allocation logic modifies vgprs_used to include AGPR space, but this can be confusing since the variable name suggests it only tracks VGPRs. Consider renaming to total_vgpr_arch_allocation or adding a clear comment explaining that vgprs_used represents the total architectural VGPR count including AGPR-mapped space.

Copilot uses AI. Check for mistakes.
Comment on lines +65 to +84
elif (
hasattr(value, "is_integer")
and callable(value.is_integer)
and value.is_integer()
):
# Handle float-like types that represent exact integers
kernel_info.index_env[str(operation.result)] = int(value)
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for handling float-like types that represent exact integers is fragile. The check hasattr(value, 'is_integer') and callable(value.is_integer) and value.is_integer() may not cover all cases. Consider using isinstance(value, (int, float)) with float.is_integer() check, or catching ValueError when converting to int, for more robust type handling.

Suggested change
elif (
hasattr(value, "is_integer")
and callable(value.is_integer)
and value.is_integer()
):
# Handle float-like types that represent exact integers
kernel_info.index_env[str(operation.result)] = int(value)
elif isinstance(value, float) and value.is_integer():
# Handle float values that represent exact integers
kernel_info.index_env[str(operation.result)] = int(value)
else:
# Attempt to convert other types to int if possible and safe
try:
int_value = int(value)
# Only store if conversion does not lose information
if float(int_value) == float(value):
kernel_info.index_env[str(operation.result)] = int_value
except (ValueError, TypeError):
pass

Copilot uses AI. Check for mistakes.
Comment on lines +216 to +231
if not matched_tid:
# Store the simplified SymPy expression for later ASM emission
kernel_info.index_env[destination_ssa] = simplified_expression
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment 'Store the simplified SymPy expression for later ASM emission' indicates that index_env can now contain both simple values (ints, 'tid.x' strings) and SymPy expressions. This is a significant change to the semantics of index_env. Consider documenting this in the KernelInfo class definition or adding a type hint to clarify that index_env values can be Union[int, str, sympy.Expr].

Copilot uses AI. Check for mistakes.
Comment on lines +585 to +620
import sympy
from .utils import build_memref_byte_offset_expr
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The import of sympy appears mid-function in handle_affine_apply_op (line 188) rather than at the module level. While this works, it's inconsistent with the module-level import on line 585 in _compute_lds_address. Consider moving all imports to the top of the file for consistency and clarity.

Copilot uses AI. Check for mistakes.
"""MMA kernel that computes C = A @ B^T."""
c_reg = tkl.Register[M, N, tkl.f32](0.0)
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of read is used even though it is always None.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings October 31, 2025 23:35
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 14 out of 14 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

kernel_info.index_env[destination_ssa] = constant_value
# Check if the simplified expression is a thread ID symbol
else:
import sympy
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import sympy statement appears inside a function (line 201) rather than at the module level. This import should be moved to the top of the file for consistency and performance. The same pattern appears at line 619. Module-level imports are executed once, while imports inside functions are executed every time the function is called.

Copilot uses AI. Check for mistakes.
kernel_info.index_env[destination_ssa] = constant_value
# Check if the simplified expression is a thread ID symbol
else:
import sympy
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second occurrence of the same issue: import sympy appears inside the _compute_lds_address function at line 619. This should be moved to the module-level imports at the top of the file.

Copilot uses AI. Check for mistakes.
else:
nonconst.append(a)
# Sort for commutativity
nonconst_sorted = sorted(nonconst, key=lambda x: str(x))
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Copilot uses AI. Check for mistakes.
const *= int(a)
else:
nonconst.append(a)
nonconst_sorted = sorted(nonconst, key=lambda x: str(x))
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Copilot uses AI. Check for mistakes.
_RationalReg = namedtuple("_RationalReg", ["numerator_reg", "denominator"])

# Canonicalization helpers for CSE
_TID_SYMBOL_NAMES = {"tid_x", "tid_y", "tid_z"}
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global variable '_TID_SYMBOL_NAMES' is not used.

Suggested change
_TID_SYMBOL_NAMES = {"tid_x", "tid_y", "tid_z"}

Copilot uses AI. Check for mistakes.
Comment on lines +126 to +127
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of read is used even though it is always None.

Suggested change
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
a_reg = tkl.Register[M, K, tkl.f16](0.0)
b_reg = tkl.Register[N, K, tkl.f16](0.0)
tkw.read(a_reg, a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
tkw.read(b_reg, b, elements_per_thread=LOAD_ELEMS_PER_THREAD)

Copilot uses AI. Check for mistakes.
Comment on lines +106 to +118
return ("int", int(e))
if isinstance(e, sympy.Symbol):
return ("sym", str(e))
if isinstance(e, sympy.Add):
return ("add", tuple(to_key(a) for a in e.args))
if isinstance(e, sympy.Mul):
return ("mul", tuple(to_key(a) for a in e.args))
if isinstance(e, sympy.Mod):
return ("mod", to_key(e.args[0]), to_key(e.args[1]))
if getattr(e, "func", None) == sympy.floor:
return ("floor", to_key(e.args[0]))
# Generic fallback
return ("raw", str(e))
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.

Suggested change
return ("int", int(e))
if isinstance(e, sympy.Symbol):
return ("sym", str(e))
if isinstance(e, sympy.Add):
return ("add", tuple(to_key(a) for a in e.args))
if isinstance(e, sympy.Mul):
return ("mul", tuple(to_key(a) for a in e.args))
if isinstance(e, sympy.Mod):
return ("mod", to_key(e.args[0]), to_key(e.args[1]))
if getattr(e, "func", None) == sympy.floor:
return ("floor", to_key(e.args[0]))
# Generic fallback
return ("raw", str(e))
return ("int", (int(e),))
if isinstance(e, sympy.Symbol):
return ("sym", (str(e),))
if isinstance(e, sympy.Add):
return ("add", tuple(to_key(a) for a in e.args))
if isinstance(e, sympy.Mul):
return ("mul", tuple(to_key(a) for a in e.args))
if isinstance(e, sympy.Mod):
return ("mod", (to_key(e.args[0]), to_key(e.args[1])))
if getattr(e, "func", None) == sympy.floor:
return ("floor", (to_key(e.args[0]),))
# Generic fallback
return ("raw", (str(e),))

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@ashay ashay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot that I simply glossed over, since I don't have enough of a background to review carefully, but overall LGTM!

Comment on lines 58 to 61
return (4, 8)
else:
# Default to CDNA granularity for unknown targets
return (4, 8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean this to be different?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it was a placeholder for different architectures but will remove since we only support 2 for now.

Copilot AI review requested due to automatic review settings November 5, 2025 02:02
This PR adds a asm lowering of a simple mma kernel.
The lhs and rhs are promoted to shared memory and then
fed to a 16x16x16 MMA. A lit test and e2e test are
added for correctness.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 14 out of 14 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (1)

wave_lang/kernel/wave/asm/expression_emitter.py:1

  • The check dynamic_expr.is_zero is missing parentheses for the method call. It should be dynamic_expr.is_zero() to actually invoke the method. Without parentheses, this evaluates the method object itself which is always truthy, making this condition ineffective.
"""

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


# If entirely constant, use zero VGPR
if dynamic_expr == 0 or (
hasattr(dynamic_expr, "is_zero") and dynamic_expr.is_zero
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check dynamic_expr.is_zero is missing parentheses for the method call. It should be dynamic_expr.is_zero(). This same issue appears in multiple locations (lines 521-523, 555-557, 755-757) and needs to be fixed consistently throughout the file.

Suggested change
hasattr(dynamic_expr, "is_zero") and dynamic_expr.is_zero
hasattr(dynamic_expr, "is_zero") and dynamic_expr.is_zero()

Copilot uses AI. Check for mistakes.
const *= int(a)
else:
nonconst.append(a)
nonconst_sorted = sorted(nonconst, key=lambda x: str(x))
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Suggested change
nonconst_sorted = sorted(nonconst, key=lambda x: str(x))
nonconst_sorted = sorted(nonconst, key=str)

Copilot uses AI. Check for mistakes.
"""Build a structural, hashable key for an expression."""
expr = canonicalize_expr(expr)

def to_key(e):
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.

Copilot uses AI. Check for mistakes.
@harsh-nod harsh-nod merged commit 3609f01 into iree-org:main Nov 5, 2025
19 checks passed
Megan0704-1 pushed a commit to Megan0704-1/wave that referenced this pull request Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants