diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index a7f2dc2d6a43e..9ead1d89069d6 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern matchAndRewrite(xegpu::CreateNdDescOp op, xegpu::CreateNdDescOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector mixedOffsets = op.getMixedOffsets(); + if (mixedOffsets.size() != 0) + return rewriter.notifyMatchFailure(op, "Offsets not supported."); auto loc = op.getLoc(); auto source = op.getSource(); // Op is lowered to a code sequence that populates payload. @@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector mixedSizes = op.getMixedSizes(); - SmallVector mixedOffsets = op.getMixedOffsets(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); if (rank != 2) @@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - // Offsets can be either 2D or not provided (0 is used). - if (mixedOffsets.size() == 2) { - offsetW = createOffset(mixedOffsets, 1); - offsetH = createOffset(mixedOffsets, 0); - } else if (mixedOffsets.size() == 0) { - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - } else { - return rewriter.notifyMatchFailure(op, - "Expected 2D offsets or no offsets."); - } + // Offsets are not supported (0 is used). + offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); @@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern } }; -class UpdateNdOffsetToXeVMPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::UpdateNdOffsetOp op, - xegpu::UpdateNdOffsetOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto mixedOffsets = op.getMixedOffsets(); - // Only 2D offsets are supported for now. - if (mixedOffsets.size() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); - auto payload = adaptor.getTensorDesc(); - // Utility for updating payload offset values from op fold result. - auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { - Value offset = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]); - offset = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offset); - Value oldOffset = - vector::ExtractOp::create(rewriter, loc, payload, payloadPos); - Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); - return vector::InsertOp::create(rewriter, loc, newOffset, payload, - payloadPos); - }; - // Update offsets in the payload. - payload = updateOffset(0, static_cast(NdTdescOffset::TensorOffsetH)); - payload = updateOffset(1, static_cast(NdTdescOffset::TensorOffsetW)); - rewriter.replaceOp(op, payload); - return success(); - } -}; - template < typename OpType, typename = std::enable_if_t { LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto mixedOffsets = op.getMixedOffsets(); + int64_t opOffsetsSize = mixedOffsets.size(); + if (opOffsetsSize != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); @@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); Value baseShapeH = vector::ExtractOp::create( rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); - // Offsets provided in two ways: - // 1. Offsets are extracted from the tensor descriptor. - // 2. (Mixed) offsets which are provided by the op. - Value offsetW; - Value offsetH; - auto mixedOffsets = op.getMixedOffsets(); - int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 0 && opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, - "Expected 2D offsets or no offsets."); - if (opOffsetsSize) { - // If mixed offsets are provided by the op convert them to i32. - offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); - offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetW); - offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); - offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetH); - } else { - // If offsets are not available, we need to extract them from the tensor - // descriptor. - offsetW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdTdescOffset::TensorOffsetW)); - offsetH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdTdescOffset::TensorOffsetH)); - } + // Offsets are provided by the op. + // convert them to i32. + Value offsetW = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + Value offsetH = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); @@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, return newAddr; } -class CreateDescToXeVMPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto eTy = op.getTensorDescType().getElementType(); - auto eBw = eTy.getIntOrFloatBitWidth(); - if (eBw % 8 != 0) - return rewriter.notifyMatchFailure( - op, "Expected element type bit width to be multiple of 8."); - auto loc = op.getLoc(); - // Offsets are provided as scalar i64 by type converter. - auto offsets = adaptor.getOffsets(); - // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32). - // But type converter will convert them to integer types. - Value addr = adaptor.getSource(); - // ui32 or i32 are passed as i32 so they need to be casted to i64. - if (addr.getType() != rewriter.getI64Type()) - addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr); - auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8); - rewriter.replaceOp(op, laneAddr); - return success(); - } -}; - -class UpdateOffsetToXeVMPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::UpdateOffsetOp op, - xegpu::UpdateOffsetOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto eTy = op.getTensorDescType().getElementType(); - auto eBw = eTy.getIntOrFloatBitWidth(); - if (eBw % 8 != 0) - return rewriter.notifyMatchFailure( - op, "Expected element type bit width to be multiple of 8."); - auto loc = op.getLoc(); - // Scatter descriptor is provided as scalar i64 by type converter. - // Offsets are provided as scalar i64 by type converter. - Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), - adaptor.getOffsets(), eBw / 8); - rewriter.replaceOp(op, newOffset); - return success(); - } -}; - template ::value>> @@ -478,6 +379,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Value offset = adaptor.getOffsets(); + if (!offset) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdescTy = op.getTensorDescType(); @@ -527,21 +431,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrI64); } - Value offsets = adaptor.getOffsets(); Value mask = adaptor.getMask(); - if (offsets) { - if (dyn_cast(offsets.getType())) { - // Offset needs be scalar. Single element vector is converted to scalar - // by type converter. - return rewriter.notifyMatchFailure(op, - "Expected offsets to be a scalar."); - } else { - // If offsets are provided, we add them to the base pointer. - // Offsets are in number of elements, we need to multiply by - // element byte size. - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); - } + if (dyn_cast(offset.getType())) { + // Offset needs be scalar. Single element vector is converted to scalar + // by type converter. + return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar."); + } else { + // If offset is provided, we add them to the base pointer. + // Offset is in number of elements, we need to multiply by + // element byte size. + basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass //===----------------------------------------------------------------------===// void mlir::populateXeGPUToXeVMConversionPatterns( const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, LoadStorePrefetchNdToXeVMPattern, LoadStorePrefetchNdToXeVMPattern>( typeConverter, patterns.getContext()); - patterns.add, LoadStoreToXeVMPattern>( typeConverter, patterns.getContext()); diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index ed664a739d134..d6e36fa73bf04 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc { // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index - // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32 - // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32 - // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64 - // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32 - // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64 - // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32 - // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64 - // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64> - // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32> - %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - // CHECK: %[[C8:.*]] = arith.constant 8 : index - %c8 = arith.constant 8 : index - // CHECK: %[[C16:.*]] = arith.constant 16 : index - %c16 = arith.constant 16 : index - // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32 - // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32> - // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32 - // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32> - // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32 - // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32> - // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32 - // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32> - %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32> gpu.return } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index 0f67dc290689b..0b150e9d58153 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -1,239 +1,73 @@ // RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s gpu.module @test { -// CHECK-LABEL: @load_gather_ui64_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: ui64 -gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> - // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 - %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> - // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) { - // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control} - // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32> - // CHECK: scf.yield %[[VAR9]] : vector<2xf32> - // CHECK: } else { - // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> - // CHECK: scf.yield %[[CST_1]] : vector<2xf32> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xi1> -> vector<2xf32> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @load_gather_memref_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> -gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) { - // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control} - // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32> - // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32> - // CHECK: scf.yield %[[VAR9]] : f32 - // CHECK: } else { - // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> - // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32> - // CHECK: scf.yield %[[VAR8]] : f32 - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @load_gather_memref_src_value_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex> -gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) { +// CHECK-LABEL: @load_gather_i64_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> +gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex> - -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) { - // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control} - // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16> - // CHECK: scf.yield %[[VAR8]] : vector<8xf16> - // CHECK: } else { - // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16> - // CHECK: scf.yield %[[CST_0]] : vector<8xf16> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr>, vector<1xi1> -> vector<8xf16> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @store_scatter_ui64_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: ui64 -gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> - // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32> - %2 = arith.constant dense<2.9>: vector<2xf32> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 - %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> - // CHECK: scf.if %[[VAR4]] { - // CHECK: llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control} - // CHECK-SAME: : vector<2xf32>, !llvm.ptr<1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xi1> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @store_scatter_memref_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> -gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16> - %2 = arith.constant dense<2.9>: vector<2xf16> - // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: scf.if %[[VAR2]] { - // CHECK: llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control} - // CHECK-SAME: : vector<2xf16>, !llvm.ptr<1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr>, vector<1xi1> + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64 + // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) { + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control} : !llvm.ptr<1> -> vector<1xf16> + // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16> + // CHECK: scf.yield %[[VAR8]] : f16 + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16> + // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16> + // CHECK: scf.yield %[[VAR7]] : f16 + // CHECK: } + %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> gpu.return } } // ----- gpu.module @test { -// CHECK-LABEL: @store_scatter_memref_src_value_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> -gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { +// CHECK-LABEL: @store_scatter_i64_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> +gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> - // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32> + // CHECK: %[[VAR3:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32> %2 = arith.constant dense<2.9>: vector<1xf32> // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[VAR5:.*]] = arith.addi %[[ARG0]], %[[VAR4]] : i64 // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> // CHECK: scf.if %[[VAR2]] { - // CHECK: llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control} - // CHECK-SAME: : f32, !llvm.ptr<1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> + // CHECK: llvm.store %[[VAR3]], %[[VAR6]] {cache_control = #xevm.store_cache_control} : f32, !llvm.ptr<1> + // CHECK: } + xegpu.store %2, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<1xf32>, i64, vector<1xindex>, vector<1xi1> gpu.return } } // ----- gpu.module @test { -// CHECK-LABEL: @prefetch_ui64_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: ui64 -gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64 - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) - xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @prefetch_memref_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> -gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> +// CHECK-LABEL: @prefetch_i64_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> +gpu.func @prefetch_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 - %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> - // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) - xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR2:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR3:.*]] = arith.addi %[[ARG0]], %[[VAR2]] : i64 + // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[VAR3]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR4]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %src[%offset] <{offset_align_byte=4, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : i64, vector<1xindex> gpu.return } } @@ -250,12 +84,10 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 - %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) - xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + xegpu.prefetch %src[%offset] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : memref<256xf32>, vector<1xindex> gpu.return } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir index b28a8c2ccf843..2a2b99f57cabd 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir @@ -9,9 +9,9 @@ gpu.module @materializecast { gpu.func @materialize_memref(%src: memref<128xf32>) kernel { // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %offset = arith.constant 0 : index + %mask = arith.constant 1 : i1 + %val = xegpu.load %src[%offset], %mask : memref<128xf32>, index, i1 -> f32 gpu.return } } @@ -23,9 +23,9 @@ gpu.module @materializecast { gpu.func @materialize_ui64(%src: ui64) kernel { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %offset = arith.constant 0 : index + %mask = arith.constant 1 : i1 + %val = xegpu.load %src[%offset], %mask : ui64, index, i1 -> vector<1xf32> gpu.return } } @@ -37,9 +37,9 @@ gpu.module @materializecast { gpu.func @materialize_ui32(%src: ui32) kernel { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32 - %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %offset = arith.constant 0 : index + %mask = arith.constant 1 : i1 + %val = xegpu.load %src[%offset], %mask : ui32, index, i1 -> vector<1xf32> gpu.return } } @@ -52,24 +52,12 @@ gpu.module @materializecast { // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64 + // CHECK: %[[CST1:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR3:.*]] = vector.extract %[[CST1]][0] : i1 from vector<1xi1> %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %mask = arith.constant dense<1> : vector<1xi1> + %val = xegpu.load %src[%offset], %mask : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1xf32> gpu.return } } -// ----- -gpu.module @materializecast { - // CHECK-LABEL: gpu.func @materialize_single_elem_vector - // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> - gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel { - // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> - // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> - %mask = arith.constant dense<1>: vector<1xi1> - %offset = arith.constant dense<0> : vector<1xindex> - %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32> - gpu.return - } -} diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir deleted file mode 100644 index 6e59414c62582..0000000000000 --- a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s - -gpu.module @update_offset { - // CHECK-LABEL: gpu.func @update_offset - // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> - gpu.func @update_offset(%src: memref<128xf32>) kernel { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index - // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - %offset = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 - %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64 - // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64 - %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> - , vector<1xindex> - gpu.return - } -}