diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index bdf06a3a35438..73f9061f5debe 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -887,6 +887,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let builders = [ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Type": $value, "Value": $source, + "ArrayRef": $offsets, "Value": $mask, + "IntegerAttr": $chunk_size, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> @@ -1016,6 +1022,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let builders = [ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Value": $value, "Value": $dest, + "ArrayRef": $offsets, "Value": $mask, + "IntegerAttr": $chunk_size, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 7036996a68e0d..aca6654d89001 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -819,6 +819,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = source.getLoc(); + int64_t size = static_cast(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// @@ -870,6 +886,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, l2_hint, l3_hint); } +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, + ArrayRef offsets, Value mask, + IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = dest.getLoc(); + int64_t size = static_cast(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0b7fe81facfce..9f627c7e1e6d8 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -765,6 +765,110 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; +// This pattern transforms the LoadGatherOp with explicit offsets to load +// subgroup data +struct WgToSgLoadGatherOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return failure(); + ArrayRef wgShape = resultType.getShape(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + + // The offsets need to be distributed + auto offsetsVecType = + dyn_cast(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + + SmallVector newLoadOps; + auto chunkSizeAttr = + rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); + VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); + for (auto [offsets, mask] : + llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLoadOp = rewriter.create( + loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), + layout.dropSgLayoutAndData()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return success(); + } +}; + +// This pattern transforms the StoreScatterOp with explicit offsets to store +// subgroup data +struct WgToSgStoreScatterOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType valueType = dyn_cast(op.getValue().getType()); + if (!valueType) + return failure(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getValue()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + // The offsets need to be distributed + auto offsetsVecType = + dyn_cast(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + + auto chunkSizeOpt = op.getChunkSize(); + int64_t chunkSize = chunkSizeOpt ? static_cast(*chunkSizeOpt) : 1; + auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); + for (auto [val, offs, mask] : llvm::zip( + adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { + rewriter.create( + loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + // Update the layout attribute to drop sg_layout and sg_data. + if (auto newLayout = layout.dropSgLayoutAndData()) + op->setAttr("layout", newLayout); + } + rewriter.eraseOp(op); + return success(); + } +}; + struct WgToSgLoadMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -826,8 +930,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, - WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>( - patterns.getContext()); + WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, + WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, + WgToSgStoreMatrixOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -952,6 +1057,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](xegpu::LoadGatherOp op) -> bool { + auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); + return isLegal(layout); + }); + + target.addDynamicallyLegalOp( + [=](xegpu::StoreScatterOp op) -> bool { + // Check if the layout attribute is present on the result. + auto layout = op->getAttrOfType("layout"); + if (!layout) + return true; + return isLegal(layout); + }); + target.addDynamicallyLegalOp( [=](vector::BroadcastOp op) -> bool { return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 32157a7911f62..afb2bf876c18f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -264,6 +264,50 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: @load_gather + // CHECK-SAME: %[[ARG0:.*]]: memref + gpu.func @load_gather(%src : memref) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex> + // CHECK: %[[MASK:.*]] = arith.constant dense : vector<32x4xi1> + // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> + // CHECK-SAME: : memref, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256x16xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256x16xi1> + %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + : memref, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16> + gpu.return + } + + // CHECK-LABEL: @store_scatter + // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16> + gpu.func @store_scatter(%dest : memref<256xf16>) { + // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16> + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> + // CHECK: %[[MASK:.*]] = arith.constant dense : vector<8xi1> + // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> + // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> + %val = arith.constant {layout_result_0 = #xegpu.layout} dense<25.5> : vector<256xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} + : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1> + gpu.return + } + + // CHECK-LABEL: @load_with_non_unit_chunk_size + // CHECK-SAME: %[[ARG0:.*]]: memref + gpu.func @load_with_non_unit_chunk_size(%src : memref) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> + // CHECK: %[[MASK:.*]] = arith.constant dense : vector<8xi1> + // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint}> + // CHECK-SAME: : memref, vector<8xindex>, vector<8xi1> -> vector<8x4xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> + %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + : memref, vector<256xindex>, vector<256xi1> -> vector<256x4xf16> + gpu.return + } + // CHECK-LABEL: distribute_load_matrix // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {