diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index d879b93586899..63658518dd4a3 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -8,6 +8,7 @@ #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc" #include @@ -28,6 +29,7 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; } const APInt &ConstantIntRanges::smax() const { return smaxVal; } unsigned ConstantIntRanges::getStorageBitwidth(Type type) { + type = getElementTypeOrSelf(type); if (type.isIndex()) return IndexType::kInternalStorageBitWidth; if (auto integerType = dyn_cast(type)) diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 29282423089ba..09dfe932a5232 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -96,7 +96,7 @@ func.func @vector_insertelement() -> vector<4xindex> { // CHECK-LABEL: func @test_loaded_vector_extract // No bounds -// CHECK: test.reflect_bounds %{{.*}} : i32 +// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32 func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 { %c0 = arith.constant 0 : index %v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32> @@ -104,3 +104,12 @@ func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 { %bounds = test.reflect_bounds %e : i32 func.return %bounds : i32 } + +// CHECK-LABEL: func @test_vector_extsi +// CHECK: test.reflect_bounds {smax = 5 : si32, smin = 1 : si32, umax = 5 : ui32, umin = 1 : ui32} +func.func @test_vector_extsi() -> vector<2xi32> { + %0 = test.with_bounds {smax = 5 : si8, smin = 1 : si8, umax = 5 : ui8, umin = 1 : ui8 } : vector<2xi8> + %1 = arith.extsi %0 : vector<2xi8> to vector<2xi32> + %2 = test.reflect_bounds %1 : vector<2xi32> + func.return %2 : vector<2xi32> +}