diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb0ba4c9497d0..87ef39272aec9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4239,10 +4239,14 @@ OpFoldResult BitCastOp::fold(ArrayRef operands) { return getSource(); // Canceling bitcasts. - if (auto otherOp = getSource().getDefiningOp()) + if (auto otherOp = getSource().getDefiningOp()) { if (getResult().getType() == otherOp.getSource().getType()) return otherOp.getSource(); + setOperand(otherOp.getSource()); + return getResult(); + } + Attribute sourceConstant = operands.front(); if (!sourceConstant) return {}; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 152d0d241685c..e066e19eb9137 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1605,3 +1605,14 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 { %s = vector.reduction , %a : vector<4xf32> into f32 return %s : f32 } + +// ----- + +// CHECK-LABEL: func @bitcast( +// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> { +// CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16> +func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> { + %0 = vector.bitcast %a : vector<4x8xf32> to vector<4x8xi32> + %1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16> + return %1 : vector<4x16xi16> +}