-
Notifications
You must be signed in to change notification settings - Fork 556
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial StableHLO to Linalg lowering pass files (#12957)
This is a port of the MHLO to Linalg lowering from https://github.com/tensorflow/mlir-hlo. The iree-mhlo-fork commit used to port the conversion: 1f096a793ab7f73ae8f62deb8b6502c543763ca1. The imported files are relicensed under the [Google CLA](https://cla.developers.google.com/about/google-individual) from the Apache 2.0 license (Tensorflow) to the nearly-identical Apache 2.0 with the LLVM exceptions license (IREE). The initial import covers the lowering of StableHLO ops that can be trivially mapped to their MHLO counterparts. More complicated ops, like convolutions, gather, or rng, are not ported yet. In porting MHLO conversions and tests to operate on StableHLO ops, I changed all namespaces, header guards, and copyright headers, and formatted all files to match the conventions used by IREE. Any addition modifications were a non-goal. I plan to reorganizanize and clean this up further after the initial porting. Issue: #12678
- Loading branch information
Showing
23 changed files
with
7,426 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library") | ||
|
||
package( | ||
default_visibility = ["//visibility:public"], | ||
features = ["layering_check"], | ||
licenses = ["notice"], # Apache 2.0 | ||
) | ||
|
||
iree_gentbl_cc_library( | ||
name = "PassesIncGen", | ||
tbl_outs = [ | ||
( | ||
["--gen-pass-decls"], | ||
"Passes.h.inc", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "Passes.td", | ||
deps = [ | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) | ||
|
||
iree_compiler_cc_library( | ||
name = "PassHeaders", | ||
hdrs = [ | ||
"PassDetail.h", | ||
"Passes.h", | ||
"Passes.h.inc", | ||
"Rewriters.h", | ||
"TypeConversion.h", | ||
], | ||
deps = [ | ||
":PassesIncGen", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) | ||
|
||
iree_compiler_cc_library( | ||
name = "StableHLO", | ||
srcs = [ | ||
"LegalizeToLinalgUtils.cpp", | ||
"Passes.cpp", | ||
"StableHLOToLinalg.cpp", | ||
"TypeConversion.cpp", | ||
], | ||
hdrs = [ | ||
"LegalizeToLinalgUtils.h", | ||
"MapStableHLOToScalarOp.h", | ||
"Passes.h", | ||
], | ||
deps = [ | ||
":PassHeaders", | ||
":PassesIncGen", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:AffineDialect", | ||
"@llvm-project//mlir:AffineUtils", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:BufferizationDialect", | ||
"@llvm-project//mlir:ComplexDialect", | ||
"@llvm-project//mlir:ControlFlowDialect", | ||
"@llvm-project//mlir:DialectUtils", | ||
"@llvm-project//mlir:FuncDialect", | ||
"@llvm-project//mlir:FuncTransforms", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:LinalgDialect", | ||
"@llvm-project//mlir:LinalgTransforms", | ||
"@llvm-project//mlir:LinalgUtils", | ||
"@llvm-project//mlir:MLProgramDialect", | ||
"@llvm-project//mlir:MathDialect", | ||
"@llvm-project//mlir:MemRefDialect", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:ReconcileUnrealizedCasts", | ||
"@llvm-project//mlir:SCFDialect", | ||
"@llvm-project//mlir:SCFToControlFlow", | ||
"@llvm-project//mlir:SCFTransforms", | ||
"@llvm-project//mlir:ShapeDialect", | ||
"@llvm-project//mlir:ShapeToStandard", | ||
"@llvm-project//mlir:ShapeTransforms", | ||
"@llvm-project//mlir:SparseTensorDialect", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:TensorDialect", | ||
"@llvm-project//mlir:TensorUtils", | ||
"@llvm-project//mlir:Transforms", | ||
"@llvm-project//mlir:VectorDialect", | ||
"@mlir-hlo//stablehlo:broadcast_utils", | ||
"@mlir-hlo//stablehlo:chlo_ops", | ||
"@mlir-hlo//stablehlo:stablehlo_ops", | ||
], | ||
) |
89 changes: 89 additions & 0 deletions
89
compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
################################################################################ | ||
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # | ||
# compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel # | ||
# # | ||
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # | ||
# CMake-only content. # | ||
# # | ||
# To disable autogeneration for this file entirely, delete this header. # | ||
################################################################################ | ||
|
||
iree_add_all_subdirs() | ||
|
||
iree_tablegen_library( | ||
NAME | ||
PassesIncGen | ||
TD_FILE | ||
"Passes.td" | ||
OUTS | ||
--gen-pass-decls Passes.h.inc | ||
) | ||
|
||
iree_cc_library( | ||
NAME | ||
PassHeaders | ||
HDRS | ||
"PassDetail.h" | ||
"Passes.h" | ||
"Passes.h.inc" | ||
"Rewriters.h" | ||
"TypeConversion.h" | ||
DEPS | ||
::PassesIncGen | ||
MLIRPass | ||
MLIRTransforms | ||
PUBLIC | ||
) | ||
|
||
iree_cc_library( | ||
NAME | ||
StableHLO | ||
HDRS | ||
"LegalizeToLinalgUtils.h" | ||
"MapStableHLOToScalarOp.h" | ||
"Passes.h" | ||
SRCS | ||
"LegalizeToLinalgUtils.cpp" | ||
"Passes.cpp" | ||
"StableHLOToLinalg.cpp" | ||
"TypeConversion.cpp" | ||
DEPS | ||
::PassHeaders | ||
::PassesIncGen | ||
ChloOps | ||
LLVMSupport | ||
MLIRAffineDialect | ||
MLIRAffineUtils | ||
MLIRArithDialect | ||
MLIRBufferizationDialect | ||
MLIRComplexDialect | ||
MLIRControlFlowDialect | ||
MLIRFuncDialect | ||
MLIRFuncTransforms | ||
MLIRIR | ||
MLIRLinalgDialect | ||
MLIRLinalgTransforms | ||
MLIRLinalgUtils | ||
MLIRMLProgramDialect | ||
MLIRMathDialect | ||
MLIRMemRefDialect | ||
MLIRPass | ||
MLIRReconcileUnrealizedCasts | ||
MLIRSCFDialect | ||
MLIRSCFToControlFlow | ||
MLIRSCFTransforms | ||
MLIRShapeDialect | ||
MLIRShapeOpsTransforms | ||
MLIRShapeToStandard | ||
MLIRSparseTensorDialect | ||
MLIRSupport | ||
MLIRTensorDialect | ||
MLIRTensorUtils | ||
MLIRTransforms | ||
MLIRVectorDialect | ||
StablehloBroadcastUtils | ||
StablehloOps | ||
PUBLIC | ||
) | ||
|
||
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### |
132 changes: 132 additions & 0 deletions
132
compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
// Copyright 2022 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
// Implements utilities for lowering StableHLO/CHLO dialect to Linalg dialect. | ||
|
||
#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h" | ||
|
||
#include <algorithm> | ||
#include <numeric> | ||
#include <string> | ||
#include <utility> | ||
|
||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "stablehlo/dialect/ChloOps.h" | ||
#include "stablehlo/dialect/StablehloOps.h" | ||
|
||
namespace mlir::iree_compiler::stablehlo { | ||
namespace { | ||
bool hasIntegralShapeType(Operation* op) { | ||
auto stp = op->getOperand(0).getType().dyn_cast<ShapedType>(); | ||
return stp && stp.getElementType().isIntOrIndex(); | ||
} | ||
|
||
} // namespace | ||
|
||
SmallVector<utils::IteratorType, 3> getParallelAndReductionIterators( | ||
unsigned nLoops, unsigned nReduction) { | ||
SmallVector<utils::IteratorType, 3> res(nLoops - nReduction, | ||
utils::IteratorType::parallel); | ||
res.append(nReduction, utils::IteratorType::reduction); | ||
return res; | ||
} | ||
|
||
SmallVector<utils::IteratorType, 3> getNParallelLoopsAttrs( | ||
unsigned nParallelLoops) { | ||
return getParallelAndReductionIterators(nParallelLoops, 0); | ||
} | ||
|
||
Value getEmptySparseTensor(OpBuilder& b, Location loc, ShapedType type, | ||
ArrayRef<Value> dynSizes) { | ||
return b.create<bufferization::AllocTensorOp>(loc, type.cast<TensorType>(), | ||
dynSizes, | ||
/*copy=*/Value(), | ||
/*memory_space=*/IntegerAttr()); | ||
} | ||
|
||
Value getEmptyTensor(OpBuilder& b, Location loc, ShapedType type, | ||
ArrayRef<Value> dynSizes) { | ||
return b.create<tensor::EmptyOp>(loc, type.getShape(), type.getElementType(), | ||
dynSizes, | ||
type.cast<RankedTensorType>().getEncoding()); | ||
} | ||
|
||
Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType, | ||
Operation* op, ValueRange operands) { | ||
bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr; | ||
// Collect the sizes for a ranked tensor to be passed as parameter to a | ||
// new tensor initialization operation. This operation only needs the | ||
// dynamic sizes. | ||
SmallVector<Value> sizes; | ||
if (resultType.hasRank() && !resultType.hasStaticShape()) { | ||
// Ask the op for its output shape. | ||
auto shapeSource = cast<InferShapedTypeOpInterface>(op); | ||
SmallVector<Value, 1> reifiedShapes; | ||
(void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes); | ||
assert(reifiedShapes.size() == 1 && "Expected one reified result"); | ||
// Construct sizes for the required dimensions. | ||
for (const auto& en : llvm::enumerate(resultType.getShape())) { | ||
if (en.value() != ShapedType::kDynamic) continue; | ||
sizes.push_back(b.create<tensor::ExtractOp>( | ||
loc, reifiedShapes[0], | ||
ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())})); | ||
} | ||
} | ||
return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes) | ||
: getEmptyTensor(b, loc, resultType, sizes); | ||
} | ||
|
||
Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp, | ||
OpBuilder* b) { | ||
// Apply for semi-ring operations that lower to elaborate code | ||
// (any sign-op, or an integral abs-op). | ||
// TODO(peiming, ajcbik): these all can potentially be optimized by applying | ||
// value transform on sparse_tenosr.value memref | ||
if (isa<mlir::stablehlo::SignOp>(op) || isa<mlir::stablehlo::NegOp>(op) || | ||
(isa<mlir::stablehlo::AbsOp>(op) && hasIntegralShapeType(op)) || | ||
isa<chlo::AsinOp>(op) || isa<chlo::AsinhOp>(op) || | ||
isa<chlo::AtanOp>(op) || isa<chlo::AtanhOp>(op) || | ||
isa<chlo::BesselI1eOp>(op) || isa<chlo::SinhOp>(op) || | ||
isa<chlo::TanOp>(op)) { | ||
if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) && | ||
!sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType())) | ||
return Value(); | ||
Location loc = op->getLoc(); | ||
auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]); | ||
Type itp = values[0].getType(); | ||
Block* present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc); | ||
b->setInsertionPointToStart(&semiring.getPresentRegion().front()); | ||
values[0] = present->getArgument(0); | ||
return semiring; | ||
} | ||
return Value(); | ||
} | ||
|
||
Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b) { | ||
if (semiring) { | ||
b->create<sparse_tensor::YieldOp>(op->getLoc(), result); | ||
b->setInsertionPointAfter(semiring.getDefiningOp()); | ||
return semiring; | ||
} | ||
return result; | ||
} | ||
|
||
bool allOperandsAreScalarTensors(Operation* op) { | ||
return llvm::all_of(op->getOperands(), [](Value operand) { | ||
auto operandTy = operand.getType().dyn_cast<ShapedType>(); | ||
return operandTy && operandTy.getRank() == 0; | ||
}); | ||
} | ||
|
||
bool isInBodyOfLinalgOps(Operation* op) { | ||
auto* parentOp = op->getParentRegion()->getParentOp(); | ||
return parentOp->getDialect() == | ||
parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::stablehlo |
Oops, something went wrong.