From 4c0f2a78d74bfcedfc7e3a356e0d0b48c51c3c99 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 17 Sep 2025 15:07:02 +0100 Subject: [PATCH] [mlir][vector] Convert vector.transfer_read to scalar load and broadcast If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type. It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types. Instead of %s0 = vector.transfer_read %base[] : memref to vector %s1 = vector.broadcast %s0 : vector to vector Use %s0 = memref.load %base[] : memref %s1 = vector.broadcast %s0 : dtype to vector --- .../Vector/Transforms/LowerVectorTransfer.cpp | 40 ++++++++++++++----- .../vector-transfer-permutation-lowering.mlir | 18 +++++++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 2cf8f0beaa4de..4f62b6a7f2fde 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -360,17 +360,35 @@ struct TransferOpReduceRank SmallVector newScalableDims( originalVecType.getScalableDims().take_back(reducedShapeRank)); - VectorType newReadType = VectorType::get( - newShape, originalVecType.getElementType(), newScalableDims); - ArrayAttr newInBoundsAttr = - op.getInBounds() - ? rewriter.getArrayAttr( - op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) - : ArrayAttr(); - Value newRead = vector::TransferReadOp::create( - rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), - AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), - newInBoundsAttr); + Value newRead; + if (newShape.size() == 0 && newScalableDims.size() == 0) { + // Handle the scalar case. + // Convert + // %val = vector.transfer_read %base[] : memref to + // vector + // into + // %scalar = memref.load %base[] : memref + // %val = vector.broadcast %scalar : dtype to vector + Type baseType = op.getBase().getType(); + if (isa(baseType)) { + newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(), + op.getIndices()); + } + } + + if (!newRead) { + VectorType newReadType = VectorType::get( + newShape, originalVecType.getElementType(), newScalableDims); + ArrayAttr newInBoundsAttr = + op.getInBounds() + ? rewriter.getArrayAttr( + op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) + : ArrayAttr(); + newRead = vector::TransferReadOp::create( + rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), + AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), + newInBoundsAttr); + } return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType, newRead) .getVector(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index 3ae18835c8367..16104aa76e692 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims( return %res : vector<8x4x2x3xf32> } +// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_scalar +// CHECK-SAME: %[[MEM:.*]]: memref) -> vector<8x4x2x3xf32> { +// CHECK: %[[LOAD:.*]] = memref.load %[[MEM]][] : memref +// CHECK: %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32> +// CHECK: return %[[BC]] : vector<8x4x2x3xf32> +func.func @xfer_read_minor_identitiy_bcast_scalar( + %mem: memref) -> vector<8x4x2x3xf32> { + + %pad = arith.constant 0.000000e+00 : f32 + + %res = vector.transfer_read %mem[], %pad { + in_bounds = [true, true, true, true], + permutation_map = affine_map<() -> (0, 0, 0, 0)> + } : memref, vector<8x4x2x3xf32> + + return %res : vector<8x4x2x3xf32> +} + // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable // CHECK-SAME: %[[MEM:.*]]: memref, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> { // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref, vector<[4]x2x3xf32>