diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index fcbf66dbe9e45..33e8f2ed1f6ed 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -194,8 +194,8 @@ class CreateNdDescToXeVMPattern // If source is a memref, we need to extract the aligned pointer as index. // Pointer type is passed as i32 or i64 by type converter. if (sourceMemrefTy) { - if (!sourceMemrefTy.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, "Expected static memref shape."); + if (!sourceMemrefTy.hasRank()) { + return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index d6e36fa73bf04..09ef76c9d1740 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -4,8 +4,9 @@ gpu.module @create_nd_tdesc { // CHECK-LABEL: gpu.func @create_nd_tdesc // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64, // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index + // CHECK-SAME: %[[DYN:.*]]: memref) kernel { gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, - %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel { + %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref) kernel { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> @@ -43,6 +44,28 @@ gpu.module @create_nd_tdesc { // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[C1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C64:.*]] = arith.constant 64 : index + %size_x = arith.constant 64 : index + // CHECK: %[[C16:.*]] = arith.constant 16 : index + %BLOCK_DMODEL = arith.constant 16 : index + // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref -> index + // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32 + // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32 + // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32 + // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 + // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32> + %dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> gpu.return } }