Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
OptionalAttr<XeGPU_LayoutAttr>:$layout);
OptionalAttr<DistributeLayoutAttr>:$layout);
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Expand Down Expand Up @@ -903,7 +903,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint,
"xegpu::LayoutAttr": $layout)>
"xegpu::DistributeLayoutAttr": $layout)>
];

let hasVerifier = 1;
Expand Down Expand Up @@ -988,7 +988,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
OptionalAttr<XeGPU_LayoutAttr>:$layout);
OptionalAttr<DistributeLayoutAttr>:$layout);

let extraClassDeclaration = extraBaseClassDeclaration#[{
Type getDestType() {
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint,
"xegpu::LayoutAttr": $layout)>
"xegpu::DistributeLayoutAttr": $layout)>
];

let hasVerifier = 1;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint,
xegpu::LayoutAttr layout) {
DistributeLayoutAttr layout) {
auto loc = source.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
Expand Down Expand Up @@ -985,7 +985,7 @@ void StoreScatterOp::build(
OpBuilder &builder, OperationState &state, Value value, Value dest,
ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
auto loc = dest.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ struct UnrollLoadGatherOpWithOffset
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
}

auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
auto layout = op.getLayoutAttr();
if (layout)
layout = layout.dropInstData();

Expand Down Expand Up @@ -778,7 +778,7 @@ struct UnrollStoreScatterOpWithOffsets
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);

auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
auto layout = op.getLayoutAttr();
if (layout)
layout = layout.dropInstData();

Expand Down
12 changes: 7 additions & 5 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,8 @@ struct WgToSgLoadGatherOpWithOffset
return failure();
ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
xegpu::getDistributeLayoutAttr(op.getResult()));
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();

Expand All @@ -913,10 +913,12 @@ struct WgToSgLoadGatherOpWithOffset
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
auto newLayout = layout.dropSgLayoutAndData();
auto newLoadOp = xegpu::LoadGatherOp::create(
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
layout.dropSgLayoutAndData());
newLayout);
xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
Expand All @@ -941,8 +943,8 @@ struct WgToSgStoreScatterOpWithOffset
if (!valueType)
return failure();

xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
xegpu::getDistributeLayoutAttr(op.getOperand(0)));
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getOperand(0));
if (!layout || !layout.isForWorkgroup())
return failure();

Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,21 @@ gpu.module @test_distribution {
%broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
gpu.return
}

// CHECK-LABEL: distribute_load_slice_attr
gpu.func @distribute_load_slice_attr() {
%2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>

// CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>}>
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>} :
// CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
%3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>

// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32xf32> to vector<32x32xf32>
%4 = vector.broadcast %3 {layout_result_0 =
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
gpu.return
}
}
Loading