-
Notifications
You must be signed in to change notification settings - Fork 25
import the second batch of the wave dialect commits #338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2feb1a9 to
2d52401
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR imports the second batch of Wave dialect commits, introducing elements-per-thread propagation, vector type support, and additional Wave dialect attributes. The changes focus on enhancing the Wave dialect's type system and lowering infrastructure.
- Implements elements-per-thread dataflow analysis for register-resident tensors
- Adds support for vector types alongside Wave tensor types in operations
- Introduces new attributes (WaveAddressSpaceAttr, WaveDistributedShapeAttr, WaveReadWriteBoundsAttr)
Reviewed Changes
Copilot reviewed 27 out of 27 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| water/test/python/sympy_converter.py | Updated symbol handling to use string lists instead of SymPy Symbol objects |
| water/test/Dialect/Wave/python_bindings.py | Added tests for new Wave attributes (address space, distributed shape) |
| water/test/Dialect/Wave/propagate-elements-per-thread.mlir | New test file for elements-per-thread propagation pass |
| water/test/Dialect/Wave/ops.mlir | Added tests for index magic symbols and write bounds |
| water/test/Dialect/Wave/ops-invalid.mlir | Updated error messages and added validation tests |
| water/test/Dialect/Wave/lower-wave-to-mlir.mlir | Updated lowering tests for new normal forms and vector types |
| water/test/Dialect/Wave/attr-type-invalid.mlir | Added validation test for elements_per_thread attribute |
| water/python/water_mlir/sympy_to_affine_converter.py | Refactored to accept string symbols instead of SymPy Symbol objects |
| water/python/WaterExtensionNanobind.cpp | Added Python bindings for new Wave attributes |
| water/lib/Dialect/Wave/Transforms/TypeConverter.cpp | Enhanced type converter with vector type support |
| water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp | Updated lowering patterns for vector types |
| water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp | Added hyperparameter validation and normal form checks |
| water/lib/Dialect/Wave/Transforms/InferTypes.cpp | Added elements-per-thread propagation analysis implementation |
| water/lib/Dialect/Wave/IR/WaveOps.cpp | Enhanced operation verification and type checking |
| water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | Added elements-per-thread interface implementations |
| water/lib/Dialect/Wave/IR/WaveDialect.cpp | Enhanced attribute verification and hyperparameter validation |
| water/lib/Dialect/Wave/IR/WaveAttrs.cpp | Added new attribute implementations and normal form verification |
| water/lib/CAPI/Dialects.cpp | Added C API bindings for new attributes |
| water/include/water/c/Dialects.h | Added C API declarations for new attributes |
| water/include/water/Dialect/Wave/Transforms/Passes.td | Added elements-per-thread propagation pass definition |
| water/include/water/Dialect/Wave/IR/WaveTypes.td | Updated type constraints for vector support |
| water/include/water/Dialect/Wave/IR/WaveTypes.h | Added utility function for register tensor type checking |
| water/include/water/Dialect/Wave/IR/WaveOps.td | Enhanced operation definitions with new attributes and interfaces |
| water/include/water/Dialect/Wave/IR/WaveInterfaces.td | Added elements-per-thread interface and trait definitions |
| water/include/water/Dialect/Wave/IR/WaveInterfaces.h | Added elements-per-thread lattice value and interface implementations |
| water/include/water/Dialect/Wave/IR/WaveDialect.td | Updated dialect documentation and added elements_per_thread attribute name |
| water/include/water/Dialect/Wave/IR/WaveAttrs.td | Added new attribute definitions (WaveReadWriteBoundsAttr) |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| if (diag.getSeverity() == mlir::DiagnosticSeverity::Error) | ||
| emittedError = true; | ||
|
|
||
| // Returning failure indicates that the diagnostic wan't handled |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'wan't' to 'wasn't'.
| // Returning failure indicates that the diagnostic wan't handled | |
| // Returning failure indicates that the diagnostic wasn't handled |
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // WaveDistributedShapeAttr | ||
| // DistributedShapeAttr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest of the code references this as WaveDistributedShapeAttr so why this change?
|
@ftynse @martin-luecke - this looks good to me! please let me know when you have a branch rebased on main and I can land that. |
Original review: ftynse/water#49 Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#50 Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#51 Signed-off-by: tyb0807 <sontuan.vu@amd.com>
affine map Original review: ftynse/water#52 When converting sympy expression to affine map, we only care about the symbol names and not the exact `sympy.Symbol` objects. This allows the caller of `convert_sympy_to_affine_map` to freely add assumptions to the symbols (hence creating a new symbol with the same name) and substitute the original symbols in the expression. Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#54 Add a dialect-level verifier checking that any occurrence of Wave symbols, either as attributes or as stirng name in relevant dictionaries, reference symbols listed as hyperparameters. Also report unused symbols. --------- Signed-off-by: Alex Zinenko <git@ozinenko.com> Signed-off-by: tyb0807 <sontuan.vu@amd.com> Co-authored-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#55 Completely overhaul the lowering procedure for register-resident Wave dialect tensors. They are representing values that are virtually distributed across multiple threads. When lowering to core MLIR dialects, we also perform a programming model change: the lowered IR is expected to be SIMT. Therefore, register-resident Wave dialect tensors should be converted to vectors of a _distributed_ shape rather than a shape that comes from direct substitution of symbolic values from hyperparameters (unlike memory-resident tensors). The distributed per-thread shape of the vector is derived from the "elements per thread" property available on memory operations, in hyperparameters and compilation options, or extracted from intrinsic structure of the MMA operation. Implement a dataflow analysis for propagating this value from the property for now and report eventual conflicts. Based on the above dataflow analysis, implement a pass rewriting register-resident Wave dialect tensors to vectors with distributed shapes. Note that this _cannot_ be implemented as a dialect conversion (or at least is significantly complicated) since the type conversion depends on the result of the analysis, and not only on the original type or defining operation. This in turn requires extending Wave dialect operations to also accept vector-typed values whenever a register-resident tensor is accepted. This is not a significant departure from the original representation as both tensors and vectors are immutable SSA values. Introduce an additional normalform checking for the absence of register-resident tensors in the input IR and make it a precondition of the lowering pass, and a postcondition of the elements-per-thread propagation pass. --------- Signed-off-by: Alex Zinenko <git@ozinenko.com> Signed-off-by: tyb0807 <sontuan.vu@amd.com> Co-authored-by: tyb0807 <sontuan.vu@amd.com>
…ead/write ops Original review: ftynse/water#53 Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#58 Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#48 Minimalist lowering for read/write with innermost vectorization and masking. Non-innermost vectorization and scatter/gather not yet supported. Buffers ops emission is not yet supported and may be better of as a reusable optimization. --------- Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com> Signed-off-by: Alex Zinenko <git@ozinenko.com> Co-authored-by: Alex Zinenko <git@ozinenko.com>
Original review: ftynse/water#25 Signed-off-by: Tim Gymnich <tim.gymnich@amd.com>
Original review: ftynse/water#60 This PR renames the attribute "DistributedShapeAttr" to "ExprAttr". This generalizes the attribute so it can represent any tuple of affine expressions over Wave symbols. Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Original review: ftynse/water#63 Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#62 This PR adds support for lowering read/write along non innermost dimensions. It makes the read/write lowering choose the most appropriate Vector op depending on the memory layout: * **Fast path (contiguous minor dim):** If we vectorize the trailing (innermost) memref dimension and that dimension has stride 1, we emit vector.load / vector.store (or their masked variants). * **General path (non-contiguous):** Otherwise we emit vector.transfer_read / vector.transfer_write with a 1D permutation map (assuming vectors of Rank-1 )selecting the vectorized dimension. **Details** * New helper isMinorContiguous() Returns true iff vectorizeddim == rank-1 and minorStride == 1 (using memref::getStridesAndOffset). Failure if layout isn’t representable as simple strides, in which case we conservatively use transfer ops. * New helper make1DTransferCommonAttrs() Builds: * perm: affine_map<(d0, ..., d{R-1}) -> (d_vdim)> * in_bounds: [true] and [false] Usage: **Unmasked:** vector.transfer_read/write we set in_bounds = [false] , the compiler will emit checks that indices are within bounds of memref shape **Masked:** vector.transfer_read/write we set in_bounds = [true] as we can assume the mask already makes sur we are in bounds. **Tests** * Added/updated checks to: * Match vector.transfer_read/write with the expected permutation map and in bounds Attribute * Verify the constant padding emission and usage of mask For now in the TypeConverter the memref layouts are set to empty (MemRefLayoutAttrInterface{}) (row major layout assumed and therefore all innermost dimesnion that are vctorized have stride 1) **Next**: Add tests with explicit MemRef layouts using strided<...> (e.g., strided<[128, 1]> vs strided<[1, 128]>) to exercise both the contiguous and non-contiguous paths --------- Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Original review: ftynse/water#65 This is required since symbols prefixed by `$` are not accepted in affine expressions. Signed-off-by: tyb0807 <sontuan.vu@amd.com>
2d52401 to
c2451c3
Compare
Please review per commit