Skip to content

Commit

Permalink
Add initial StableHLO to Linalg lowering pass files (#12957)
Browse files Browse the repository at this point in the history
This is a port of the MHLO to Linalg lowering from
https://github.com/tensorflow/mlir-hlo.
The iree-mhlo-fork commit used to port the conversion:
1f096a793ab7f73ae8f62deb8b6502c543763ca1.
The imported files are relicensed under the [Google
CLA](https://cla.developers.google.com/about/google-individual) from the
Apache 2.0 license (Tensorflow) to the nearly-identical Apache 2.0 with
the LLVM exceptions license (IREE).

The initial import covers the lowering of StableHLO ops that can be
trivially mapped to their MHLO counterparts. More complicated ops, like
convolutions, gather, or rng, are not ported yet.

In porting MHLO conversions and tests to operate on StableHLO ops, I
changed all namespaces, header guards, and copyright headers, and
formatted all files to match the conventions used by IREE.

Any addition modifications were a non-goal. I plan to reorganizanize and
clean this up further after the initial porting.

Issue: #12678
  • Loading branch information
kuhar authored and jpienaar committed May 1, 2023
1 parent c790a97 commit 251866b
Show file tree
Hide file tree
Showing 23 changed files with 7,426 additions and 0 deletions.
1 change: 1 addition & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, repo_map: Dict[str, str]):
"MhloPasses",
],
"@mlir-hlo//stablehlo:chlo_ops": ["ChloOps",],
"@mlir-hlo//stablehlo:stablehlo_ops": ["StablehloOps",],
"@mlir-hlo//:stablehlo_legalize_to_hlo_pass": ["StablehloToMhlo",],
"@mlir-hlo//stablehlo:broadcast_utils": ["StablehloBroadcastUtils",],

Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/InputConversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(Common)

if(IREE_INPUT_MHLO)
add_subdirectory(MHLO)
add_subdirectory(StableHLO)
endif()
if(IREE_INPUT_TORCH)
add_subdirectory(TMTensor)
Expand Down
97 changes: 97 additions & 0 deletions compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["--gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

iree_compiler_cc_library(
name = "PassHeaders",
hdrs = [
"PassDetail.h",
"Passes.h",
"Passes.h.inc",
"Rewriters.h",
"TypeConversion.h",
],
deps = [
":PassesIncGen",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

iree_compiler_cc_library(
name = "StableHLO",
srcs = [
"LegalizeToLinalgUtils.cpp",
"Passes.cpp",
"StableHLOToLinalg.cpp",
"TypeConversion.cpp",
],
hdrs = [
"LegalizeToLinalgUtils.h",
"MapStableHLOToScalarOp.h",
"Passes.h",
],
deps = [
":PassHeaders",
":PassesIncGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:MLProgramDialect",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToControlFlow",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:ShapeToStandard",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:SparseTensorDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@mlir-hlo//stablehlo:broadcast_utils",
"@mlir-hlo//stablehlo:chlo_ops",
"@mlir-hlo//stablehlo:stablehlo_ops",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

iree_cc_library(
NAME
PassHeaders
HDRS
"PassDetail.h"
"Passes.h"
"Passes.h.inc"
"Rewriters.h"
"TypeConversion.h"
DEPS
::PassesIncGen
MLIRPass
MLIRTransforms
PUBLIC
)

iree_cc_library(
NAME
StableHLO
HDRS
"LegalizeToLinalgUtils.h"
"MapStableHLOToScalarOp.h"
"Passes.h"
SRCS
"LegalizeToLinalgUtils.cpp"
"Passes.cpp"
"StableHLOToLinalg.cpp"
"TypeConversion.cpp"
DEPS
::PassHeaders
::PassesIncGen
ChloOps
LLVMSupport
MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRBufferizationDialect
MLIRComplexDialect
MLIRControlFlowDialect
MLIRFuncDialect
MLIRFuncTransforms
MLIRIR
MLIRLinalgDialect
MLIRLinalgTransforms
MLIRLinalgUtils
MLIRMLProgramDialect
MLIRMathDialect
MLIRMemRefDialect
MLIRPass
MLIRReconcileUnrealizedCasts
MLIRSCFDialect
MLIRSCFToControlFlow
MLIRSCFTransforms
MLIRShapeDialect
MLIRShapeOpsTransforms
MLIRShapeToStandard
MLIRSparseTensorDialect
MLIRSupport
MLIRTensorDialect
MLIRTensorUtils
MLIRTransforms
MLIRVectorDialect
StablehloBroadcastUtils
StablehloOps
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Implements utilities for lowering StableHLO/CHLO dialect to Linalg dialect.

#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"

#include <algorithm>
#include <numeric>
#include <string>
#include <utility>

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
namespace {
bool hasIntegralShapeType(Operation* op) {
auto stp = op->getOperand(0).getType().dyn_cast<ShapedType>();
return stp && stp.getElementType().isIntOrIndex();
}

} // namespace

SmallVector<utils::IteratorType, 3> getParallelAndReductionIterators(
unsigned nLoops, unsigned nReduction) {
SmallVector<utils::IteratorType, 3> res(nLoops - nReduction,
utils::IteratorType::parallel);
res.append(nReduction, utils::IteratorType::reduction);
return res;
}

SmallVector<utils::IteratorType, 3> getNParallelLoopsAttrs(
unsigned nParallelLoops) {
return getParallelAndReductionIterators(nParallelLoops, 0);
}

Value getEmptySparseTensor(OpBuilder& b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes) {
return b.create<bufferization::AllocTensorOp>(loc, type.cast<TensorType>(),
dynSizes,
/*copy=*/Value(),
/*memory_space=*/IntegerAttr());
}

Value getEmptyTensor(OpBuilder& b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes) {
return b.create<tensor::EmptyOp>(loc, type.getShape(), type.getElementType(),
dynSizes,
type.cast<RankedTensorType>().getEncoding());
}

Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
Operation* op, ValueRange operands) {
bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr;
// Collect the sizes for a ranked tensor to be passed as parameter to a
// new tensor initialization operation. This operation only needs the
// dynamic sizes.
SmallVector<Value> sizes;
if (resultType.hasRank() && !resultType.hasStaticShape()) {
// Ask the op for its output shape.
auto shapeSource = cast<InferShapedTypeOpInterface>(op);
SmallVector<Value, 1> reifiedShapes;
(void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes);
assert(reifiedShapes.size() == 1 && "Expected one reified result");
// Construct sizes for the required dimensions.
for (const auto& en : llvm::enumerate(resultType.getShape())) {
if (en.value() != ShapedType::kDynamic) continue;
sizes.push_back(b.create<tensor::ExtractOp>(
loc, reifiedShapes[0],
ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())}));
}
}
return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes)
: getEmptyTensor(b, loc, resultType, sizes);
}

Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
OpBuilder* b) {
// Apply for semi-ring operations that lower to elaborate code
// (any sign-op, or an integral abs-op).
// TODO(peiming, ajcbik): these all can potentially be optimized by applying
// value transform on sparse_tenosr.value memref
if (isa<mlir::stablehlo::SignOp>(op) || isa<mlir::stablehlo::NegOp>(op) ||
(isa<mlir::stablehlo::AbsOp>(op) && hasIntegralShapeType(op)) ||
isa<chlo::AsinOp>(op) || isa<chlo::AsinhOp>(op) ||
isa<chlo::AtanOp>(op) || isa<chlo::AtanhOp>(op) ||
isa<chlo::BesselI1eOp>(op) || isa<chlo::SinhOp>(op) ||
isa<chlo::TanOp>(op)) {
if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) &&
!sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType()))
return Value();
Location loc = op->getLoc();
auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]);
Type itp = values[0].getType();
Block* present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc);
b->setInsertionPointToStart(&semiring.getPresentRegion().front());
values[0] = present->getArgument(0);
return semiring;
}
return Value();
}

Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b) {
if (semiring) {
b->create<sparse_tensor::YieldOp>(op->getLoc(), result);
b->setInsertionPointAfter(semiring.getDefiningOp());
return semiring;
}
return result;
}

bool allOperandsAreScalarTensors(Operation* op) {
return llvm::all_of(op->getOperands(), [](Value operand) {
auto operandTy = operand.getType().dyn_cast<ShapedType>();
return operandTy && operandTy.getRank() == 0;
});
}

bool isInBodyOfLinalgOps(Operation* op) {
auto* parentOp = op->getParentRegion()->getParentOp();
return parentOp->getDialect() ==
parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
}

} // namespace mlir::iree_compiler::stablehlo
Loading

0 comments on commit 251866b

Please sign in to comment.