diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index e79f6a8aec1cf..70b56ca77b2da 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -26,6 +26,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/DebugStringHelper.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" @@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { else dialect = value.getParentBlock()->getParentOp()->getDialect(); - Type type = getElementTypeOrSelf(value); - solver->propagateIfChanged( - cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect))); + Attribute cstAttr; + if (isa(value.getType())) { + cstAttr = IntegerAttr::get(value.getType(), *constant); + } else if (auto shapedTy = dyn_cast(value.getType())) { + cstAttr = SplatElementsAttr::get(shapedTy, *constant); + } else { + llvm::report_fatal_error( + Twine("FIXME: Don't know how to create a constant for this type: ") + + mlir::debugString(value.getType())); + } + solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect))); } LogicalResult IntegerRangeAnalysis::visitOperation( diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 777ff0ecaa314..2017905587b26 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -8,6 +8,7 @@ #include +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" @@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final MLIRContext *ctx = op->getContext(); DataFlowSolver solver; solver.load(); + solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index ea5969a100258..e6e48d30cece5 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -132,3 +132,19 @@ func.func @wraps() -> i8 { %mod = arith.remsi %val, %c64 : i8 return %mod : i8 } + +// ----- + +// CHECK-LABEL: @analysis_crash +func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<-1> : tensor<128xi32> + %splat = tensor.splat %arg0 : tensor<128xi32> + %0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 { + scf.yield %arg3 : tensor<128xi32> + } + %1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32> + // Make sure the analysis doesn't crash when materializing the range as a tensor constant. + %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64> + return %2 : tensor<128xi64> +}