diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h new file mode 100644 index 0000000000000..3a877647490a3 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h @@ -0,0 +1,64 @@ +//===- IntegerDivisibilityAnalysis.h - Integer divisibility -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer divisibility +// inference. Operations participate in the analysis by implementing +// `InferIntDivisibilityOpInterface`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H +#define MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h" + +namespace mlir::dataflow { + +/// This lattice element represents the integer divisibility of an SSA value. +class IntegerDivisibilityLattice : public Lattice { +public: + using Lattice::Lattice; +}; + +/// Integer divisibility analysis determines, for each integer-typed SSA +/// value, a divisor that the value is guaranteed to be a multiple of. It +/// uses operations that implement `InferIntDivisibilityOpInterface` and +/// also sets the divisibility of induction variables of loops with known +/// lower bounds and steps. +/// +/// This analysis depends on DeadCodeAnalysis, and will be a silent no-op +/// if DeadCodeAnalysis is not loaded in the same solver context. +class IntegerDivisibilityAnalysis + : public SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + /// At an entry point, set the lattice to the most pessimistic state, + /// indicating that no further reasoning can be done. + void setToEntryState(IntegerDivisibilityLattice *lattice) override; + + /// Visit an operation, invoking the transfer function. + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; + + /// Visit block arguments or operation results of an operation with region + /// control-flow for which values are not defined by region control-flow. This + /// function tries to infer the divisibility of loop induction variables based + /// on known loop bounds and steps. + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ValueRange successorInputs, + ArrayRef argLattices) override; +}; + +} // namespace mlir::dataflow + +#endif // MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index b2a4cf7f488bd..3d7cbcc375d2a 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferIntDivisibilityOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" @@ -43,7 +44,9 @@ def ImplicitAffineTerminator : SingleBlockImplicitTerminator<"AffineYieldOp">; def AffineApplyOp : Affine_Op<"apply", - [Pure, DeclareOpInterfaceMethods]> { + [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "affine apply operation"; let description = [{ The `affine.apply` operation applies an [affine mapping](#affine-maps) @@ -570,7 +573,8 @@ class AffineMinMaxOpBase traits = []> : let hasVerifier = 1; } -def AffineMinOp : AffineMinMaxOpBase<"min", [Pure]> { +def AffineMinOp : AffineMinMaxOpBase<"min", + [Pure, DeclareOpInterfaceMethods]> { let summary = "min operation"; let description = [{ Syntax: @@ -594,7 +598,8 @@ def AffineMinOp : AffineMinMaxOpBase<"min", [Pure]> { }]; } -def AffineMaxOp : AffineMinMaxOpBase<"max", [Pure]> { +def AffineMaxOp : AffineMinMaxOpBase<"max", + [Pure, DeclareOpInterfaceMethods]> { let summary = "max operation"; let description = [{ The `affine.max` operation computes the maximum value result from a multi-result @@ -1071,6 +1076,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure, Elementwise, + DeclareOpInterfaceMethods, // Infer linear_index type from the first result type during parsing. TypesMatchWith<"linear_index type must match result types", "multi_index", "linear_index", "$_self[0]"> diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 0fc3db8e993d8..bf6eb18df2e8a 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -15,6 +15,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index fa85b840e2707..1f8b07aed3f0d 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferIntDivisibilityOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -223,6 +224,7 @@ def Arith_ConstantOp : Op, AllTypesMatch<["value", "result"]>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "integer or floating point constant"; let description = [{ @@ -270,7 +272,8 @@ def Arith_ConstantOp : Op { +def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", + [Commutative, DeclareOpInterfaceMethods]> { let summary = "integer addition operation"; let description = [{ Performs N-bit addition on the operands. The operands are interpreted as @@ -416,7 +419,8 @@ def Arith_SubUIExtendedOp : Arith_Op<"subui_extended", [Pure, // SubIOp //===----------------------------------------------------------------------===// -def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> { +def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi", + [DeclareOpInterfaceMethods]> { let summary = [{ Integer subtraction operation. }]; @@ -461,7 +465,9 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> { //===----------------------------------------------------------------------===// def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", - [Commutative, DeclareOpInterfaceMethods] + [Commutative, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods] > { let summary = [{ Integer multiplication operation. @@ -593,7 +599,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative, //===----------------------------------------------------------------------===// def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui", - [ConditionallySpeculatable]> { + [ConditionallySpeculatable, + DeclareOpInterfaceMethods]> { let summary = "unsigned integer division operation"; let description = [{ Unsigned integer division. Rounds towards zero. Treats the leading bit as @@ -1191,7 +1198,8 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> { // MaxSIOp //===----------------------------------------------------------------------===// -def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> { +def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", + [Commutative, DeclareOpInterfaceMethods]> { let summary = "signed integer maximum operation"; let hasFolder = 1; } @@ -1200,7 +1208,8 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> { // MaxUIOp //===----------------------------------------------------------------------===// -def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> { +def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", + [Commutative, DeclareOpInterfaceMethods]> { let summary = "unsigned integer maximum operation"; let hasFolder = 1; } @@ -1250,7 +1259,8 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> { // MinSIOp //===----------------------------------------------------------------------===// -def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> { +def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", + [Commutative, DeclareOpInterfaceMethods]> { let summary = "signed integer minimum operation"; let hasFolder = 1; } @@ -1259,7 +1269,8 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> { // MinUIOp //===----------------------------------------------------------------------===// -def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> { +def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", + [Commutative, DeclareOpInterfaceMethods]> { let summary = "unsigned integer minimum operation"; let hasFolder = 1; } @@ -2004,6 +2015,7 @@ class BooleanConditionOrMatchingShape : def SelectOp : Arith_Op<"select", [Pure, AllTypesMatch<["true_value", "false_value", "result"]>, BooleanConditionOrMatchingShape<"condition", "result">, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "select operation"; diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index 3cbc9df05f3d7..6461c68423c73 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface) add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(FunctionInterfaces) add_mlir_interface(IndexingMapOpInterface) +add_mlir_interface(InferIntDivisibilityOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferStridedMetadataInterface) add_mlir_interface(InferTypeOpInterface) diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h new file mode 100644 index 0000000000000..374acee05cb10 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h @@ -0,0 +1,120 @@ +//===- InferIntDivisibilityOpInterface.h - Integer Divisibility -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the integer divisibility inference +// interface defined in `InferIntDivisibilityOpInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H +#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H + +#include "mlir/IR/OpDefinition.h" +#include +#include + +namespace mlir { + +/// Statically known divisibility information for an integer SSA value. +/// Tracks separate divisors for the unsigned and signed interpretations of +/// the value so that subsequent analyses can use whichever is more precise. +class ConstantIntDivisibility { +public: + ConstantIntDivisibility() = default; + ConstantIntDivisibility(uint64_t udiv, uint64_t sdiv) + : udivVal(udiv), sdivVal(sdiv) {} + + bool operator==(const ConstantIntDivisibility &other) const { + return udivVal == other.udivVal && sdivVal == other.sdivVal; + } + + uint64_t udiv() const { return this->udivVal; } + uint64_t sdiv() const { return this->sdivVal; } + + // Returns the union (computed separately for signed and unsigned bounds) + // for this divisibility and `other`. + ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const { + return ConstantIntDivisibility( + /*udiv=*/std::gcd(udiv(), other.udiv()), + /*sdiv=*/std::gcd(sdiv(), other.sdiv())); + } + +private: + uint64_t udivVal; + uint64_t sdivVal; + + friend raw_ostream &operator<<(raw_ostream &os, + const ConstantIntDivisibility &div); +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const ConstantIntDivisibility &div) { + os << "ConstantIntDivisibility(udiv = " << div.udivVal + << ", sdiv = " << div.sdivVal << ")"; + return os; +} + +/// This lattice value represents the integer divisibility of an SSA value. +class IntegerDivisibility { +public: + IntegerDivisibility(ConstantIntDivisibility value) + : value(std::move(value)) {} + explicit IntegerDivisibility( + std::optional value = std::nullopt) + : value(std::move(value)) {} + // Gets the minimum divisibility of 1 that is used to indicate that the value + // cannot be analyzed further. + static IntegerDivisibility getMinDivisibility() { + return IntegerDivisibility(ConstantIntDivisibility(1, 1)); + } + + bool isUninitialized() const { return !value.has_value(); } + const ConstantIntDivisibility &getValue() const { + assert(!isUninitialized()); + return *value; + } + + bool operator==(const IntegerDivisibility &rhs) const { + return value == rhs.value; + } + + static IntegerDivisibility join(const IntegerDivisibility &lhs, + const IntegerDivisibility &rhs) { + if (lhs.isUninitialized()) { + return rhs; + } + if (rhs.isUninitialized()) { + return lhs; + } + return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue())); + } + + void print(raw_ostream &os) const { os << value; } + +private: + std::optional value; +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const IntegerDivisibility &div) { + div.print(os); + return os; +} + +/// The type of the `setResultDivs` callback provided to ops implementing +/// InferIntDivisibilityOpInterface. It should be called once for each integer +/// result value and be passed the ConstantIntDivisibility corresponding to +/// that value. +using SetIntDivisibilityFn = + llvm::function_ref; + +} // end namespace mlir + +#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h.inc" + +#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td new file mode 100644 index 0000000000000..c665475e0fd7f --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td @@ -0,0 +1,41 @@ +//===- InferIntDivisibilityOpInterface.td - Integer Divisibility -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for divisibility analysis on scalar integers. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE +#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE + +include "mlir/IR/OpBase.td" + +def InferIntDivisibilityOpInterface : + OpInterface<"InferIntDivisibilityOpInterface"> { + let description = [{ + Allows operations to participate in integer divisibility analysis. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Infer the divisibility of the results of this op given the + divisibility of its arguments. For each result value, the method + should call `setResultDivs` with that `Value` as an argument. + }], + /*retTy=*/"void", + /*methodName=*/"inferResultDivisibility", + /*args=*/(ins + "::llvm::ArrayRef<::mlir::IntegerDivisibility>":$argDivs, + "::mlir::SetIntDivisibilityFn":$setResultDivs) + > + ]; +} + +#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index db10ebcf2c311..596ffaff428b5 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES DataFlow/ConstantPropagationAnalysis.cpp DataFlow/DeadCodeAnalysis.cpp DataFlow/DenseAnalysis.cpp + DataFlow/IntegerDivisibilityAnalysis.cpp DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp @@ -37,6 +38,7 @@ add_mlir_library(MLIRAnalysis DataFlow/ConstantPropagationAnalysis.cpp DataFlow/DeadCodeAnalysis.cpp DataFlow/DenseAnalysis.cpp + DataFlow/IntegerDivisibilityAnalysis.cpp DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp @@ -53,6 +55,7 @@ add_mlir_library(MLIRAnalysis MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRFunctionInterfaces + MLIRInferIntDivisibilityOpInterface MLIRInferIntRangeInterface MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface diff --git a/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp new file mode 100644 index 0000000000000..ba10a8b5a0060 --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp @@ -0,0 +1,135 @@ +//===- IntegerDivisibilityAnalysis.cpp - Integer divisibility ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer divisibility +// inference. Operations participate in the analysis by implementing +// `InferIntDivisibilityOpInterface`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h" + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-divisibility-analysis" + +using llvm::dbgs; + +namespace mlir::dataflow { + +void IntegerDivisibilityAnalysis::setToEntryState( + IntegerDivisibilityLattice *lattice) { + propagateIfChanged(lattice, + lattice->join(IntegerDivisibility::getMinDivisibility())); +} + +LogicalResult IntegerDivisibilityAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + auto inferrable = dyn_cast(op); + if (!inferrable) { + setAllToEntryStates(results); + return success(); + } + + LLVM_DEBUG(dbgs() << "Inferring divisibility for " << *op << "\n"); + auto argDivs = llvm::map_to_vector( + operands, [](const IntegerDivisibilityLattice *lattice) { + return lattice->getValue(); + }); + auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) { + auto result = dyn_cast(v); + if (!result) { + return; + } + assert(llvm::is_contained(op->getResults(), result)); + + LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n"); + IntegerDivisibilityLattice *lattice = results[result.getResultNumber()]; + IntegerDivisibility oldDiv = lattice->getValue(); + + ChangeResult changed = lattice->join(newDiv); + + // Catch loop results with loop-variant divisibility and conservatively + // set them to divisibility 1 (no information) so we don't ratchet + // indefinitely (the dataflow analysis in MLIR doesn't attempt to work + // out trip counts and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && !oldDiv.isUninitialized() && + !(lattice->getValue() == oldDiv)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + changed |= lattice->join(IntegerDivisibility::getMinDivisibility()); + } + propagateIfChanged(lattice, changed); + }; + + inferrable.inferResultDivisibility(argDivs, joinCallback); + return success(); +} + +void IntegerDivisibilityAnalysis::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, ValueRange successorInputs, + ArrayRef argLattices) { + // Get the constant divisibility, or query the lattice for Values. + auto getDivFromOfr = [&](std::optional ofr, Block *block, + bool isUnsigned) -> uint64_t { + if (ofr.has_value()) { + if (auto constBound = getConstantIntValue(*ofr)) { + return constBound.value(); + } + auto value = cast(ofr.value()); + const IntegerDivisibilityLattice *lattice = + getLatticeElementFor(getProgramPointBefore(block), value); + if (lattice != nullptr && !lattice->getValue().isUninitialized()) { + return isUnsigned ? lattice->getValue().getValue().udiv() + : lattice->getValue().getValue().sdiv(); + } + } + return isUnsigned + ? IntegerDivisibility::getMinDivisibility().getValue().udiv() + : IntegerDivisibility::getMinDivisibility().getValue().sdiv(); + }; + + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + std::optional> ivs = loop.getLoopInductionVars(); + std::optional> lbs = loop.getLoopLowerBounds(); + std::optional> steps = loop.getLoopSteps(); + if (!ivs || !lbs || !steps) { + return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( + op, successor, successorInputs, argLattices); + } + for (auto [iv, lb, step] : llvm::zip_equal(*ivs, *lbs, *steps)) { + IntegerDivisibilityLattice *ivEntry = getLatticeElement(iv); + Block *block = iv.getParentBlock(); + uint64_t stepUDiv = getDivFromOfr(step, block, /*unsigned=*/true); + uint64_t stepSDiv = getDivFromOfr(step, block, /*unsigned=*/false); + uint64_t lbUDiv = getDivFromOfr(lb, block, /*unsigned=*/true); + uint64_t lbSDiv = getDivFromOfr(lb, block, /*unsigned=*/false); + ConstantIntDivisibility lbDiv(lbUDiv, lbSDiv); + ConstantIntDivisibility stepDiv(stepUDiv, stepSDiv); + + // Loop induction variables are computed as `lb + i * step`. The + // divisibility for `i * step` is just the divisibility of `step`, so + // the total divisibility is obtained by unioning the step divisibility + // with the lower bound divisibility, which takes the GCD of the two. + ConstantIntDivisibility ivDiv = stepDiv.getUnion(lbDiv); + propagateIfChanged(ivEntry, ivEntry->join(ivDiv)); + } + return; + } + + return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( + op, successor, successorInputs, argLattices); +} + +} // namespace mlir::dataflow diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt index 566bc060e5d38..1caf2fa396797 100644 --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAffineDialect AffineMemoryOpInterfaces.cpp AffineOps.cpp AffineValueMap.cpp + InferIntDivisibilityOpInterfaceImpl.cpp InferIntRangeInterfaceImpls.cpp ValueBoundsOpInterfaceImpl.cpp @@ -16,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineDialect MLIRArithDialect MLIRDialectUtils MLIRIR + MLIRInferIntDivisibilityOpInterface MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRLoopLikeInterface diff --git a/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp new file mode 100644 index 0000000000000..451ec98ef3737 --- /dev/null +++ b/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp @@ -0,0 +1,312 @@ +//===- InferIntDivisibilityOpInterfaceImpl.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Direct implementations of `InferIntDivisibilityOpInterface` for affine ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h" + +#include +#include + +using namespace mlir; +using namespace mlir::affine; + +namespace { + +static ConstantIntDivisibility +getDivisibilityOfOperand(Value v, IntegerDivisibility divisibility) { + if (!divisibility.isUninitialized()) + return divisibility.getValue(); + APInt intVal; + if (matchPattern(v, m_ConstantInt(&intVal))) { + uint64_t udiv = intVal.getZExtValue(); + uint64_t sdiv = std::abs(intVal.getSExtValue()); + return ConstantIntDivisibility(udiv, sdiv); + } + return ConstantIntDivisibility(1, 1); +} + +/// Visits affine expressions and recursively calculates the divisibilities of +/// each subexpression. The final divisibilities of the expression and its +/// subexpressions will be stored in the map for which a reference is provided +/// to the AffineExprDivisibilityFinder (i.e., `divisibilityMap`). +class AffineExprDivisibilityFinder + : public AffineExprVisitor { +public: + using ExprDivisibilityMap = + llvm::DenseMap; + AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap) + : divisibilityMap(divisibilityMap) {} + + ConstantIntDivisibility visitConstantExpr(AffineConstantExpr expr) { + // Constant expressions are trivial, since they are always static. + uint64_t constValue = std::abs(expr.getValue()); + return ConstantIntDivisibility(constValue, constValue); + } + + ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) { + // Dim expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) + return divisibilityMap[expr]; + return IntegerDivisibility::getMinDivisibility().getValue(); + } + + ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) { + // Symbol expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) + return divisibilityMap[expr]; + return IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Infer the divisibility of an addition or subtraction expression by + /// recursively visiting the LHS and RHS, and then unioning the results. + ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) + return divisibilityMap[expr]; + // The divisibility of an addition is the GCD of its constituents' + // divisibilities. + ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return lhsDiv.getUnion(rhsDiv); + } + + /// Infer the divisibility of a multiplication expression by recursively + /// visiting the LHS and RHS, and then multiplying the results. + ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) + return divisibilityMap[expr]; + // The divisibility of a multiplication is the product of its constituents' + // divisibilities. + ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return ConstantIntDivisibility(lhsDiv.udiv() * rhsDiv.udiv(), + lhsDiv.sdiv() * rhsDiv.sdiv()); + } + + ConstantIntDivisibility visitFloorDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + ConstantIntDivisibility visitCeilDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + /// Infer the divisibility of a mod expression. If the RHS is a constant, + /// the result divisibility is gcd(lhs_divisibility, rhs_constant), since + /// (d * k) mod c is always divisible by gcd(d, c). Furthermore, if the + /// LHS divisibility is itself divisible by the constant (i.e., d % c == 0), + /// then (d * k) mod c is always zero, represented as divisibility 0. + ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) + return divisibilityMap[expr]; + auto constRhs = dyn_cast(expr.getRHS()); + if (!constRhs || constRhs.getValue() == 0) + return ConstantIntDivisibility(1, 1); + auto constValue = static_cast(std::abs(constRhs.getValue())); + ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + // If the LHS is always a multiple of constValue, x mod constValue is + // always zero. Divisibility 0 is the lattice top ("divides everything"). + uint64_t modUDiv = (lhsDiv.udiv() % constValue == 0) + ? 0 + : std::gcd(lhsDiv.udiv(), constValue); + uint64_t modSDiv = (lhsDiv.sdiv() % constValue == 0) + ? 0 + : std::gcd(lhsDiv.sdiv(), constValue); + return ConstantIntDivisibility(modUDiv, modSDiv); + } + +private: + ConstantIntDivisibility visitInvalidExpr(AffineBinaryOpExpr expr) { + return IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Helper shared by ceildiv and floordiv implementations. Returns the minimum + /// divisibility as a fallback if the divisor is not a constant, because the + /// divisibility cannot be inferred in this case. If the divisor is a + /// constant, then this function recursively visits the dividend, and returns + /// the quotient of the dividend's divisibility with the divisor. + ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) + return divisibilityMap[expr]; + auto constRhs = dyn_cast(expr.getRHS()); + // Division by zero is undefined, so return the minimum divisibility. + if (!constRhs || constRhs.getValue() == 0) + return ConstantIntDivisibility(1, 1); + auto constValue = static_cast(std::abs(constRhs.getValue())); + ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + uint64_t divUDiv = + lhsDiv.udiv() % constValue == 0 ? lhsDiv.udiv() / constValue : 1; + uint64_t divSDiv = + lhsDiv.sdiv() % constValue == 0 ? lhsDiv.sdiv() / constValue : 1; + return ConstantIntDivisibility(divUDiv, divSDiv); + } + + ExprDivisibilityMap &divisibilityMap; +}; + +/// Returns the divisibilities of each AffineMap result based on the +/// divisibilities of its dims and symbols. The `dimAndSymbolDivisibilities` +/// should contain the divisibilities of the dims, followed by the +/// divisibilities of the symbols in ascending order by their positions. +SmallVector getResultDivisibilities( + AffineMap map, + ArrayRef dimAndSymbolDivisibilities) { + // Seed the AffineExprDivisibilityFinder with the dimAndSymbolDivisibilities. + llvm::DenseMap exprDivisibilityMap; + SmallVector inputExprs; + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumDims()), + [&](int64_t dim) { return getAffineDimExpr(dim, map.getContext()); })); + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumSymbols()), + [&](int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); })); + for (auto [expr, divisibility] : + llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) { + exprDivisibilityMap[expr] = divisibility; + } + AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap); + + // Walk each result expression and compute their divisibilities. + SmallVector resultDivisibilities; + for (AffineExpr resultExpr : map.getResults()) + resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr)); + return resultDivisibilities; +} + +/// Infer the result divisibility of an affine.min or affine.max operation +/// based on its operand divisibilities. The result divisibility is the GCD +/// of the divisibilities of each of the affine map results, because the result +/// of the affine.min/max op could be any of these results. +template +void inferAffineMinOrMaxResultDivisibility( + MinOrMaxTy minOrMaxOp, ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + static_assert(llvm::is_one_of::value, + "MinOrMaxTy must be AffineMinOp or AffineMaxOp"); + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : + llvm::zip(minOrMaxOp.getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities); + + ConstantIntDivisibility resultDivisibility = + resultDivisibilities.pop_back_val(); + for (auto divisibility : resultDivisibilities) + resultDivisibility = resultDivisibility.getUnion(divisibility); + setResultDivs(minOrMaxOp.getResult(), resultDivisibility); +} + +} // namespace + +void AffineApplyOp::inferResultDivisibility( + ArrayRef argDivs, SetIntDivisibilityFn setResultDivs) { + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : llvm::zip(getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(getMap(), operandDivisibilities); + for (auto [result, divisibility] : + llvm::zip_equal(getOperation()->getResults(), resultDivisibilities)) { + setResultDivs(result, divisibility); + } +} + +void AffineMinOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferAffineMinOrMaxResultDivisibility(*this, argDivs, setResultDivs); +} + +void AffineMaxOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferAffineMinOrMaxResultDivisibility(*this, argDivs, setResultDivs); +} + +void AffineDelinearizeIndexOp::inferResultDivisibility( + ArrayRef argDivs, SetIntDivisibilityFn setResultDivs) { + MLIRContext *ctx = getContext(); + + // Operands are: [linear_index, dynamic_basis_values...] + ConstantIntDivisibility linearDiv = + getDivisibilityOfOperand(getLinearIndex(), argDivs[0]); + + ArrayRef staticBasis = getStaticBasis(); + int64_t numResults = getNumResults(); + + // Build affine expressions for each result. + // Dim 0 = linear index, symbols = dynamic basis values. + AffineExpr linearExpr = getAffineDimExpr(0, ctx); + + // Collect operand divisibilities: [linear_index_div, dynamic_basis_divs...] + SmallVector operandDivs; + operandDivs.push_back(linearDiv); + + // Map static/dynamic basis values to affine expressions. + int64_t dynIdx = 0; + SmallVector basisExprs; + for (int64_t i = 0, e = static_cast(staticBasis.size()); i < e; + ++i) { + if (ShapedType::isDynamic(staticBasis[i])) { + basisExprs.push_back(getAffineSymbolExpr(dynIdx, ctx)); + operandDivs.push_back(getDivisibilityOfOperand(getDynamicBasis()[dynIdx], + argDivs[1 + dynIdx])); + dynIdx++; + } else { + basisExprs.push_back(getAffineConstantExpr(staticBasis[i], ctx)); + } + } + + // The computation basis skips the outer bound if present. + bool hasOuter = hasOuterBound(); + int64_t basisStart = hasOuter ? 1 : 0; + + // Each result[i] can be expressed as an affine expression of the linear + // index using the effective basis (after dropping outer bound if present). + // Effective basis B[k] = basisExprs[basisStart + k], for k = 0..N-2. + // Stride s[i] = product of B[i..N-2] = product of + // basisExprs[basisStart+i .. end]. + // + // result[0] = x floordiv s[0] + // result[i>0] = (x floordiv s[i]) mod B[i-1] + // For i=N-1, s[N-1]=1, so result[N-1] = x mod B[N-2]. + + AffineExpr stride = getAffineConstantExpr(1, ctx); + for (int64_t i = numResults - 1; i >= 0; --i) { + AffineExpr resultExpr; + if (i == 0) { + resultExpr = linearExpr.floorDiv(stride); + } else { + resultExpr = + (linearExpr.floorDiv(stride)) % basisExprs[basisStart + i - 1]; + } + + AffineMap resultMap = AffineMap::get(1, dynIdx, resultExpr, ctx); + SmallVector divs = + getResultDivisibilities(resultMap, operandDivs); + setResultDivs(getResult(i), divs[0]); + + if (i > 0) + stride = basisExprs[basisStart + i - 1] * stride; + } +} diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt index 4beb99ccfdfba..3423e11a7d0f0 100644 --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_OPTIONAL_SOURCES ArithOps.cpp ArithDialect.cpp + InferIntDivisibilityOpInterfaceImpl.cpp InferIntRangeInterfaceImpls.cpp ValueBoundsOpInterfaceImpl.cpp ) @@ -12,6 +13,7 @@ add_public_tablegen_target(MLIRArithCanonicalizationIncGen) add_mlir_dialect_library(MLIRArithDialect ArithOps.cpp ArithDialect.cpp + InferIntDivisibilityOpInterfaceImpl.cpp InferIntRangeInterfaceImpls.cpp ADDITIONAL_HEADER_DIRS @@ -24,6 +26,7 @@ add_mlir_dialect_library(MLIRArithDialect LINK_LIBS PUBLIC MLIRCastInterfaces MLIRDialect + MLIRInferIntDivisibilityOpInterface MLIRInferIntRangeCommon MLIRInferIntRangeInterface MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp new file mode 100644 index 0000000000000..b23ee108ca4a3 --- /dev/null +++ b/mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp @@ -0,0 +1,122 @@ +//===- InferIntDivisibilityOpInterfaceImpl.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Direct implementations of `InferIntDivisibilityOpInterface` for arith ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h" + +#include + +using namespace mlir; +using namespace mlir::arith; + +static ConstantIntDivisibility +getDivisibilityOfOperand(Value v, IntegerDivisibility divisibility) { + if (!divisibility.isUninitialized()) + return divisibility.getValue(); + APInt intVal; + if (matchPattern(v, m_ConstantInt(&intVal))) { + uint64_t udiv = intVal.getZExtValue(); + uint64_t sdiv = std::abs(intVal.getSExtValue()); + return ConstantIntDivisibility(udiv, sdiv); + } + return ConstantIntDivisibility(1, 1); +} + +// Result divisibility is the GCD (union) of the operand divisibilities. +template +static void +inferBinaryGCDResultDivisibility(OpTy op, ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + auto lhsDiv = getDivisibilityOfOperand(op.getLhs(), argDivs[0]); + auto rhsDiv = getDivisibilityOfOperand(op.getRhs(), argDivs[1]); + setResultDivs(op.getResult(), lhsDiv.getUnion(rhsDiv)); +} + +void ConstantOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + auto constAttr = dyn_cast_if_present(getValue()); + if (!constAttr) + return; + const APInt &value = constAttr.getValue(); + uint64_t udiv = value.getZExtValue(); + uint64_t sdiv = std::abs(value.getSExtValue()); + setResultDivs(getResult(), ConstantIntDivisibility(udiv, sdiv)); +} + +void AddIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs); +} + +void SubIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs); +} + +void MinUIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs); +} + +void MaxUIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs); +} + +void MinSIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs); +} + +void MaxSIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs); +} + +void MulIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + auto lhsDivisibility = getDivisibilityOfOperand(getLhs(), argDivs[0]); + auto rhsDivisibility = getDivisibilityOfOperand(getRhs(), argDivs[1]); + + uint64_t mulUDiv = lhsDivisibility.udiv() * rhsDivisibility.udiv(); + uint64_t mulSDiv = lhsDivisibility.sdiv() * rhsDivisibility.sdiv(); + + setResultDivs(getResult(), ConstantIntDivisibility(mulUDiv, mulSDiv)); +} + +void DivUIOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + APInt intVal; + if (!matchPattern(getRhs(), m_ConstantInt(&intVal))) + return; + + auto lhsDivisibility = getDivisibilityOfOperand(getLhs(), argDivs[0]); + + uint64_t divUDiv = lhsDivisibility.udiv() % intVal.getZExtValue() == 0 + ? lhsDivisibility.udiv() / intVal.getZExtValue() + : 1; + uint64_t divSDiv = + lhsDivisibility.sdiv() % std::abs(intVal.getSExtValue()) == 0 + ? lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()) + : 1; + + setResultDivs(getResult(), ConstantIntDivisibility(divUDiv, divSDiv)); +} + +void SelectOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + // argDivs[0] is the condition (i1), argDivs[1] is true, argDivs[2] is false. + auto trueDiv = getDivisibilityOfOperand(getTrueValue(), argDivs[1]); + auto falseDiv = getDivisibilityOfOperand(getFalseValue(), argDivs[2]); + setResultDivs(getResult(), trueDiv.getUnion(falseDiv)); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 41e890cb408ba..d20d290c45c01 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES FunctionImplementation.cpp FunctionInterfaces.cpp IndexingMapOpInterface.cpp + InferIntDivisibilityOpInterface.cpp InferIntRangeInterface.cpp InferStridedMetadataInterface.cpp InferTypeOpInterface.cpp @@ -66,6 +67,7 @@ add_mlir_library(MLIRFunctionInterfaces ) add_mlir_interface_library(IndexingMapOpInterface) +add_mlir_interface_library(InferIntDivisibilityOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_library(MLIRInferStridedMetadataInterface diff --git a/mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp b/mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp new file mode 100644 index 0000000000000..acd7cd9530b5c --- /dev/null +++ b/mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp @@ -0,0 +1,11 @@ +//===- InferIntDivisibilityOpInterface.cpp - Integer divisibility inference ==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h" + +#include "mlir/Interfaces/InferIntDivisibilityOpInterface.cpp.inc" diff --git a/mlir/test/Analysis/DataFlow/integer-divisibility.mlir b/mlir/test/Analysis/DataFlow/integer-divisibility.mlir new file mode 100644 index 0000000000000..7f9466e949d4c --- /dev/null +++ b/mlir/test/Analysis/DataFlow/integer-divisibility.mlir @@ -0,0 +1,152 @@ +// RUN: mlir-opt --split-input-file --test-int-divisibility-analysis --allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: @constant +func.func @constant() -> index { + %0 = arith.constant 8 : index + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %1 = "test.int_divisibility"(%0) : (index) -> index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @muli_constant +func.func @muli_constant(%arg0 : index) -> index { + %c4 = arith.constant 4 : index + %0 = arith.muli %arg0, %c4 : index + // CHECK: divisibility = "udiv = 4, sdiv = 4" + %1 = "test.int_divisibility"(%0) : (index) -> index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @addi_gcd_of_muli_operands +func.func @addi_gcd_of_muli_operands(%arg0 : index, %arg1 : index) -> index { + %c8 = arith.constant 8 : index + %c12 = arith.constant 12 : index + %a = arith.muli %arg0, %c8 : index + %b = arith.muli %arg1, %c12 : index + %0 = arith.addi %a, %b : index + // gcd(8, 12) = 4. + // CHECK: divisibility = "udiv = 4, sdiv = 4" + %1 = "test.int_divisibility"(%0) : (index) -> index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @addi_same_divisibility +func.func @addi_same_divisibility(%arg0 : index, %arg1 : index) -> index { + %c16 = arith.constant 16 : index + %a = arith.muli %arg0, %c16 : index + %b = arith.muli %arg1, %c16 : index + %0 = arith.addi %a, %b : index + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %1 = "test.int_divisibility"(%0) : (index) -> index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mul +func.func @affine_apply_mul(%arg0 : index) -> index { + %c2 = arith.constant 2 : index + %seed = arith.muli %arg0, %c2 : index + %0 = affine.apply affine_map<(d0) -> (d0 * 16)>(%seed) + // 2 * 16 = 32. + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %1 = "test.int_divisibility"(%0) : (index) -> index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mul_then_floordiv +func.func @affine_apply_mul_then_floordiv(%arg0 : index) -> index { + %0 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg0) + %1 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%0) + // 16 floordiv 4 = 4. + // CHECK: divisibility = "udiv = 4, sdiv = 4" + %2 = "test.int_divisibility"(%1) : (index) -> index + return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mod_zero +func.func @affine_apply_mod_zero(%arg0 : index) -> index { + %0 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg0) + %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%0) + // 16 % 16 == 0, so x mod 16 is always 0 -> divisibility 0 (lattice top). + // CHECK: divisibility = "udiv = 0, sdiv = 0" + %2 = "test.int_divisibility"(%1) : (index) -> index + return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_constant +func.func @affine_apply_constant() -> index { + %0 = affine.apply affine_map<() -> (64)>() + // CHECK: divisibility = "udiv = 64, sdiv = 64" + %1 = "test.int_divisibility"(%0) : (index) -> index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @scf_for_constant_step +func.func @scf_for_constant_step() { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c8 = arith.constant 8 : index + scf.for %iv = %c0 to %c64 step %c8 { + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %0 = "test.int_divisibility"(%iv) : (index) -> index + } + return +} + +// ----- + +// CHECK-LABEL: @scf_for_nontrivial_gcd +func.func @scf_for_nontrivial_gcd() { + %c12 = arith.constant 12 : index + %c100 = arith.constant 100 : index + %c18 = arith.constant 18 : index + scf.for %iv = %c12 to %c100 step %c18 { + // gcd(12, 18) = 6. + // CHECK: divisibility = "udiv = 6, sdiv = 6" + %0 = "test.int_divisibility"(%iv) : (index) -> index + } + return +} + +// ----- + +// CHECK-LABEL: @scf_for_coprime +func.func @scf_for_coprime() { + %c15 = arith.constant 15 : index + %c100 = arith.constant 100 : index + %c8 = arith.constant 8 : index + scf.for %iv = %c15 to %c100 step %c8 { + // gcd(15, 8) = 1. + // CHECK: divisibility = "udiv = 1, sdiv = 1" + %0 = "test.int_divisibility"(%iv) : (index) -> index + } + return +} + +// ----- + +// CHECK-LABEL: @affine_apply_mul_plus_const +func.func @affine_apply_mul_plus_const(%arg0 : index) -> index { + %c4 = arith.constant 4 : index + %seed = arith.muli %arg0, %c4 : index + %0 = affine.apply affine_map<(d0) -> (d0 * 8 + 16)>(%seed) + // seed has udiv = 4, multiplied by 8 -> 32, then +16. gcd(32,16) = 16. + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %1 = "test.int_divisibility"(%0) : (index) -> index + return %1 : index +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt index c37671ade37b3..d86af5017f24b 100644 --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_library(MLIRTestAnalysis DataFlow/TestDeadCodeAnalysis.cpp DataFlow/TestDenseBackwardDataFlowAnalysis.cpp DataFlow/TestDenseForwardDataFlowAnalysis.cpp + DataFlow/TestIntegerDivisibilityAnalysis.cpp DataFlow/TestLivenessAnalysis.cpp DataFlow/TestSparseBackwardDataFlowAnalysis.cpp DataFlow/TestStridedMetadataRangeAnalysis.cpp @@ -27,6 +28,7 @@ add_mlir_library(MLIRTestAnalysis mlir_target_link_libraries(MLIRTestAnalysis PUBLIC MLIRAffineDialect MLIRAnalysis + MLIRArithDialect MLIRFunctionInterfaces MLIRMemRefDialect MLIRPass diff --git a/mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp new file mode 100644 index 0000000000000..626cbc0fac7aa --- /dev/null +++ b/mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp @@ -0,0 +1,93 @@ +//===- TestIntegerDivisibilityAnalysis.cpp - Test int divisibility --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace { +struct TestIntegerDivisibilityAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestIntegerDivisibilityAnalysisPass) + + StringRef getArgument() const override { + return "test-int-divisibility-analysis"; + } + StringRef getDescription() const override { + return "Test integer divisibility analysis by annotating " + "'test.int_divisibility' ops with the divisibility of their " + "operand."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + Operation *rootOp = getOperation(); + MLIRContext *context = &getContext(); + + // The pass is rooted on `test.int_divisibility` ops, which are expected + // to have a single operand for which to annotate divisibility information. + SmallVector> queryOps; + rootOp->walk([&](Operation *op) { + if (op->getName().getStringRef() == "test.int_divisibility" && + op->getNumOperands() == 1) + queryOps.emplace_back(op, op->getOperand(0)); + }); + + DataFlowSolver solver; + // DeadCodeAnalysis is the base analysis that allows the solver to traverse + // control flow. It is required by IntegerDivisibilityAnalysis. + solver.load(); + // SparseConstantPropagation allows the solver to call + // visitNonControlFlowArguments and analyze arguments like loop induction + // variables. + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(rootOp))) + return signalPassFailure(); + + for (auto &[op, value] : queryOps) { + const auto *lattice = + solver.lookupState(value); + if (!lattice || lattice->getValue().isUninitialized()) { + op->setAttr("divisibility", StringAttr::get(context, "uninitialized")); + continue; + } + + // Format for the divisibility information is "udiv = X, sdiv = Y". + const auto &div = lattice->getValue().getValue(); + std::string result; + llvm::raw_string_ostream os(result); + os << "udiv = " << div.udiv() << ", sdiv = " << div.sdiv(); + op->setAttr("divisibility", StringAttr::get(context, result)); + } + } +}; +} // end anonymous namespace + +namespace mlir::test { +void registerTestIntegerDivisibilityAnalysisPass() { + PassRegistration(); +} +} // end namespace mlir::test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index c4754b3a08551..13c0934f34656 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -104,6 +104,7 @@ void registerTestComposeSubView(); void registerTestMultiBuffering(); void registerTestIRVisitorsPass(); void registerTestGenericIRVisitorsPass(); +void registerTestIntegerDivisibilityAnalysisPass(); void registerTestInterfaces(); void registerTestIRVisitorsPass(); void registerTestLastModifiedPass(); @@ -253,6 +254,7 @@ static void registerTestPasses() { mlir::test::registerTestMultiBuffering(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); + mlir::test::registerTestIntegerDivisibilityAnalysisPass(); mlir::test::registerTestInterfaces(); mlir::test::registerTestIrdlTestDialectConversionPass(); mlir::test::registerTestIRVisitorsPass();