-
Notifications
You must be signed in to change notification settings - Fork 25
Add sample MMA lowering for asm backend #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR significantly expands the ASM backend to support MFMA (Matrix Multiply-Accumulate) operations on AMD CDNA architectures. The key changes include:
- Refactored SymPy alias from
sptosympyfor consistency - Added comprehensive MFMA support with AGPR allocation and LDS staging
- Implemented generic expression emitter for complex SymPy expressions
- Added dynamic register allocation with architecture-specific granularities
- Refactored operation handlers into a separate module for better organization
- Enhanced instruction set with DS (LDS), MFMA, and additional buffer operations
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| wave_lang/kernel/wave/asm/utils.py | Changed SymPy import alias from sp to sympy; added TYPE_CHECKING for forward references; added byte offset expression builders |
| wave_lang/kernel/wave/asm/register_allocator.py | Removed implicit v0 reservation; added AGPRAllocator; enhanced VGPR allocator with reserve/free methods |
| wave_lang/kernel/wave/asm/mlir_walker.py | Refactored operation handling to handlers.py; added LDS and MFMA tracking state |
| wave_lang/kernel/wave/asm/kernel_model.py | Added lds_size_bytes field |
| wave_lang/kernel/wave/asm/instructions.py | Added LDS, MFMA, AGPR, and additional buffer instructions |
| wave_lang/kernel/wave/asm/handlers.py | New file: extracted operation handlers from mlir_walker.py |
| wave_lang/kernel/wave/asm/expression_emitter.py | New file: generic SymPy expression visitor for AMDGCN assembly |
| wave_lang/kernel/wave/asm/asm_emitter.py | Enhanced with MFMA support, dynamic register allocation, LDS operations |
| wave_lang/kernel/wave/asm/init.py | Removed affine_map_simplifies_to_tid_x export |
| tests/kernel/wave/expression_emitter_test.py | New file: comprehensive tests for expression emitter |
| tests/kernel/wave/asm_backend_test.py | Added MFMA kernel test |
| lit_tests/kernel/wave/asm.py | Updated tests for dynamic register allocation; added MFMA test |
| docs/wave/asm_backend.rst | Comprehensive documentation updates for MFMA and new features |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if hasattr(value_attribute, "value"): | ||
| constant_value = int(value_attribute.value) | ||
| kernel_info.index_env[str(operation.result)] = constant_value | ||
| except (AttributeError, ValueError, TypeError) as e: |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception variable e is caught but never used. Either remove it from the except clause if it's not needed, or log it for debugging purposes. Consider: except (AttributeError, ValueError, TypeError):
| except (AttributeError, ValueError, TypeError) as e: | |
| except (AttributeError, ValueError, TypeError): |
| except (AttributeError, TypeError): | ||
| # If simplification fails, expression may not be convertible to constant | ||
| pass |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to Comment 1, consider removing the unused exception variables throughout the file for consistency. This pattern appears multiple times (lines 177, 314, 411, 435).
| import sympy as sp | ||
|
|
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent SymPy import alias. The file imports sympy at the module level, but here it's imported as sp. For consistency with the PR's refactoring goal (changing sp → sympy), use import sympy or reference the module-level import.
| import sympy as sp | |
| Uses the actual indices from vector.store and the memref's stride information | ||
| to compute the byte offset, rather than forcing lane-linear packing. | ||
| """ | ||
| import sympy as sp |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another inconsistent SymPy import alias. This should use the module-level sympy import for consistency.
| arg_names = [f"arg{i}_ptr" for i in range(num_args)] | ||
| if kernel_name == "mma" and num_args == 3: | ||
| arg_names = ["A_ptr", "B_ptr", "C_ptr"] |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hardcoded special case for kernel name 'mma' creates tight coupling and fragility. Consider making argument names configurable through KernelInfo or metadata, rather than hardcoding a specific kernel name check.
| arg_names = [f"arg{i}_ptr" for i in range(num_args)] | |
| if kernel_name == "mma" and num_args == 3: | |
| arg_names = ["A_ptr", "B_ptr", "C_ptr"] | |
| # Use argument names from KernelInfo if available, otherwise default | |
| arg_names = getattr(self.kernel_info, "arg_names", None) | |
| if not arg_names or len(arg_names) != num_args: | |
| arg_names = [f"arg{i}_ptr" for i in range(num_args)] |
| acc = tkw.mma(a_reg, b_reg, c_reg) | ||
| tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD) |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result of mma is used even though it is always None.
| acc = tkw.mma(a_reg, b_reg, c_reg) | |
| tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD) | |
| tkw.mma(a_reg, b_reg, c_reg) | |
| tkw.write(c_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD) |
|
|
||
| return self.walker.emitter.materialize_byte_offset_expr(kernel_info, byte_offset_expr) | ||
|
|
||
| def _extract_source_registers(self, vector_bytes): |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OperationHandlers._extract_source_registers returns tuple of size 2 and tuple of size 4.
5c0015c to
f0824c4
Compare
825d8a0 to
80370a2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const_v = self.emitter.vgpr_allocator.alloc_v() | ||
| self.emitter.emit_instruction(VMovB32(const_v, multiplier)) | ||
| self.emitter.emit_instruction(VMulLoU32(temp_v, src_v, const_v)) | ||
| self.emitter.vgpr_allocator.free_v(const_v) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] For non-power-of-2 multipliers, consider using inline constants instead of allocating a VGPR. AMDGCN instructions can accept inline constants directly in many cases, which would save a register allocation and a move instruction.
| const_v = self.emitter.vgpr_allocator.alloc_v() | |
| self.emitter.emit_instruction(VMovB32(const_v, multiplier)) | |
| self.emitter.emit_instruction(VMulLoU32(temp_v, src_v, const_v)) | |
| self.emitter.vgpr_allocator.free_v(const_v) | |
| # Use inline constant for multiplier | |
| self.emitter.emit_instruction(VMulLoU32(temp_v, src_v, multiplier)) | |
| # No VGPR allocation needed for inline constant | |
| # (no free_v call needed) |
| self.emit(" .amdhsa_system_sgpr_workgroup_id_x 0") | ||
| self.emit(" .amdhsa_system_sgpr_workgroup_id_y 0") | ||
| self.emit(" .amdhsa_system_sgpr_workgroup_id_z 0") |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Workgroup IDs are now always disabled (set to 0). If any kernels require workgroup IDs in the future, this will be a hard-to-debug issue. Consider adding a comment explaining why these are disabled or making this configurable based on kernel requirements.
| def __init__(self): | ||
| self.s_max = 1 # s[0:1] reserved for kernarg ptr | ||
| self.v_used = set([0]) # v0 used by lane id; v2 will be used on demand | ||
| self.v_used = set() # No implicit reservations |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment 'No implicit reservations' is misleading because the reserve() method (line 52-55) still allows explicit reservations. Consider updating to: '# Explicit reservations tracked via reserve() method' for clarity.
| self.v_used = set() # No implicit reservations | |
| self.v_used = set() # Explicit reservations tracked via reserve() method |
80370a2 to
78f84ca
Compare
78f84ca to
d386b83
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 14 out of 14 changed files in this pull request and generated 11 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from typing import Dict, List, Optional, Tuple, TYPE_CHECKING | ||
|
|
||
| import sympy as sp | ||
| import sympy |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] While changing from import sympy as sp to import sympy provides consistency, this change affects ~50+ references throughout the file (e.g., sp.Expr → sympy.Expr, sp.Symbol → sympy.Symbol). Consider whether the benefits of this change outweigh the increased verbosity. If consistency is the goal, ensure this pattern is followed project-wide.
|
|
||
| class VGPRAllocator: | ||
| def __init__(self, register_file: RegFile, base=2): | ||
| def __init__(self, register_file: RegFile, base=0): |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from base=2 to base=0 alters the default starting VGPR index. This is a breaking change if any code relied on the previous default behavior where v2 was the standard voffset register. Ensure all callers are updated or explicitly pass base=2 if they need the old behavior.
| def __init__(self, register_file: RegFile, base=0): | |
| def __init__(self, register_file: RegFile, base=2): |
| def __init__( | ||
| self, | ||
| dst_regs: Tuple[int, int], | ||
| vindex_reg: Union[str, int], | ||
| srd_regs: Tuple[int, int, int, int], | ||
| offset: int, | ||
| comment: str = None, | ||
| ): | ||
| # Format vindex_reg as string if it's an integer | ||
| if isinstance(vindex_reg, int): | ||
| vindex_reg = f"v{vindex_reg}" |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern of converting vindex_reg from int to string is repeated in multiple instruction classes (BufferLoadDwordx4, BufferStoreDwordx4, BufferLoadDwordx2, BufferStoreDword). Consider extracting this into a helper function or base class method to reduce code duplication and ensure consistent formatting.
| elif multiplier > 0 and (multiplier & (multiplier - 1)) == 0: | ||
| # Power of 2: use shift | ||
| shift = multiplier.bit_length() - 1 | ||
| self.emitter.emit_instruction(VLshlRevB32(temp_v, shift, src_v)) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The power-of-2 check (multiplier & (multiplier - 1)) == 0 incorrectly treats 0 as a power of 2. While the condition multiplier > 0 prevents 0 from reaching this code, if multiplier is 0 (e.g., from constant folding), it would incorrectly enter this branch. Consider using a more explicit check like multiplier > 0 and (multiplier & (multiplier - 1)) == 0 or add a specific check for multiplier == 0.
| if divisor_val <= 0 or (divisor_val & (divisor_val - 1)) != 0: | ||
| raise ValueError(f"Mod divisor must be power-of-two, got: {divisor_val}") |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same power-of-2 validation logic appears in both _handle_mod (line 242) and _handle_floor (line 267). Consider extracting this into a helper method like _validate_power_of_two(value, operation_name) to reduce duplication and ensure consistent error messages.
| if self.register_file.a_used: | ||
| # AGPRs are allocated, compute how many we need | ||
| agprs_used = max(self.register_file.a_used) + 1 | ||
| # Round up to AGPR granularity (same as VGPR granularity) | ||
| agprs_used = ( | ||
| (agprs_used + vgpr_granularity - 1) // vgpr_granularity | ||
| ) * vgpr_granularity | ||
| # Total arch VGPRs must accommodate both VGPRs and AGPRs | ||
| total_arch_vgprs = accum_offset + agprs_used | ||
| vgprs_used = max(vgprs_used, total_arch_vgprs) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The AGPR allocation logic modifies vgprs_used to include AGPR space, but this can be confusing since the variable name suggests it only tracks VGPRs. Consider renaming to total_vgpr_arch_allocation or adding a clear comment explaining that vgprs_used represents the total architectural VGPR count including AGPR-mapped space.
| elif ( | ||
| hasattr(value, "is_integer") | ||
| and callable(value.is_integer) | ||
| and value.is_integer() | ||
| ): | ||
| # Handle float-like types that represent exact integers | ||
| kernel_info.index_env[str(operation.result)] = int(value) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling float-like types that represent exact integers is fragile. The check hasattr(value, 'is_integer') and callable(value.is_integer) and value.is_integer() may not cover all cases. Consider using isinstance(value, (int, float)) with float.is_integer() check, or catching ValueError when converting to int, for more robust type handling.
| elif ( | |
| hasattr(value, "is_integer") | |
| and callable(value.is_integer) | |
| and value.is_integer() | |
| ): | |
| # Handle float-like types that represent exact integers | |
| kernel_info.index_env[str(operation.result)] = int(value) | |
| elif isinstance(value, float) and value.is_integer(): | |
| # Handle float values that represent exact integers | |
| kernel_info.index_env[str(operation.result)] = int(value) | |
| else: | |
| # Attempt to convert other types to int if possible and safe | |
| try: | |
| int_value = int(value) | |
| # Only store if conversion does not lose information | |
| if float(int_value) == float(value): | |
| kernel_info.index_env[str(operation.result)] = int_value | |
| except (ValueError, TypeError): | |
| pass |
| if not matched_tid: | ||
| # Store the simplified SymPy expression for later ASM emission | ||
| kernel_info.index_env[destination_ssa] = simplified_expression |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment 'Store the simplified SymPy expression for later ASM emission' indicates that index_env can now contain both simple values (ints, 'tid.x' strings) and SymPy expressions. This is a significant change to the semantics of index_env. Consider documenting this in the KernelInfo class definition or adding a type hint to clarify that index_env values can be Union[int, str, sympy.Expr].
| import sympy | ||
| from .utils import build_memref_byte_offset_expr |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The import of sympy appears mid-function in handle_affine_apply_op (line 188) rather than at the module level. While this works, it's inconsistent with the module-level import on line 585 in _compute_lds_address. Consider moving all imports to the top of the file for consistency and clarity.
| """MMA kernel that computes C = A @ B^T.""" | ||
| c_reg = tkl.Register[M, N, tkl.f32](0.0) | ||
| a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) | ||
| b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result of read is used even though it is always None.
d386b83 to
bdf76f5
Compare
bdf76f5 to
fc39984
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 14 out of 14 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| kernel_info.index_env[destination_ssa] = constant_value | ||
| # Check if the simplified expression is a thread ID symbol | ||
| else: | ||
| import sympy |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import sympy statement appears inside a function (line 201) rather than at the module level. This import should be moved to the top of the file for consistency and performance. The same pattern appears at line 619. Module-level imports are executed once, while imports inside functions are executed every time the function is called.
| kernel_info.index_env[destination_ssa] = constant_value | ||
| # Check if the simplified expression is a thread ID symbol | ||
| else: | ||
| import sympy |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second occurrence of the same issue: import sympy appears inside the _compute_lds_address function at line 619. This should be moved to the module-level imports at the top of the file.
| else: | ||
| nonconst.append(a) | ||
| # Sort for commutativity | ||
| nonconst_sorted = sorted(nonconst, key=lambda x: str(x)) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| const *= int(a) | ||
| else: | ||
| nonconst.append(a) | ||
| nonconst_sorted = sorted(nonconst, key=lambda x: str(x)) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| _RationalReg = namedtuple("_RationalReg", ["numerator_reg", "denominator"]) | ||
|
|
||
| # Canonicalization helpers for CSE | ||
| _TID_SYMBOL_NAMES = {"tid_x", "tid_y", "tid_z"} |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The global variable '_TID_SYMBOL_NAMES' is not used.
| _TID_SYMBOL_NAMES = {"tid_x", "tid_y", "tid_z"} |
| a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) | ||
| b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result of read is used even though it is always None.
| a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) | |
| b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) | |
| a_reg = tkl.Register[M, K, tkl.f16](0.0) | |
| b_reg = tkl.Register[N, K, tkl.f16](0.0) | |
| tkw.read(a_reg, a, elements_per_thread=LOAD_ELEMS_PER_THREAD) | |
| tkw.read(b_reg, b, elements_per_thread=LOAD_ELEMS_PER_THREAD) |
| return ("int", int(e)) | ||
| if isinstance(e, sympy.Symbol): | ||
| return ("sym", str(e)) | ||
| if isinstance(e, sympy.Add): | ||
| return ("add", tuple(to_key(a) for a in e.args)) | ||
| if isinstance(e, sympy.Mul): | ||
| return ("mul", tuple(to_key(a) for a in e.args)) | ||
| if isinstance(e, sympy.Mod): | ||
| return ("mod", to_key(e.args[0]), to_key(e.args[1])) | ||
| if getattr(e, "func", None) == sympy.floor: | ||
| return ("floor", to_key(e.args[0])) | ||
| # Generic fallback | ||
| return ("raw", str(e)) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
| return ("int", int(e)) | |
| if isinstance(e, sympy.Symbol): | |
| return ("sym", str(e)) | |
| if isinstance(e, sympy.Add): | |
| return ("add", tuple(to_key(a) for a in e.args)) | |
| if isinstance(e, sympy.Mul): | |
| return ("mul", tuple(to_key(a) for a in e.args)) | |
| if isinstance(e, sympy.Mod): | |
| return ("mod", to_key(e.args[0]), to_key(e.args[1])) | |
| if getattr(e, "func", None) == sympy.floor: | |
| return ("floor", to_key(e.args[0])) | |
| # Generic fallback | |
| return ("raw", str(e)) | |
| return ("int", (int(e),)) | |
| if isinstance(e, sympy.Symbol): | |
| return ("sym", (str(e),)) | |
| if isinstance(e, sympy.Add): | |
| return ("add", tuple(to_key(a) for a in e.args)) | |
| if isinstance(e, sympy.Mul): | |
| return ("mul", tuple(to_key(a) for a in e.args)) | |
| if isinstance(e, sympy.Mod): | |
| return ("mod", (to_key(e.args[0]), to_key(e.args[1]))) | |
| if getattr(e, "func", None) == sympy.floor: | |
| return ("floor", (to_key(e.args[0]),)) | |
| # Generic fallback | |
| return ("raw", (str(e),)) |
fc39984 to
52d301b
Compare
ashay
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot that I simply glossed over, since I don't have enough of a background to review carefully, but overall LGTM!
| return (4, 8) | ||
| else: | ||
| # Default to CDNA granularity for unknown targets | ||
| return (4, 8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean this to be different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it was a placeholder for different architectures but will remove since we only support 2 for now.
52d301b to
b98546c
Compare
This PR adds a asm lowering of a simple mma kernel. The lhs and rhs are promoted to shared memory and then fed to a 16x16x16 MMA. A lit test and e2e test are added for correctness. Signed-off-by: Harsh Menon <harsh@nod-labs.com>
b98546c to
9887b8f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 14 out of 14 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
wave_lang/kernel/wave/asm/expression_emitter.py:1
- The check
dynamic_expr.is_zerois missing parentheses for the method call. It should bedynamic_expr.is_zero()to actually invoke the method. Without parentheses, this evaluates the method object itself which is always truthy, making this condition ineffective.
"""
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # If entirely constant, use zero VGPR | ||
| if dynamic_expr == 0 or ( | ||
| hasattr(dynamic_expr, "is_zero") and dynamic_expr.is_zero |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check dynamic_expr.is_zero is missing parentheses for the method call. It should be dynamic_expr.is_zero(). This same issue appears in multiple locations (lines 521-523, 555-557, 755-757) and needs to be fixed consistently throughout the file.
| hasattr(dynamic_expr, "is_zero") and dynamic_expr.is_zero | |
| hasattr(dynamic_expr, "is_zero") and dynamic_expr.is_zero() |
| const *= int(a) | ||
| else: | ||
| nonconst.append(a) | ||
| nonconst_sorted = sorted(nonconst, key=lambda x: str(x)) |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| nonconst_sorted = sorted(nonconst, key=lambda x: str(x)) | |
| nonconst_sorted = sorted(nonconst, key=str) |
| """Build a structural, hashable key for an expression.""" | ||
| expr = canonicalize_expr(expr) | ||
|
|
||
| def to_key(e): |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
expr_key.to_key returns tuple of size 2 and tuple of size 3.
This PR adds a asm lowering of a simple mma kernel. The lhs and rhs are promoted to shared memory and then fed to a 16x16x16 MMA. A lit test and e2e test are added for correctness.