-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Conversion] Convert XeGPU to XeVM pass: Remove lowering support for tensor descriptor with offsets. #157550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Update load/store/prefetch test cases to use direct offsets.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) ChangesAnd update load/store/prefetch test cases to use direct offsets. Patch is 38.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/157550.diff 5 Files Affected:
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a7f2dc2d6a43e..7d756620ecd16 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
matchAndRewrite(xegpu::CreateNdDescOp op,
xegpu::CreateNdDescOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ if (mixedOffsets.size() != 0)
+ return rewriter.notifyMatchFailure(op, "Offsets not supported.");
auto loc = op.getLoc();
auto source = op.getSource();
// Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
if (rank != 2)
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- // Offsets can be either 2D or not provided (0 is used).
- if (mixedOffsets.size() == 2) {
- offsetW = createOffset(mixedOffsets, 1);
- offsetH = createOffset(mixedOffsets, 0);
- } else if (mixedOffsets.size() == 0) {
- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- } else {
- return rewriter.notifyMatchFailure(op,
- "Expected 2D offsets or no offsets.");
- }
+ // Offsets are not supported not (0 is used).
+ offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
}
};
-class UpdateNdOffsetToXeVMPattern
- : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(xegpu::UpdateNdOffsetOp op,
- xegpu::UpdateNdOffsetOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto mixedOffsets = op.getMixedOffsets();
- // Only 2D offsets are supported for now.
- if (mixedOffsets.size() != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
- auto payload = adaptor.getTensorDesc();
- // Utility for updating payload offset values from op fold result.
- auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
- Value offset =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
- offset = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offset);
- Value oldOffset =
- vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
- Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
- return vector::InsertOp::create(rewriter, loc, newOffset, payload,
- payloadPos);
- };
- // Update offsets in the payload.
- payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
- payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
- rewriter.replaceOp(op, payload);
- return success();
- }
-};
-
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto mixedOffsets = op.getMixedOffsets();
+ int64_t opOffsetsSize = mixedOffsets.size();
+ if (opOffsetsSize != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
@@ -311,32 +276,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
- // Offsets provided in two ways:
- // 1. Offsets are extracted from the tensor descriptor.
- // 2. (Mixed) offsets which are provided by the op.
- Value offsetW;
- Value offsetH;
- auto mixedOffsets = op.getMixedOffsets();
- int64_t opOffsetsSize = mixedOffsets.size();
- if (opOffsetsSize != 0 && opOffsetsSize != 2)
- return rewriter.notifyMatchFailure(op,
- "Expected 2D offsets or no offsets.");
- if (opOffsetsSize) {
- // If mixed offsets are provided by the op convert them to i32.
- offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
- offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetW);
- offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
- offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetH);
- } else {
- // If offsets are not available, we need to extract them from the tensor
- // descriptor.
- offsetW = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
- offsetH = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
- }
+ // Offsets are provided by the op.
+ // convert them to i32.
+ Value offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ Value offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
@@ -422,54 +369,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
return newAddr;
}
-class CreateDescToXeVMPattern
- : public OpConversionPattern<xegpu::CreateDescOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto eTy = op.getTensorDescType().getElementType();
- auto eBw = eTy.getIntOrFloatBitWidth();
- if (eBw % 8 != 0)
- return rewriter.notifyMatchFailure(
- op, "Expected element type bit width to be multiple of 8.");
- auto loc = op.getLoc();
- // Offsets are provided as scalar i64 by type converter.
- auto offsets = adaptor.getOffsets();
- // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
- // But type converter will convert them to integer types.
- Value addr = adaptor.getSource();
- // ui32 or i32 are passed as i32 so they need to be casted to i64.
- if (addr.getType() != rewriter.getI64Type())
- addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
- auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
- rewriter.replaceOp(op, laneAddr);
- return success();
- }
-};
-
-class UpdateOffsetToXeVMPattern
- : public OpConversionPattern<xegpu::UpdateOffsetOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(xegpu::UpdateOffsetOp op,
- xegpu::UpdateOffsetOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto eTy = op.getTensorDescType().getElementType();
- auto eBw = eTy.getIntOrFloatBitWidth();
- if (eBw % 8 != 0)
- return rewriter.notifyMatchFailure(
- op, "Expected element type bit width to be multiple of 8.");
- auto loc = op.getLoc();
- // Scatter descriptor is provided as scalar i64 by type converter.
- // Offsets are provided as scalar i64 by type converter.
- Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
- adaptor.getOffsets(), eBw / 8);
- rewriter.replaceOp(op, newOffset);
- return success();
- }
-};
-
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +377,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Value offsets = adaptor.getOffsets();
+ if (!offsets)
+ return rewriter.notifyMatchFailure(op, "Expected offsets to be provided.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
@@ -527,21 +429,18 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
basePtrI64);
}
- Value offsets = adaptor.getOffsets();
Value mask = adaptor.getMask();
- if (offsets) {
- if (dyn_cast<VectorType>(offsets.getType())) {
- // Offset needs be scalar. Single element vector is converted to scalar
- // by type converter.
- return rewriter.notifyMatchFailure(op,
- "Expected offsets to be a scalar.");
- } else {
- // If offsets are provided, we add them to the base pointer.
- // Offsets are in number of elements, we need to multiply by
- // element byte size.
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
- }
+ if (dyn_cast<VectorType>(offsets.getType())) {
+ // Offset needs be scalar. Single element vector is converted to scalar
+ // by type converter.
+ return rewriter.notifyMatchFailure(op,
+ "Expected offsets to be a scalar.");
+ } else {
+ // If offsets are provided, we add them to the base pointer.
+ // Offsets are in number of elements, we need to multiply by
+ // element byte size.
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
//===----------------------------------------------------------------------===//
void mlir::populateXeGPUToXeVMConversionPatterns(
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
- patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+ patterns.add<CreateNdDescToXeVMPattern,
LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
typeConverter, patterns.getContext());
- patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
- AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+ patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index ed664a739d134..d6e36fa73bf04 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
- // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
- // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
- // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
- // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
- // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
- // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
- // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
- // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
- // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
- // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
- %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
- // CHECK: %[[C8:.*]] = arith.constant 8 : index
- %c8 = arith.constant 8 : index
- // CHECK: %[[C16:.*]] = arith.constant 16 : index
- %c16 = arith.constant 16 : index
- // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
- // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
- // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
- // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
- // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
- // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
- // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
- // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
- %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 0f67dc290689b..0b150e9d58153 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -1,239 +1,73 @@
// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
gpu.module @test {
-// CHECK-LABEL: @load_gather_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
- // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
- // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
- %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
- // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) {
- // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
- // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32>
- // CHECK: scf.yield %[[VAR9]] : vector<2xf32>
- // CHECK: } else {
- // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
- // CHECK: scf.yield %[[CST_1]] : vector<2xf32>
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32>
- gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
- // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
- %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
- // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) {
- // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
- // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32>
- // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32>
- // CHECK: scf.yield %[[VAR9]] : f32
- // CHECK: } else {
- // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
- // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32>
- // CHECK: scf.yield %[[VAR8]] : f32
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32>
- gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
+// CHECK-LABEL: @load_gather_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: ve...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
return rewriter.notifyMatchFailure(op, | ||
"Expected 2D offsets or no offsets."); | ||
} | ||
// Offsets are not supported not (0 is used). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: drop the second "not" in "Offsets are not supported not"
LogicalResult | ||
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Value offsets = adaptor.getOffsets(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: offsets -> offset ? since it is only a scalar operand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
And update load/store/prefetch test cases to use direct offsets.
Tensor descriptors with offsets are getting deprecated.