diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp index 4dc5ea4f7bb24..1642e9829a79d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp @@ -282,9 +282,9 @@ class XeGPUCreateNdDescOpPattern final modifiedStrides[modifiedStrides.size() - 2]), innerLaneData); - // If the source is a static memref, we need to extract the pointer to + // If the source is a memref, we need to extract the pointer to // base address. - if (memrefType && memrefType.hasStaticShape()) { + if (memrefType) { auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create( rewriter, createNdOp.getLoc(), source); source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(), diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir similarity index 91% rename from mlir/test/Dialect/XeGPU/optimize-transpose.mlir rename to mlir/test/Dialect/XeGPU/optimize-block-loads.mlir index 24a0de6ed48a5..526adc5a95d10 100644 --- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir +++ b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir @@ -278,3 +278,32 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @dynamic_memref( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> { +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK-NEXT: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref -> index +// CHECK-NEXT: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]] {layout_result_0 = +// CHECK-SAME: #xegpu.layout} : !xegpu.tensor_desc<16x8xi32, +// CHECK-SAME: #xegpu.layout> -> vector<16x8xi32> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T2]] {layout_result_0 = +// CHECK-SAME: #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> +#a = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout +gpu.module @xevm_module { +gpu.func @dynamic_memref(%arg0: memref, %arg1: vector<8x16xf16>) -> vector<8x16xf32> { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %0 = xegpu.create_nd_tdesc %arg0, shape : [64, 64], strides : [64, 1] : memref -> !xegpu.tensor_desc<16x16xf16, #b> + %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16> + %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + gpu.return %6 : vector<8x16xf32> +} +}