-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir] Add a shape function library op
Op with mapping from ops to corresponding shape functions for those op in the library and mechanism to associate shape functions to functions. The mapping of operand to shape function is kept separate from the shape functions themselves as the operation is associated to the shape function and not vice versa, and one could have a common library of shape functions that can be used in different contexts. Use fully qualified names and require a name for shape fn lib ops for now and an explicit print/parse (based around the generated one & GPU module op ones). This commit reverts d9da4c3. Fixes missing headers (don't know how that was working locally). Differential Revision: https://reviews.llvm.org/D91672
- Loading branch information
Showing
10 changed files
with
242 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics | ||
|
||
// expected-remark@+1 {{associated shape function: same_result_shape}} | ||
func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32> | ||
attributes {shape.function = @shape_lib::@same_result_shape} { | ||
// expected-remark@+1 {{no associated way}} | ||
%0 = tanh %arg : tensor<10x20xf32> | ||
// expected-remark@+1 {{associated shape function: same_result_shape}} | ||
%1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32> | ||
return %1 : tensor<10x20xf32> | ||
} | ||
|
||
// The shape function library with some local functions. | ||
shape.function_library @shape_lib { | ||
// Test shape function that returns the shape of input arg as result shape. | ||
func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { | ||
%0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape | ||
return %0 : !shape.shape | ||
} | ||
} mapping { | ||
test.same_operand_result_type = @same_result_shape | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
add_subdirectory(Affine) | ||
add_subdirectory(Shape) | ||
add_subdirectory(SPIRV) | ||
add_subdirectory(Test) | ||
add_subdirectory(Tosa) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Exclude tests from libMLIR.so | ||
add_mlir_library(MLIRShapeTestPasses | ||
TestShapeFunctions.cpp | ||
|
||
EXCLUDE_FROM_LIBMLIR | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRPass | ||
MLIRShape | ||
MLIRSupport | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
//===- TestShapeFunctions.cpp - Passes to test shape function ------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include <queue> | ||
|
||
#include "mlir/Dialect/Shape/IR/Shape.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/Interfaces/InferTypeOpInterface.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
/// This is a pass that reports shape functions associated with ops. | ||
struct ReportShapeFnPass | ||
: public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { | ||
void runOnOperation() override; | ||
}; | ||
} // end anonymous namespace | ||
|
||
void ReportShapeFnPass::runOnOperation() { | ||
auto module = getOperation(); | ||
|
||
// Lookup shape function library. | ||
shape::FunctionLibraryOp shapeFnLib = nullptr; | ||
for (auto lib : module.getOps<shape::FunctionLibraryOp>()) { | ||
if (shapeFnLib) { | ||
lib.emitError("duplicate shape library op") | ||
.attachNote(shapeFnLib.getLoc()) | ||
<< "previous mapping"; | ||
return signalPassFailure(); | ||
} | ||
shapeFnLib = lib; | ||
}; | ||
|
||
// Report the shape function available to refine the op. | ||
auto shapeFnId = Identifier::get("shape.function", &getContext()); | ||
auto remarkShapeFn = [&](Operation *op) { | ||
if (op->isKnownTerminator()) | ||
return; | ||
if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { | ||
op->emitRemark() << "implements InferType op interface"; | ||
} else if (auto fn = shapeFnLib.getShapeFunction(op)) { | ||
op->emitRemark() << "associated shape function: " << fn.getName(); | ||
} else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { | ||
auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); | ||
op->emitRemark() << "associated shape function: " << fn.getName(); | ||
} else { | ||
op->emitRemark() << "no associated way to refine shape"; | ||
} | ||
}; | ||
|
||
module.getBodyRegion().walk([&](FuncOp func) { | ||
// Skip ops in the shape function library. | ||
if (isa<shape::FunctionLibraryOp>(func.getParentOp())) | ||
return; | ||
|
||
func.walk([&](Operation *op) { remarkShapeFn(op); }); | ||
}); | ||
} | ||
|
||
namespace mlir { | ||
void registerShapeFunctionTestPasses() { | ||
PassRegistration<ReportShapeFnPass>( | ||
"test-shape-function-report", | ||
"Test pass to report associated shape functions"); | ||
} | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters