diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 613adeb5eeaaf..25c2fe71f5ff4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6743,7 +6743,7 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) { if (intPack.isSplat()) { auto splat = intPack.getSplatValue(); - if (llvm::isa(dstElemType)) { + if (llvm::isa(dstElemType) && srcElemType.isIntOrFloat()) { uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth(); uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8126389212ce6..82b2cb633d1c9 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1371,6 +1371,19 @@ func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) { // ----- +// Verify that bitcast with index source element type does not crash (the fold +// must not call getIntOrFloatBitWidth on a non-integer/float type). +// CHECK-LABEL: func @bitcast_index_no_fold +// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<16xindex> +// CHECK: vector.bitcast %[[CST]] : vector<16xindex> to vector<16xi64> +func.func @bitcast_index_no_fold() -> vector<16xi64> { + %cst = arith.constant dense<0> : vector<16xindex> + %0 = vector.bitcast %cst : vector<16xindex> to vector<16xi64> + return %0 : vector<16xi64> +} + +// ----- + // CHECK-LABEL: broadcast_poison // CHECK: %[[POISON:.*]] = ub.poison : vector<4x6xi8> // CHECK: return %[[POISON]] : vector<4x6xi8>