Skip to content

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Oct 3, 2025

Please review per commit

@ftynse ftynse force-pushed the users/ftynse/wave-dialect-2 branch from 2feb1a9 to 2d52401 Compare October 3, 2025 15:25
@harsh-nod harsh-nod requested a review from Copilot October 6, 2025 18:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR imports the second batch of Wave dialect commits, introducing elements-per-thread propagation, vector type support, and additional Wave dialect attributes. The changes focus on enhancing the Wave dialect's type system and lowering infrastructure.

  • Implements elements-per-thread dataflow analysis for register-resident tensors
  • Adds support for vector types alongside Wave tensor types in operations
  • Introduces new attributes (WaveAddressSpaceAttr, WaveDistributedShapeAttr, WaveReadWriteBoundsAttr)

Reviewed Changes

Copilot reviewed 27 out of 27 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
water/test/python/sympy_converter.py Updated symbol handling to use string lists instead of SymPy Symbol objects
water/test/Dialect/Wave/python_bindings.py Added tests for new Wave attributes (address space, distributed shape)
water/test/Dialect/Wave/propagate-elements-per-thread.mlir New test file for elements-per-thread propagation pass
water/test/Dialect/Wave/ops.mlir Added tests for index magic symbols and write bounds
water/test/Dialect/Wave/ops-invalid.mlir Updated error messages and added validation tests
water/test/Dialect/Wave/lower-wave-to-mlir.mlir Updated lowering tests for new normal forms and vector types
water/test/Dialect/Wave/attr-type-invalid.mlir Added validation test for elements_per_thread attribute
water/python/water_mlir/sympy_to_affine_converter.py Refactored to accept string symbols instead of SymPy Symbol objects
water/python/WaterExtensionNanobind.cpp Added Python bindings for new Wave attributes
water/lib/Dialect/Wave/Transforms/TypeConverter.cpp Enhanced type converter with vector type support
water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp Updated lowering patterns for vector types
water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp Added hyperparameter validation and normal form checks
water/lib/Dialect/Wave/Transforms/InferTypes.cpp Added elements-per-thread propagation analysis implementation
water/lib/Dialect/Wave/IR/WaveOps.cpp Enhanced operation verification and type checking
water/lib/Dialect/Wave/IR/WaveInterfaces.cpp Added elements-per-thread interface implementations
water/lib/Dialect/Wave/IR/WaveDialect.cpp Enhanced attribute verification and hyperparameter validation
water/lib/Dialect/Wave/IR/WaveAttrs.cpp Added new attribute implementations and normal form verification
water/lib/CAPI/Dialects.cpp Added C API bindings for new attributes
water/include/water/c/Dialects.h Added C API declarations for new attributes
water/include/water/Dialect/Wave/Transforms/Passes.td Added elements-per-thread propagation pass definition
water/include/water/Dialect/Wave/IR/WaveTypes.td Updated type constraints for vector support
water/include/water/Dialect/Wave/IR/WaveTypes.h Added utility function for register tensor type checking
water/include/water/Dialect/Wave/IR/WaveOps.td Enhanced operation definitions with new attributes and interfaces
water/include/water/Dialect/Wave/IR/WaveInterfaces.td Added elements-per-thread interface and trait definitions
water/include/water/Dialect/Wave/IR/WaveInterfaces.h Added elements-per-thread lattice value and interface implementations
water/include/water/Dialect/Wave/IR/WaveDialect.td Updated dialect documentation and added elements_per_thread attribute name
water/include/water/Dialect/Wave/IR/WaveAttrs.td Added new attribute definitions (WaveReadWriteBoundsAttr)

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

if (diag.getSeverity() == mlir::DiagnosticSeverity::Error)
emittedError = true;

// Returning failure indicates that the diagnostic wan't handled
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'wan't' to 'wasn't'.

Suggested change
// Returning failure indicates that the diagnostic wan't handled
// Returning failure indicates that the diagnostic wasn't handled

Copilot uses AI. Check for mistakes.

//===----------------------------------------------------------------------===//
// WaveDistributedShapeAttr
// DistributedShapeAttr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the code references this as WaveDistributedShapeAttr so why this change?

@harsh-nod
Copy link
Collaborator

@ftynse @martin-luecke - this looks good to me! please let me know when you have a branch rebased on main and I can land that.

tyb0807 and others added 14 commits October 14, 2025 21:56
Original review: ftynse/water#49

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#50

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#51

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
affine map

Original review: ftynse/water#52

When converting sympy expression to affine map, we only care about the
symbol names and not the exact `sympy.Symbol` objects. This allows the
caller of `convert_sympy_to_affine_map` to freely add assumptions to the
symbols (hence creating a new symbol with the same name) and substitute
the original symbols in the expression.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#54

Add a dialect-level verifier checking that any occurrence of Wave
symbols, either as attributes or as stirng name in relevant
dictionaries, reference symbols listed as hyperparameters. Also report
unused symbols.

---------

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Co-authored-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#55

Completely overhaul the lowering procedure for register-resident Wave
dialect tensors. They are representing values that are virtually
distributed across multiple threads. When lowering to core MLIR
dialects, we also perform a programming model change: the lowered IR is
expected to be SIMT. Therefore, register-resident Wave dialect tensors
should be converted to vectors of a _distributed_ shape rather than a
shape that comes from direct substitution of symbolic values from
hyperparameters (unlike memory-resident tensors).

The distributed per-thread shape of the vector is derived from the
"elements per thread" property available on memory operations, in
hyperparameters and compilation options, or extracted from intrinsic
structure of the MMA operation. Implement a dataflow analysis for
propagating this value from the property for now and report eventual
conflicts.

Based on the above dataflow analysis, implement a pass rewriting
register-resident Wave dialect tensors to vectors with distributed
shapes. Note that this _cannot_ be implemented as a dialect conversion
(or at least is significantly complicated) since the type conversion
depends on the result of the analysis, and not only on the original type
or defining operation.

This in turn requires extending Wave dialect operations to also accept
vector-typed values whenever a register-resident tensor is accepted.
This is not a significant departure from the original representation as
both tensors and vectors are immutable SSA values.

Introduce an additional normalform checking for the absence of
register-resident tensors in the input IR and make it a precondition of
the lowering pass, and a postcondition of the elements-per-thread
propagation pass.

---------

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Co-authored-by: tyb0807 <sontuan.vu@amd.com>
…ead/write ops

Original review: ftynse/water#53

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#58

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#48

Minimalist lowering for read/write with innermost vectorization and
masking. Non-innermost vectorization and scatter/gather not yet
supported. Buffers ops emission is not yet supported and may be better
of as a reusable optimization.

---------

Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Co-authored-by: Alex Zinenko <git@ozinenko.com>
Original review: ftynse/water#25

Signed-off-by: Tim Gymnich <tim.gymnich@amd.com>
Original review: ftynse/water#60

This PR renames the attribute "DistributedShapeAttr" to "ExprAttr".
This generalizes the attribute so it can represent any tuple of affine expressions over Wave symbols.

Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Original review: ftynse/water#63

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Original review: ftynse/water#62

This PR adds support for lowering read/write along non innermost
dimensions.
It makes the read/write lowering choose the most appropriate Vector op
depending on the memory layout:
* **Fast path (contiguous minor dim):** If we vectorize the trailing
(innermost) memref dimension and that dimension has stride 1, we emit
vector.load / vector.store (or their masked variants).
* **General path (non-contiguous):** Otherwise we emit
vector.transfer_read / vector.transfer_write with a 1D permutation map
(assuming vectors of Rank-1 )selecting the vectorized dimension.

**Details**
* New helper isMinorContiguous() 
Returns true iff vectorizeddim ==
rank-1 and minorStride == 1 (using memref::getStridesAndOffset). Failure
if layout isn’t representable as simple strides, in which case we
conservatively use transfer ops.
* New helper make1DTransferCommonAttrs() 
Builds:
  * perm: affine_map<(d0, ..., d{R-1}) -> (d_vdim)>
  * in_bounds: [true] and [false]
Usage:
**Unmasked:** vector.transfer_read/write we set in_bounds = [false] ,
the compiler will emit checks that indices are within bounds of memref
shape
**Masked:** vector.transfer_read/write we set in_bounds = [true] as we
can assume the mask already makes sur we are in bounds.

**Tests**
* Added/updated checks to:
* Match vector.transfer_read/write with the expected permutation map and
in bounds Attribute
  * Verify the constant padding emission and usage of mask

For now in the TypeConverter the memref layouts are set to empty
(MemRefLayoutAttrInterface{})
(row major layout assumed and therefore all innermost dimesnion that are
vctorized have stride 1)

**Next**: Add tests with explicit MemRef layouts using strided<...>
(e.g., strided<[128, 1]> vs strided<[1, 128]>) to exercise both the
contiguous and non-contiguous paths

---------

Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Original review: ftynse/water#65

This is required since symbols prefixed by `$` are not accepted in
affine expressions.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
@ftynse ftynse force-pushed the users/ftynse/wave-dialect-2 branch from 2d52401 to c2451c3 Compare October 14, 2025 21:56
@harsh-nod harsh-nod merged commit c2451c3 into main Oct 14, 2025
17 checks passed
@harsh-nod harsh-nod deleted the users/ftynse/wave-dialect-2 branch October 14, 2025 22:48
@ftynse ftynse changed the title DO NOT MERGE: import the second batch of the wave dialect commits import the second batch of the wave dialect commits Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants