Skip to content

Commit

Permalink
[mlir] Add a shape function library op
Browse files Browse the repository at this point in the history
Op with mapping from ops to corresponding shape functions for those op
in the library and mechanism to associate shape functions to functions.
The mapping of operand to shape function is kept separate from the shape
functions themselves as the operation is associated to the shape
function and not vice versa, and one could have a common library of
shape functions that can be used in different contexts.

Use fully qualified names and require a name for shape fn lib ops for
now and an explicit print/parse (based around the generated one & GPU
module op ones).

This commit reverts d9da4c3. Fixes
missing headers (don't know how that was working locally).

Differential Revision: https://reviews.llvm.org/D91672
  • Loading branch information
jpienaar committed Nov 29, 2020
1 parent 5408fdc commit e534cee
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 1 deletion.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Shape/IR/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef MLIR_SHAPE_IR_SHAPE_H
#define MLIR_SHAPE_IR_SHAPE_H

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand Down
61 changes: 60 additions & 1 deletion mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"

//===----------------------------------------------------------------------===//
// Shape op definitions
Expand Down Expand Up @@ -492,7 +493,7 @@ def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> {
}

def Shape_YieldOp : Shape_Op<"yield",
[HasParent<"ReduceOp">,
[HasParent<"ReduceOp, FunctionLibraryOp">,
NoSideEffect,
ReturnLike,
Terminator]> {
Expand Down Expand Up @@ -780,4 +781,62 @@ def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// Shape collection ops.
//===----------------------------------------------------------------------===//

def Shape_FunctionLibraryOp : Shape_Op<"function_library",
[AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
SingleBlockImplicitTerminator<"ShapeFunctionLibraryTerminatorOp">]> {
let summary = "Represents shape functions and corresponding ops";
let description = [{
Represents a list of shape functions and the ops whose shape transfer
functions they represent.

Example:

```mlir
shape.function_library {
func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
return %0 : !shape.shape
}
} mapping {
std.atan = @same_result_shape
}
```
}];

let arguments = (ins SymbolNameAttr:$sym_name,
OptionalAttr<StrAttr>:$sym_visibility);
let arguments = (ins DictionaryAttr:$mapping);
let regions = (region AnyRegion:$body);

let extraClassDeclaration = [{
/// Returns an associated shape function for an operation if defined.
FuncOp getShapeFunction(Operation *op);
}];

let builders = [OpBuilderDAG<(ins "StringRef":$name)>];
let skipDefaultBuilders = 1;

let printer = [{ ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}

//===----------------------------------------------------------------------===//
// ShapeFunctionLibraryTerminatorOp
//===----------------------------------------------------------------------===//

def ShapeFunctionLibraryTerminatorOp : Shape_Op<"fn_lib_terminator",
[Terminator, HasParent<"FunctionLibraryOp">]> {
let summary = "A pseudo op that marks the end of a shape function library";
let description = [{
`shape_fn_lib_terminator` is a special pseudo terminator operation for the
shape function library. It has no semantic meaning beyond keeping the body
well-formed.
}];
let assemblyFormat = "attr-dict";
}

#endif // SHAPE_OPS
59 changes: 59 additions & 0 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,65 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
return builder.getIndexTensorAttr(extents);
}

//===----------------------------------------------------------------------===//
// FunctionLibraryOp
//===----------------------------------------------------------------------===//

void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
StringRef name) {
ensureTerminator(*result.addRegion(), builder, result.location);
result.attributes.push_back(builder.getNamedAttr(
::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
}

FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
auto attr = mapping()
.get(op->getName().getIdentifier())
.dyn_cast_or_null<FlatSymbolRefAttr>();
if (!attr)
return nullptr;
return lookupSymbol<FuncOp>(attr);
}

ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
OperationState &result) {
// Parse the op name.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();

if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();

auto *bodyRegion = result.addRegion();
if (parser.parseRegion(*bodyRegion))
return failure();

FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
result.location);
if (parser.parseKeyword("mapping"))
return failure();

DictionaryAttr mappingAttr;
if (parser.parseAttribute(mappingAttr,
parser.getBuilder().getType<NoneType>(), "mapping",
result.attributes))
return failure();
return success();
}

void print(OpAsmPrinter &p, FunctionLibraryOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.getName());
p.printOptionalAttrDictWithKeyword(
op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
p << " mapping ";
p.printAttributeWithoutType(op.mappingAttr());
}

//===----------------------------------------------------------------------===//
// GetExtentOp
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Analysis/test-shape-fn-report.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics

// expected-remark@+1 {{associated shape function: same_result_shape}}
func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
attributes {shape.function = @shape_lib::@same_result_shape} {
// expected-remark@+1 {{no associated way}}
%0 = tanh %arg : tensor<10x20xf32>
// expected-remark@+1 {{associated shape function: same_result_shape}}
%1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32>
return %1 : tensor<10x20xf32>
}

// The shape function library with some local functions.
shape.function_library @shape_lib {
// Test shape function that returns the shape of input arg as result shape.
func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
return %0 : !shape.shape
}
} mapping {
test.same_operand_result_type = @same_result_shape
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(Test)
add_subdirectory(Tosa)
16 changes: 16 additions & 0 deletions mlir/test/lib/Dialect/Shape/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRShapeTestPasses
TestShapeFunctions.cpp

EXCLUDE_FROM_LIBMLIR

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRShape
MLIRSupport
)
73 changes: 73 additions & 0 deletions mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//===- TestShapeFunctions.cpp - Passes to test shape function ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <queue>

#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {
/// This is a pass that reports shape functions associated with ops.
struct ReportShapeFnPass
: public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
} // end anonymous namespace

void ReportShapeFnPass::runOnOperation() {
auto module = getOperation();

// Lookup shape function library.
shape::FunctionLibraryOp shapeFnLib = nullptr;
for (auto lib : module.getOps<shape::FunctionLibraryOp>()) {
if (shapeFnLib) {
lib.emitError("duplicate shape library op")
.attachNote(shapeFnLib.getLoc())
<< "previous mapping";
return signalPassFailure();
}
shapeFnLib = lib;
};

// Report the shape function available to refine the op.
auto shapeFnId = Identifier::get("shape.function", &getContext());
auto remarkShapeFn = [&](Operation *op) {
if (op->isKnownTerminator())
return;
if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
op->emitRemark() << "implements InferType op interface";
} else if (auto fn = shapeFnLib.getShapeFunction(op)) {
op->emitRemark() << "associated shape function: " << fn.getName();
} else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
op->emitRemark() << "associated shape function: " << fn.getName();
} else {
op->emitRemark() << "no associated way to refine shape";
}
};

module.getBodyRegion().walk([&](FuncOp func) {
// Skip ops in the shape function library.
if (isa<shape::FunctionLibraryOp>(func.getParentOp()))
return;

func.walk([&](Operation *op) { remarkShapeFn(op); });
});
}

namespace mlir {
void registerShapeFunctionTestPasses() {
PassRegistration<ReportShapeFnPass>(
"test-shape-function-report",
"Test pass to report associated shape functions");
}
} // namespace mlir
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def VariadicWithSameOperandsResult :
let results = (outs AnySignlessInteger:$result);
}

def SameOperandsResultType : TEST_Op<
"same_operand_result_type", [SameOperandsAndResultType]> {
let arguments = (ins AnyTensor:$operand);
let results = (outs AnyTensor:$result);
}

//===----------------------------------------------------------------------===//
// Test Results
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/tools/mlir-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ set(LLVM_LINK_COMPONENTS
if(MLIR_INCLUDE_TESTS)
set(test_libs
MLIRAffineTransformsTestPasses
MLIRShapeTestPasses
MLIRSPIRVTestPasses
MLIRTestDialect
MLIRTestIR
Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace mlir {
void registerConvertToTargetEnvPass();
void registerPassManagerTestPass();
void registerPrintOpAvailabilityPass();
void registerShapeFunctionTestPasses();
void registerSideEffectTestPasses();
void registerSliceAnalysisTestPass();
void registerSymbolTestPasses();
Expand Down Expand Up @@ -98,6 +99,7 @@ void registerTestPasses() {
registerConvertToTargetEnvPass();
registerPassManagerTestPass();
registerPrintOpAvailabilityPass();
registerShapeFunctionTestPasses();
registerSideEffectTestPasses();
registerSliceAnalysisTestPass();
registerSymbolTestPasses();
Expand Down

0 comments on commit e534cee

Please sign in to comment.