From be1c00cc486c3b2fe69c13b5477df5be8bd1c70e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 10 Sep 2025 17:51:57 +0000 Subject: [PATCH 1/3] add transpose function --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index cfe3e800484ce..24756318e4339 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -231,7 +231,51 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { multiple blocks according to round-robin distribution rules.}], "FailureOr>>", "getOffsets", - (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)> + (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)>, + InterfaceMethod": $perm), + /*methodBody=*/[{ + if (!other) + return false; + if ($_self.getRank() != other.getRank() || perm.size() != static_cast($_self.getRank())) + return false; + // check if the permutation is valid + int64_t rank = $_self.getRank(); + SmallVector seen(rank, false); + for (const auto &ta : llvm::enumerate(perm)) { + if (ta.value() < 0 || ta.value() >= rank) + return false; + if (seen[ta.value()]) + return false; + seen[ta.value()] = true; + } + auto checkTranspose = [](ArrayRef dst, ArrayRef src, ArrayRef perm) { + for (const auto &ta : llvm::enumerate(perm)) { + if (src[ta.index()] != dst[ta.value()]) + return false; + } + return true; + }; + // check sgLayout + if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm)) + return false; + // check sgData + if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm)) + return false; + // check instData + if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm)) + return false; + // check laneLayout + if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm)) + return false; + // check laneData + if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm)) + return false; + return true; + }]> ]; } From 916c75f12298f76b2f8c6e2b5645125e75d34a73 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 10 Sep 2025 23:15:18 +0000 Subject: [PATCH 2/3] add slice attribute utils --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 12 ++++++++++- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 24756318e4339..aa3e3c5cddc05 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -275,7 +275,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm)) return false; return true; - }]> + }]>, + InterfaceMethod ]; } @@ -477,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { FailureOr>> getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + /// Check if this is slice of some other layout. + bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; } + }]; let assemblyFormat = "`<` struct(params) `>`"; @@ -638,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { FailureOr>> getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + /// Check if this is slice of some other layout. + bool isSliceOf(const xegpu::DistributeLayoutAttr &other); + }]; let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`"; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 7f3be7f91c56b..a3783d5e05df6 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -409,6 +410,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, shape); } +bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { + auto flattenedThis = flatten(); + // If other is a LayoutAttr, just compare directly with parent of + // flattenedThis. + if (auto otherLayout = dyn_cast(other)) + return flattenedThis.getParent() == otherLayout; + // If other is a SliceAttr, flatten it first before comparing. + auto otherFlattened = dyn_cast(other).flatten(); + // Both must have common parent LayoutAttr. + if (flattenedThis.getParent() != otherFlattened.getParent()) + return false; + // otherFlattened's sliced dims must be a subset of flattenedThis's sliced + // dims. + llvm::SmallDenseSet thisDims( + flattenedThis.getDims().asArrayRef().begin(), + flattenedThis.getDims().asArrayRef().end()); + return llvm::all_of(otherFlattened.getDims().asArrayRef(), + [&](int64_t dim) { return thisDims.contains(dim); }); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// From 77e8a9477dbd76bf95e5d142a0a6e6a4596ab3d2 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 10 Sep 2025 23:54:57 +0000 Subject: [PATCH 3/3] fix name --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index a3783d5e05df6..cc133b110c95a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -417,16 +417,16 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { if (auto otherLayout = dyn_cast(other)) return flattenedThis.getParent() == otherLayout; // If other is a SliceAttr, flatten it first before comparing. - auto otherFlattened = dyn_cast(other).flatten(); + auto flattenedOther = dyn_cast(other).flatten(); // Both must have common parent LayoutAttr. - if (flattenedThis.getParent() != otherFlattened.getParent()) + if (flattenedThis.getParent() != flattenedOther.getParent()) return false; // otherFlattened's sliced dims must be a subset of flattenedThis's sliced // dims. llvm::SmallDenseSet thisDims( flattenedThis.getDims().asArrayRef().begin(), flattenedThis.getDims().asArrayRef().end()); - return llvm::all_of(otherFlattened.getDims().asArrayRef(), + return llvm::all_of(flattenedOther.getDims().asArrayRef(), [&](int64_t dim) { return thisDims.contains(dim); }); }