diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 07546c0fd51ff..7d9febec632ca 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -950,7 +950,12 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef operands) { Attribute src = operands[0]; Attribute pos = operands[1]; - if (!src || !pos) + + // Fold extractelement (splat X) -> X. + if (auto splat = getVector().getDefiningOp()) + return splat.getInput(); + + if (!pos || !src) return {}; auto srcElements = src.cast().getValues(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8b6640bb06784..033f17ae2fe12 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1409,3 +1409,13 @@ func @extract_element_fold() -> i32 { %1 = vector.extractelement %v[%i : i32] : vector<4xi32> return %1 : i32 } + +// CHECK-LABEL: func @extract_element_splat_fold +// CHECK-SAME: (%[[ARG:.+]]: i32) +// CHECK: return %[[ARG]] +func @extract_element_splat_fold(%a : i32) -> i32 { + %v = vector.splat %a : vector<4xi32> + %i = arith.constant 2 : i32 + %1 = vector.extractelement %v[%i : i32] : vector<4xi32> + return %1 : i32 +}