diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae0989bed26..09f10b3ac952d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -122,8 +122,27 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { return std::nullopt; } if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { - LDBG() << "--no unrolling needed -> SKIP"; - return std::nullopt; + // If maybeShapeRatio are all 1s, only allow unrolling for leading unit + // dimension removal: [1,1,...,n] -> [n] + if (maybeUnrollShape->size() <= targetShape->size()) { + LDBG() << "--no dimension reduction -> SKIP"; + return std::nullopt; + } + + size_t dimDiff = maybeUnrollShape->size() - targetShape->size(); + ArrayRef srcShape = *maybeUnrollShape; + ArrayRef tgtShape = *targetShape; + + // Check leading dimensions are 1s and remaining matches target + bool isValidRemoval = llvm::all_of(srcShape.slice(0, dimDiff), + [](int64_t dim) { return dim == 1; }) && + srcShape.slice(dimDiff) == tgtShape; + + if (!isValidRemoval) { + LDBG() << "--not a valid leading unit dimension removal -> SKIP"; + return std::nullopt; + } + LDBG() << "--leading unit dimension removal -> CONTINUE"; } LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS"; return targetShape; diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5c67f33..9fd77645b78b5 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,18 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + + +func.func @elementwise_leading_unit_dim(%v1: vector<1x2x2xf32>, %v2: vector<1x2x2xf32>) -> vector<1x2x2xf32> { + %0 = arith.addf %v1, %v2 : vector<1x2x2xf32> + return %0 : vector<1x2x2xf32> +} + +// CHECK-LABEL: func @elementwise_leading_unit_dim +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x2x2xf32>, %[[ARG1:.*]]: vector<1x2x2xf32>) -> vector<1x2x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x2x2xf32> +// CHECK: %[[S_LHS:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[S_RHS:.*]] = vector.shape_cast %[[ARG1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[ADD:.*]] = arith.addf %[[S_LHS]], %[[S_RHS]] : vector<2x2xf32> +// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ADD]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<1x2x2xf32> +// CHECK: return %[[INS]] : vector<1x2x2xf32>