diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 299c8afffb2e5..0bc001b5d576a 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -725,6 +725,8 @@ struct ExtSIPattern final : public OpConversionPattern { assert(srcBW < dstBW); Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW, rewriter, op.getLoc()); + if (!shiftSize) + return rewriter.notifyMatchFailure(op, "unsupported type for shift"); // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's @@ -800,6 +802,8 @@ struct ExtUIPattern final : public OpConversionPattern { Value mask = getScalarOrVectorConstInt( dstType, llvm::maskTrailingOnes(bitwidth), rewriter, op.getLoc()); + if (!mask) + return rewriter.notifyMatchFailure(op, "unsupported type for mask"); rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn(), mask); } else { @@ -868,6 +872,8 @@ struct TruncIPattern final : public OpConversionPattern { unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); Value mask = getScalarOrVectorConstInt( dstType, llvm::maskTrailingOnes(bw), rewriter, op.getLoc()); + if (!mask) + return rewriter.notifyMatchFailure(op, "unsupported type for mask"); rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn(), mask); } else { diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index 9d7ab2be096ef..92b587d5ed1e4 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -122,6 +122,16 @@ func.func @unsupported_constant_tensor_2xf64_0() { // ----- +// Regression test: arith.trunci on tensor types should not crash +// (https://github.com/llvm/llvm-project/issues/178214). +func.func @trunci_tensor_no_crash(%arg0: tensor<1xi32>) -> tensor<1xi16> { + // expected-error @+1 {{failed to legalize operation 'arith.trunci'}} + %0 = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16> + return %0 : tensor<1xi16> +} + +// ----- + func.func @constant_dense_resource_non_existant() { // expected-error @+2 {{failed to legalize operation 'arith.constant'}} // expected-error @+1 {{could not find resource blob}}