Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout by dropping InstData",
"xegpu::DistributeLayoutAttr",
"dropInstData">,
InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
"xegpu::DistributeLayoutAttr",
"setUnitDimData",
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
"xegpu::DistributeLayoutAttr",
"setUnitDimLayout",
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
indices based on the effective layout level.}],
"FailureOr<SmallVector<Value>>",
Expand Down Expand Up @@ -283,9 +291,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
}
return true;
}]>,
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
/*retTy=*/"bool",
/*methodName=*/"isSliceOf",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,

InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
/*retTy=*/"bool",
/*methodName=*/"isEqualTo",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
];
}
Expand Down Expand Up @@ -487,6 +500,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return {};
}

//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);

//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);

/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
FailureOr<SmallVector<Value>>
Expand All @@ -501,6 +520,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {

/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }

/// Check if this is identical to some other layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);

}];

Expand Down Expand Up @@ -649,6 +671,12 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
return SliceAttr::get(getContext(), parent, attr.getDims());
}

//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);

//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);

/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
/// it will coalese two slice operations and return a simplified SliceAttr
Expand All @@ -670,7 +698,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {

/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);


/// Check if this is identical to some other layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
}];

let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
OptionalAttr<DistributeLayoutAttr>:$layout);

let results = (outs XeGPU_ValueType: $value);
Expand Down
143 changes: 143 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,86 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
return genCoordinates(builder, loc, ids, layout, subShape, shape);
}

bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
if (dyn_cast<xegpu::SliceAttr>(other))
return false;

return *this == dyn_cast<xegpu::LayoutAttr>(other);
}

// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
auto sgDataOpt = getSgData();
auto instDataOpt = getInstData();
auto laneDataOpt = getLaneData();

SmallVector<int32_t> sgData;
SmallVector<int32_t> instData;
SmallVector<int32_t> laneData;

if (sgDataOpt) {
sgData = llvm::to_vector(sgDataOpt.asArrayRef());
}
if (instDataOpt) {
instData = llvm::to_vector(instDataOpt.asArrayRef());
}
if (laneDataOpt) {
laneData = llvm::to_vector(laneDataOpt.asArrayRef());
}

for (auto dim : unitDims) {
if (dim < static_cast<int64_t>(sgData.size()))
sgData[dim] = 1;
if (dim < static_cast<int64_t>(instData.size()))
instData[dim] = 1;
if (dim < static_cast<int64_t>(laneData.size()))
laneData[dim] = 1;
}

return LayoutAttr::get(
getContext(), getSgLayout(),
sgData.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), sgData),
instData.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), instData),
getLaneLayout(),
laneData.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), laneData),
getOrder());
}

// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
auto sgLayoutOpt = getSgLayout();
auto laneLayoutOpt = getLaneLayout();

SmallVector<int32_t> sgLayout;
SmallVector<int32_t> laneLayout;

if (sgLayoutOpt) {
sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
}
if (laneLayoutOpt) {
laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
}

for (auto dim : unitDims) {
if (dim < static_cast<int64_t>(sgLayout.size()))
sgLayout[dim] = 1;
if (dim < static_cast<int64_t>(laneLayout.size()))
laneLayout[dim] = 1;
}

return LayoutAttr::get(
getContext(),
sgLayout.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), sgLayout),
getSgData(), getInstData(),
laneLayout.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), laneLayout),
getLaneData(), getOrder());
}

//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -510,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
[&](int64_t dim) { return thisDims.contains(dim); });
}

bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
if (dyn_cast<xegpu::LayoutAttr>(other))
return false;

auto flattenedThis = flatten();
auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();

return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
(flattenedThis.getDims() == flattenedOther.getDims()));
}

// Helper function to adjust unit dimensions from sliced space to parent space
static SetVector<int64_t>
adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
ArrayRef<int64_t> sliceDims) {
// Reconstruct parent's non-sliced dimensions

int64_t parentRank = sliceDims.size() + unitDims.size();
llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
sliceDims.end());
SmallVector<int64_t> nonSlicedDims;
for (int64_t i = 0; i < parentRank; ++i) {
if (!slicedDimsSet.contains(i))
nonSlicedDims.push_back(i);
}

// Map unit dims from sliced space to parent space
SetVector<int64_t> adjustUnitDims;
for (auto dim : unitDims) {
if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
adjustUnitDims.insert(nonSlicedDims[dim]);
}
}

return adjustUnitDims;
}

// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
SliceAttr attr = flatten();
ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());

SetVector<int64_t> adjustUnitDims =
adjustUnitDimsWithSliceDims(unitDims, sliceDims);

return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
attr.getDims());
}

// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
SliceAttr attr = flatten();
ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());

SetVector<int64_t> adjustUnitDims =
adjustUnitDimsWithSliceDims(unitDims, sliceDims);

return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
attr.getDims());
}

//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
Expand Down
42 changes: 29 additions & 13 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
// Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
if (!sourceTy) {
broadcast.emitWarning("Expecting source type to be a vector type.");
// skip layout propagation for non-vector source operand.
if (!sourceTy)
return;
}

// Only consider nD -> nD broadcast.
// Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
if (sourceTy.getRank() != resultTy.getRank()) {
broadcast.emitWarning("Expecting source and result to have same rank.");
auto sourceDims = sourceTy.getShape();
auto resultDims = resultTy.getShape();
SmallVector<int64_t> bcastDims;
auto dimDiff = resultTy.getRank() - sourceTy.getRank();
// adding the missing leading dims
for (int i = 0; i < dimDiff; i++)
bcastDims.push_back(i);

// for the rest dims in the resultTy, if sourceTy dim is 1, then it's
// broadcasted dim
for (size_t i = 0; i < sourceDims.size(); i++)
if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
bcastDims.push_back(i + dimDiff);

// create a slice layout for the source
xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
broadcast->getContext(),
cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));

propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}

SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
if (broadcastUnitDims.size() != 1) {
broadcast.emitWarning("Expecting source type to be nD vector only with "
"one broadcasted dimension.");
return;
}
// Propagate the result layout to the source operand.
resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
.setUnitDimData(broadcastUnitDims);
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}

Expand Down Expand Up @@ -917,7 +933,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
} else {

// The layout is strictly determined by the payload type.
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
VectorType payloadTy = load.getValueType();
if (!payloadTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
Expand Down Expand Up @@ -987,7 +1003,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
// Currently, for 2D StoreScatterOp we expect that the height dimension of
// the tensor descriptor is equal to the subgroup size. This is ensured by
// the op verifier.
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
VectorType payloadTy = storeScatter.getValueType();
if (!payloadTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
Expand Down
Loading