diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index cfe3e800484ce..aa3e3c5cddc05 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -231,7 +231,55 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { multiple blocks according to round-robin distribution rules.}], "FailureOr>>", "getOffsets", - (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)> + (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)>, + InterfaceMethod": $perm), + /*methodBody=*/[{ + if (!other) + return false; + if ($_self.getRank() != other.getRank() || perm.size() != static_cast($_self.getRank())) + return false; + // check if the permutation is valid + int64_t rank = $_self.getRank(); + SmallVector seen(rank, false); + for (const auto &ta : llvm::enumerate(perm)) { + if (ta.value() < 0 || ta.value() >= rank) + return false; + if (seen[ta.value()]) + return false; + seen[ta.value()] = true; + } + auto checkTranspose = [](ArrayRef dst, ArrayRef src, ArrayRef perm) { + for (const auto &ta : llvm::enumerate(perm)) { + if (src[ta.index()] != dst[ta.value()]) + return false; + } + return true; + }; + // check sgLayout + if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm)) + return false; + // check sgData + if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm)) + return false; + // check instData + if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm)) + return false; + // check laneLayout + if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm)) + return false; + // check laneData + if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm)) + return false; + return true; + }]>, + InterfaceMethod ]; } @@ -433,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { FailureOr>> getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + /// Check if this is slice of some other layout. + bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; } + }]; let assemblyFormat = "`<` struct(params) `>`"; @@ -594,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { FailureOr>> getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + /// Check if this is slice of some other layout. + bool isSliceOf(const xegpu::DistributeLayoutAttr &other); + }]; let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`"; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 7f3be7f91c56b..cc133b110c95a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -409,6 +410,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, shape); } +bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { + auto flattenedThis = flatten(); + // If other is a LayoutAttr, just compare directly with parent of + // flattenedThis. + if (auto otherLayout = dyn_cast(other)) + return flattenedThis.getParent() == otherLayout; + // If other is a SliceAttr, flatten it first before comparing. + auto flattenedOther = dyn_cast(other).flatten(); + // Both must have common parent LayoutAttr. + if (flattenedThis.getParent() != flattenedOther.getParent()) + return false; + // otherFlattened's sliced dims must be a subset of flattenedThis's sliced + // dims. + llvm::SmallDenseSet thisDims( + flattenedThis.getDims().asArrayRef().begin(), + flattenedThis.getDims().asArrayRef().end()); + return llvm::all_of(flattenedOther.getDims().asArrayRef(), + [&](int64_t dim) { return thisDims.contains(dim); }); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===//