diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 7f7f7d065c50e..716681fe9e187 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -13,8 +13,9 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td" include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td" include "mlir/IR/BuiltinTypes.td" -def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; -def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>; +def XeGPU_IntType : AnyTypeOf<[I1, I<4>, I8, I16, I32, I64, SI1, SI8, SI16, + SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; +def XeGPU_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>; def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>; def XeGPU_BaseAddrType diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 9c99a24bea8cd..2b162ec3f3bf4 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -150,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional L1hint, } } +// +// Note: +// Block operations for tile of sub byte element types are handled by +// emulating with larger element types. +// Tensor descriptor are keep intact and only ops consuming them are +// emulated +// + class CreateNdDescToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -262,9 +270,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { op, "Expected offset rank to match descriptor rank."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); - if (elemBitSize % 8 != 0) + bool isSubByte = elemBitSize < 8; + uint64_t wScaleFactor = 1; + + if (!isSubByte && (elemBitSize % 8 != 0)) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); + auto tileW = tdescTy.getDimSize(tileRank - 1); + // For sub byte types, only 4bits are currently supported. + if (isSubByte) { + if (elemBitSize != 4) + return rewriter.notifyMatchFailure( + op, "Only sub byte types of 4bits are supported."); + if (tileRank != 2) + return rewriter.notifyMatchFailure( + op, "Sub byte types are only supported for 2D tensor descriptors."); + auto subByteFactor = 8 / elemBitSize; + auto tileH = tdescTy.getDimSize(0); + // Handle special case for packed load. + if constexpr (std::is_same_v) { + if (op.getPacked().value_or(false)) { + // packed load is implemented as packed loads of 8bit elements. + if (tileH == systolicDepth * 4 && + tileW == executionSize * subByteFactor) { + // Usage case for loading as Matrix B with pack request. + // source is assumed to pre-packed into 8bit elements + // Emulate with 8bit loads with pack request. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(8); + tileW = executionSize; + wScaleFactor = subByteFactor; + } + } + } + // If not handled by packed load case above, handle other cases. + if (wScaleFactor == 1) { + auto sub16BitFactor = subByteFactor * 2; + if (tileW == executionSize * sub16BitFactor) { + // Usage case for loading as Matrix A operand + // Emulate with 16bit loads/stores. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(16); + tileW = executionSize; + wScaleFactor = sub16BitFactor; + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported tile shape for sub byte types."); + } + } + // recompute element bit size for emulation. + elemBitSize = elemType.getIntOrFloatBitWidth(); + } // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( @@ -298,15 +354,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // FIXME: width or pitch is not the same as baseShapeW it should be the + // stride of the second to last dimension in row major layout. // Compute width in bytes. - Value baseWidthByte = + Value baseShapeWInBytes = arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); // Compute pitch in bytes. - Value basePitchByte = + Value basePitchBytes = arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize); - // Get tile width from the tensor descriptor type. - auto tileW = tdescTy.getDimSize(tileRank - 1); + if (wScaleFactor > 1) { + // Scale offsetW, baseShapeWInBytes for sub byte emulation. + // Note: tileW is already scaled above. + Value wScaleFactorValLog2 = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor)); + baseShapeWInBytes = arith::ShRSIOp::create( + rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2); + basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes, + wScaleFactorValLog2); + offsetW = + arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2); + } // Get tile height from the tensor descriptor type. auto tileH = tdescTy.getDimSize(0); // Get vblocks from the tensor descriptor type. @@ -330,8 +398,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { auto storeCacheControl = translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH, - basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src, + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src, xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { @@ -339,8 +407,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); if constexpr (std::is_same_v) { xevm::BlockPrefetch2dOp::create( - rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH, - basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); rewriter.eraseOp(op); } else { @@ -354,9 +422,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { : rewriter.getIntegerType(elemBitSize)); Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH, - basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, - vblocks, transpose, vnni, + rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes, + baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW, + tileH, vblocks, transpose, vnni, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); resultFlatVec = vector::BitCastOp::create( rewriter, loc, diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir new file mode 100644 index 0000000000000..97e5ce14f8539 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir @@ -0,0 +1,80 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store_matrix_a + // CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1> + gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel { + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64> + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32 + // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]] + // CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] + // CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64 + %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4> + // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]] + // CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] + // CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64 + %dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4> + + // CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4> + + // CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64> + // CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]], + // CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{ + // CHECK-SAME: cache_control = #xevm.load_cache_control, elem_size_in_bits = 16 : i32, + // CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, + // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + %loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x64xi4> -> vector<32xi4> + + // CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32> + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr> + + // CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64> + // CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1> + // CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]], + // CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{ + // CHECK-SAME: cache_control = #xevm.store_cache_control, elem_size_in_bits = 16 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) + xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr> + gpu.return + } + + // CHECK-LABEL: gpu.func @load_matrix_b_request_pack + gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel { + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + %srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4> + %dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4> + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4> + + // CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{ + // CHECK-SAME: cache_control = #xevm.load_cache_control, elem_size_in_bits = 8 : i32, + // CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false, + // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + %loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<32x32xi4> -> vector<64xi4> + + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir new file mode 100644 index 0000000000000..f9254728bab41 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @prefetch_check { + // CHECK-LABEL: gpu.func @prefetch_matrix_a + gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel { + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4> + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4> + + // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] + // CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 16 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1> + xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x64xi4> + + gpu.return + } +}