Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,55 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
multiple blocks according to round-robin distribution rules.}],
"FailureOr<SmallVector<SmallVector<Value>>>",
"getOffsets",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
to some other layout according to given permutation of (0...n-1).}],
/*retTy=*/"bool",
/*methodName=*/"isTransposeOf",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other, "ArrayRef<int64_t>": $perm),
/*methodBody=*/[{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the other PR - nit: could you move it to the source file too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @adam-smnk , please ignore this PR. changes are already there with #155517

if (!other)
return false;
if ($_self.getRank() != other.getRank() || perm.size() != static_cast<size_t>($_self.getRank()))
return false;
// check if the permutation is valid
int64_t rank = $_self.getRank();
SmallVector<bool, 8> seen(rank, false);
for (const auto &ta : llvm::enumerate(perm)) {
if (ta.value() < 0 || ta.value() >= rank)
return false;
if (seen[ta.value()])
return false;
seen[ta.value()] = true;
}
auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
for (const auto &ta : llvm::enumerate(perm)) {
if (src[ta.index()] != dst[ta.value()])
return false;
}
return true;
};
// check sgLayout
if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm))
return false;
// check sgData
if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm))
return false;
// check instData
if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm))
return false;
// check laneLayout
if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm))
return false;
// check laneData
if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order attribute needs to be transposed.

return false;
return true;
}]>,
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
/*retTy=*/"bool",
/*methodName=*/"isSliceOf",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
];
}

Expand Down Expand Up @@ -433,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);

/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }

}];

let assemblyFormat = "`<` struct(params) `>`";
Expand Down Expand Up @@ -594,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);

/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

}];

let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -409,6 +410,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
shape);
}

bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
auto flattenedThis = flatten();
// If other is a LayoutAttr, just compare directly with parent of
// flattenedThis.
if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
return flattenedThis.getParent() == otherLayout;
// If other is a SliceAttr, flatten it first before comparing.
auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
// Both must have common parent LayoutAttr.
if (flattenedThis.getParent() != flattenedOther.getParent())
return false;
// otherFlattened's sliced dims must be a subset of flattenedThis's sliced
// dims.
llvm::SmallDenseSet<int64_t> thisDims(
flattenedThis.getDims().asArrayRef().begin(),
flattenedThis.getDims().asArrayRef().end());
return llvm::all_of(flattenedOther.getDims().asArrayRef(),
[&](int64_t dim) { return thisDims.contains(dim); });
}

//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
Expand Down