diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index c91d04ea323b..71f7bc207148 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -114,6 +114,7 @@ def __init__(self, repo_map: Dict[str, str]): "MhloPasses", ], "@mlir-hlo//stablehlo:chlo_ops": ["ChloOps",], + "@mlir-hlo//stablehlo:stablehlo_ops": ["StablehloOps",], "@mlir-hlo//:stablehlo_legalize_to_hlo_pass": ["StablehloToMhlo",], "@mlir-hlo//stablehlo:broadcast_utils": ["StablehloBroadcastUtils",], diff --git a/compiler/src/iree/compiler/InputConversion/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/CMakeLists.txt index 305c05ed7d97..474215f2d60d 100644 --- a/compiler/src/iree/compiler/InputConversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/InputConversion/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(Common) if(IREE_INPUT_MHLO) add_subdirectory(MHLO) + add_subdirectory(StableHLO) endif() if(IREE_INPUT_TORCH) add_subdirectory(TMTensor) diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel new file mode 100644 index 000000000000..7c6a8bf6e3d8 --- /dev/null +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel @@ -0,0 +1,97 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_gentbl_cc_library( + name = "PassesIncGen", + tbl_outs = [ + ( + ["--gen-pass-decls"], + "Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +iree_compiler_cc_library( + name = "PassHeaders", + hdrs = [ + "PassDetail.h", + "Passes.h", + "Passes.h.inc", + "Rewriters.h", + "TypeConversion.h", + ], + deps = [ + ":PassesIncGen", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +iree_compiler_cc_library( + name = "StableHLO", + srcs = [ + "LegalizeToLinalgUtils.cpp", + "Passes.cpp", + "StableHLOToLinalg.cpp", + "TypeConversion.cpp", + ], + hdrs = [ + "LegalizeToLinalgUtils.h", + "MapStableHLOToScalarOp.h", + "Passes.h", + ], + deps = [ + ":PassHeaders", + ":PassesIncGen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:LinalgUtils", + "@llvm-project//mlir:MLProgramDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:ShapeToStandard", + "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", + "@mlir-hlo//stablehlo:broadcast_utils", + "@mlir-hlo//stablehlo:chlo_ops", + "@mlir-hlo//stablehlo:stablehlo_ops", + ], +) diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt new file mode 100644 index 000000000000..78a3216e1487 --- /dev/null +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt @@ -0,0 +1,89 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +iree_cc_library( + NAME + PassHeaders + HDRS + "PassDetail.h" + "Passes.h" + "Passes.h.inc" + "Rewriters.h" + "TypeConversion.h" + DEPS + ::PassesIncGen + MLIRPass + MLIRTransforms + PUBLIC +) + +iree_cc_library( + NAME + StableHLO + HDRS + "LegalizeToLinalgUtils.h" + "MapStableHLOToScalarOp.h" + "Passes.h" + SRCS + "LegalizeToLinalgUtils.cpp" + "Passes.cpp" + "StableHLOToLinalg.cpp" + "TypeConversion.cpp" + DEPS + ::PassHeaders + ::PassesIncGen + ChloOps + LLVMSupport + MLIRAffineDialect + MLIRAffineUtils + MLIRArithDialect + MLIRBufferizationDialect + MLIRComplexDialect + MLIRControlFlowDialect + MLIRFuncDialect + MLIRFuncTransforms + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMLProgramDialect + MLIRMathDialect + MLIRMemRefDialect + MLIRPass + MLIRReconcileUnrealizedCasts + MLIRSCFDialect + MLIRSCFToControlFlow + MLIRSCFTransforms + MLIRShapeDialect + MLIRShapeOpsTransforms + MLIRShapeToStandard + MLIRSparseTensorDialect + MLIRSupport + MLIRTensorDialect + MLIRTensorUtils + MLIRTransforms + MLIRVectorDialect + StablehloBroadcastUtils + StablehloOps + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp new file mode 100644 index 000000000000..87f2f47de6c1 --- /dev/null +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp @@ -0,0 +1,132 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Implements utilities for lowering StableHLO/CHLO dialect to Linalg dialect. + +#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h" + +#include +#include +#include +#include + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir::iree_compiler::stablehlo { +namespace { +bool hasIntegralShapeType(Operation* op) { + auto stp = op->getOperand(0).getType().dyn_cast(); + return stp && stp.getElementType().isIntOrIndex(); +} + +} // namespace + +SmallVector getParallelAndReductionIterators( + unsigned nLoops, unsigned nReduction) { + SmallVector res(nLoops - nReduction, + utils::IteratorType::parallel); + res.append(nReduction, utils::IteratorType::reduction); + return res; +} + +SmallVector getNParallelLoopsAttrs( + unsigned nParallelLoops) { + return getParallelAndReductionIterators(nParallelLoops, 0); +} + +Value getEmptySparseTensor(OpBuilder& b, Location loc, ShapedType type, + ArrayRef dynSizes) { + return b.create(loc, type.cast(), + dynSizes, + /*copy=*/Value(), + /*memory_space=*/IntegerAttr()); +} + +Value getEmptyTensor(OpBuilder& b, Location loc, ShapedType type, + ArrayRef dynSizes) { + return b.create(loc, type.getShape(), type.getElementType(), + dynSizes, + type.cast().getEncoding()); +} + +Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType, + Operation* op, ValueRange operands) { + bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr; + // Collect the sizes for a ranked tensor to be passed as parameter to a + // new tensor initialization operation. This operation only needs the + // dynamic sizes. + SmallVector sizes; + if (resultType.hasRank() && !resultType.hasStaticShape()) { + // Ask the op for its output shape. + auto shapeSource = cast(op); + SmallVector reifiedShapes; + (void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes); + assert(reifiedShapes.size() == 1 && "Expected one reified result"); + // Construct sizes for the required dimensions. + for (const auto& en : llvm::enumerate(resultType.getShape())) { + if (en.value() != ShapedType::kDynamic) continue; + sizes.push_back(b.create( + loc, reifiedShapes[0], + ValueRange{b.create(loc, en.index())})); + } + } + return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes) + : getEmptyTensor(b, loc, resultType, sizes); +} + +Value preSparsify(Operation* op, llvm::SmallVector& values, Type rtp, + OpBuilder* b) { + // Apply for semi-ring operations that lower to elaborate code + // (any sign-op, or an integral abs-op). + // TODO(peiming, ajcbik): these all can potentially be optimized by applying + // value transform on sparse_tenosr.value memref + if (isa(op) || isa(op) || + (isa(op) && hasIntegralShapeType(op)) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op)) { + if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) && + !sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType())) + return Value(); + Location loc = op->getLoc(); + auto semiring = b->create(loc, rtp, values[0]); + Type itp = values[0].getType(); + Block* present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc); + b->setInsertionPointToStart(&semiring.getPresentRegion().front()); + values[0] = present->getArgument(0); + return semiring; + } + return Value(); +} + +Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b) { + if (semiring) { + b->create(op->getLoc(), result); + b->setInsertionPointAfter(semiring.getDefiningOp()); + return semiring; + } + return result; +} + +bool allOperandsAreScalarTensors(Operation* op) { + return llvm::all_of(op->getOperands(), [](Value operand) { + auto operandTy = operand.getType().dyn_cast(); + return operandTy && operandTy.getRank() == 0; + }); +} + +bool isInBodyOfLinalgOps(Operation* op) { + auto* parentOp = op->getParentRegion()->getParentOp(); + return parentOp->getDialect() == + parentOp->getContext()->getLoadedDialect(); +} + +} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h new file mode 100644 index 000000000000..5f36df166b9c --- /dev/null +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h @@ -0,0 +1,178 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Utils for lowering of the StableHLO dialect to the Linalg dialect. + +#ifndef IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_ +#define IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir::iree_compiler::stablehlo { + +/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes +/// are "parallel" except the last `nReduction` elements, where are "reduction" +/// attributes. +SmallVector getParallelAndReductionIterators( + unsigned nLoops, unsigned nReduction); + +/// Returns an ArrayAttr that contains `nParallelLoops` "parallel" attributes. +SmallVector getNParallelLoopsAttrs( + unsigned nParallelLoops); + +/// Generates an init sparse tensor. +Value getEmptySparseTensor(OpBuilder& b, Location loc, ShapedType type, + ArrayRef dynSizes); + +/// Generates a tensor.empty op. +Value getEmptyTensor(OpBuilder& b, Location loc, ShapedType type, + ArrayRef dynSizes); + +/// Generates an empty tensor for the result of the operation, which could be a +/// dense tensor or a sparse tensor. +Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType, + Operation* op, ValueRange operands); + +/// Sparsifies a (block of) operation(s) that cannot be handled directly +/// by the sparse compiler but has well-known semi-ring semantics. +/// +/// This yields something of the following form: +/// +/// %result = sparse_tensor.unary %values[0] +/// present={ +/// ^bb1(%val): +/// ... codegen proceeds here using %val .... +/// sparse_tensor.yield +/// } +/// absent={} +/// linalg.yield %result +Value preSparsify(Operation* op, llvm::SmallVector& values, Type rtp, + OpBuilder* b); + +/// Finalizes sparse semi-ring construction. +Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b); + +/// Returns true if all operands are tensors with rank 0. +bool allOperandsAreScalarTensors(Operation* op); + +/// Returns true if parent op is linalg. +bool isInBodyOfLinalgOps(Operation* op); + +/// Converts a HLO operation to a linalg.generic op that contains the +/// corresponding scalar operations. +template +class PointwiseToLinalgConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + // Find maximum rank / number of loops. + auto getRank = [](Value v) { + return v.getType().cast().getRank(); + }; + auto isScalar = [&](Value v) { return getRank(v) == 0; }; + auto it = llvm::find_if_not(adaptor.getOperands(), isScalar); + Value maxRankArg = + it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front(); + int64_t nloops = getRank(maxRankArg); + + // Apply only if all operands are scalar or have the same rank. Some ops, + // like `mhlo.select`, support implicit broadcasting of scalars. + if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { + int64_t r = getRank(v); + return r == 0 || r == nloops; + })) { + return rewriter.notifyMatchFailure( + op, "Operands must be os same rank or scalar."); + } + + // Find result type, if on tensors. + std::optional resultTy; + resultTy = this->typeConverter->convertType(op->getResultTypes().front()) + .template dyn_cast(); + + // Check result type compatibility. + if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops || + !(resultTy->getElementType().isSignlessIntOrFloat() || + resultTy->getElementType().isa())) { + return rewriter.notifyMatchFailure( + op, "mismatched operand/result types or iterator count"); + } + + if (allOperandsAreScalarTensors(op) && isInBodyOfLinalgOps(op)) + return failure(); + + // Find input/output values and types. + ValueRange inputs = adaptor.getOperands(); + Value output = + getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands()); + + // Create indexing maps. + AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext()); + AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops); + SmallVector maps; + for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap); + maps.push_back(idMap); + + // Build `linalg.generic` op. + bool failed = false; + auto linalgOp = rewriter.create( + loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps, + getNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, + ValueRange args) { + Type innerResultTy = getElementTypeOrSelf(output); + auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); + auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter); + Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( + op, innerResultTy, argvec, &rewriter); + if (innerResult == nullptr) { + failed = true; + } else { + innerResult = postSparsify(op, semiring, innerResult, &rewriter); + nestedBuilder.create(loc, innerResult); + } + }, + linalg::getPrunedAttributeList(op)); + if (failed) return failure(); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } +}; + +} // namespace mlir::iree_compiler::stablehlo + +#endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_ diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h b/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h new file mode 100644 index 000000000000..43e35926bc0c --- /dev/null +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h @@ -0,0 +1,1319 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_INPUTCONVERSION_STABLEHLO_MAP_STABLEHLO_TO_SCALAR_OP_H +#define IREE_COMPILER_INPUTCONVERSION_STABLEHLO_MAP_STABLEHLO_TO_SCALAR_OP_H + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir { +namespace stablehlo { +namespace impl { + +// A struct to map StableHloBinaryOpTy type to the corresponding floating-point +// and integer scalar operation types. +template +struct StableHloToScalarOp { + using FOp = void; + using IOp = void; + using UOp = void; + using COp = void; +}; + +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::arith::AddFOp; + using IOp = ::mlir::arith::AddIOp; + using UOp = ::mlir::arith::AddIOp; + using COp = ::mlir::complex::AddOp; +}; +template <> +struct StableHloToScalarOp { + using IOp = ::mlir::arith::AndIOp; + using UOp = ::mlir::arith::AndIOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::CbrtOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::arith::CmpFOp; + using IOp = ::mlir::arith::CmpIOp; + using UOp = ::mlir::arith::CmpIOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::CeilOp; +}; +template <> +struct StableHloToScalarOp { + using IOp = ::mlir::math::CountLeadingZerosOp; + using UOp = ::mlir::math::CountLeadingZerosOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::CosOp; + using COp = ::mlir::complex::CosOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::ExpOp; + using COp = ::mlir::complex::ExpOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::ExpM1Op; + using COp = ::mlir::complex::Expm1Op; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::FloorOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::LogOp; + using COp = ::mlir::complex::LogOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::Log1pOp; + using COp = ::mlir::complex::Log1pOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::arith::MulFOp; + using IOp = ::mlir::arith::MulIOp; + using UOp = ::mlir::arith::MulIOp; + using COp = ::mlir::complex::MulOp; +}; +template <> +struct StableHloToScalarOp { + using IOp = ::mlir::arith::OrIOp; + using UOp = ::mlir::arith::OrIOp; +}; +template <> +struct StableHloToScalarOp { + using IOp = ::mlir::math::CtPopOp; + using UOp = ::mlir::math::CtPopOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::RsqrtOp; + using COp = ::mlir::complex::RsqrtOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::RoundEvenOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::RoundOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::arith::SubFOp; + using IOp = ::mlir::arith::SubIOp; + using UOp = ::mlir::arith::SubIOp; + using COp = ::mlir::complex::SubOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::SqrtOp; + using COp = ::mlir::complex::SqrtOp; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::SinOp; + using COp = ::mlir::complex::SinOp; +}; +// FIXME(Jakub) +/* +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::TanOp; + using COp = ::mlir::complex::TanOp; +}; +*/ +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::Atan2Op; + using COp = ::mlir::complex::Atan2Op; +}; +template <> +struct StableHloToScalarOp { + using FOp = ::mlir::math::TanhOp; + using COp = ::mlir::complex::TanhOp; +}; +template <> +struct StableHloToScalarOp { + using IOp = ::mlir::arith::XOrIOp; + using UOp = ::mlir::arith::XOrIOp; +}; + +// Alias for the map from StableHLO binary op type to STD floating-point op +// type. +template +using ScalarFOp = typename StableHloToScalarOp::FOp; +// Alias for the map from StableHLO binary op type to STD signed integer op +// type. +template +using ScalarIOp = typename StableHloToScalarOp::IOp; +// Alias for the map from StableHLO binary op type to STD unsigned integer op +// type. +template +using ScalarUOp = typename StableHloToScalarOp::UOp; +// Alias for the map from StableHLO binary op type to STD complex op type. +template +using ScalarCOp = typename StableHloToScalarOp::COp; + +template +struct MapStableHloOpToScalarOpImpl { + Value operator()(Location /*loc*/, ArrayRef /*ResultTypes*/, + ArrayRef /*argTypes*/, ValueRange /*args*/, + OpBuilder* /*b*/) { + return nullptr; + } +}; + +template +struct MapStableHloOpToScalarOpImpl { + Value operator()(Location loc, ArrayRef resultTypes, + ArrayRef /*argTypes*/, ValueRange args, OpBuilder* b) { + return b->template create(loc, resultTypes, args, + std::nullopt); + } +}; + +template +struct MapStableHloOpToScalarOpImpl { + Value operator()(Location loc, ArrayRef resultTypes, + ArrayRef argTypes, ValueRange args, OpBuilder* b) { + Type elementType = getElementTypeOrSelf(argTypes.front()); + if (SupportedType{}(elementType)) { + return b->template create(loc, resultTypes, args, + std::nullopt); + } + return MapStableHloOpToScalarOpImpl{}(loc, resultTypes, argTypes, + args, b); + } +}; + +template +struct MapStableHloOpToScalarOpImpl { + Value operator()(Location loc, ArrayRef resultTypes, + ArrayRef argTypes, ValueRange args, OpBuilder* b) { + return MapStableHloOpToScalarOpImpl{}(loc, resultTypes, argTypes, + args, b); + } +}; + +struct IsAnyIntegerType { + bool operator()(Type t) { return t.isa(); } +}; + +struct IsSignedIntegerType { + bool operator()(Type t) { + // Pretend that signless is signed. This will change eventually. + return t.isa() && !t.isUnsignedInteger() && + !t.isSignlessInteger(1); + } +}; + +struct IsUnsignedIntegerType { + bool operator()(Type t) { + return t.isUnsignedInteger() || t.isSignlessInteger(1); + } +}; + +struct IsFloatType { + bool operator()(Type t) { return t.isa(); } +}; + +struct IsComplexType { + bool operator()(Type t) { return t.isa(); } +}; + +template