Skip to content

Commit

Permalink
Add a new OpAsmOpInterface to allow for ops to directly hook into the…
Browse files Browse the repository at this point in the history
… AsmPrinter.

This interface provides more fine-grained hooks into the AsmPrinter than the dialect interface, allowing for operations to define the asm name to use for results directly on the operations themselves. The hook is also expanded to enable defining named result "groups". Get a special name to use when printing the results of this operation.
The given callback is invoked with a specific result value that starts a
result "pack", and the name to give this result pack. To signal that a
result pack should use the default naming scheme, a None can be passed
in instead of the name.

For example, if you have an operation that has four results and you want
to split these into three distinct groups you could do the following:

  setNameFn(getResult(0), "first_result");
  setNameFn(getResult(1), "middle_results");
  setNameFn(getResult(3), ""); // use the default numbering.

This would print the operation as follows:

  %first_result, %middle_results:2, %0 = "my.op" ...

PiperOrigin-RevId: 281546873
  • Loading branch information
River707 authored and tensorflower-gardener committed Nov 20, 2019
1 parent 3c05595 commit eb41855
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 93 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(EDSC)
add_subdirectory(IR)
add_subdirectory(Transforms)
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/StandardOps/Ops.h
Expand Up @@ -24,10 +24,9 @@
#define MLIR_DIALECT_STANDARDOPS_OPS_H

#include "mlir/Analysis/CallInterfaces.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"

// Pull in all enum type definitions and utility function declarations.
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/StandardOps/Ops.td
Expand Up @@ -26,6 +26,7 @@
include "mlir/IR/OpBase.td"
#endif // OP_BASE

include "mlir/IR/OpAsmInterface.td"
include "mlir/Analysis/CallInterfaces.td"

def Std_Dialect : Dialect {
Expand Down Expand Up @@ -580,7 +581,8 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
let hasCanonicalizer = 1;
}

def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
def ConstantOp : Std_Op<"constant",
[NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "constant";

let arguments = (ins AnyAttr:$value);
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIROpAsmInterfacesIncGen)
65 changes: 65 additions & 0 deletions mlir/include/mlir/IR/OpAsmInterface.td
@@ -0,0 +1,65 @@
//===- OpAsmInterface.td - Asm Interfaces for opse ---------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains Interfaces for interacting with the AsmParser and
// AsmPrinter.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_OPASMINTERFACE
#define MLIR_OPASMINTERFACE

#ifndef OP_BASE
include "mlir/IR/OpBase.td"
#endif // OP_BASE

/// Interface for hooking into the OpAsmPrinter and OpAsmParser.
def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
let description = [{
This interface provides hooks to interact with the AsmPrinter and AsmParser
classes.
}];

let methods = [
InterfaceMethod<[{
Get a special name to use when printing the results of this operation.
The given callback is invoked with a specific result value that starts a
result "pack", and the name to give this result pack. To signal that a
result pack should use the default naming scheme, a None can be passed
in instead of the name.

For example, if you have an operation that has four results and you want
to split these into three distinct groups you could do the following:

```c++
setNameFn(getResult(0), "first_result");
setNameFn(getResult(1), "middle_results");
setNameFn(getResult(3), ""); // use the default numbering.
```

This would print the operation as follows:

```mlir
%first_result, %middle_results:2, %0 = "my.op" ...
```
}],
"void", "getAsmResultNames", (ins "OpAsmSetValueNameFn":$setNameFn)
>,
];
}

#endif // MLIR_OPASMINTERFACE
18 changes: 15 additions & 3 deletions mlir/include/mlir/IR/OpImplementation.h
Expand Up @@ -600,6 +600,10 @@ class OpAsmParser {
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//

/// A functor used to set the name of the start of a result group of an
/// operation. See 'getAsmResultNames' below for more details.
using OpAsmSetValueNameFn = function_ref<void(Value *, StringRef)>;

class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
Expand All @@ -621,11 +625,19 @@ class OpAsmDialectInterface
virtual void
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {}

/// Get a special name to use when printing the given operation. The desired
/// name should be streamed into 'os'.
virtual void getOpResultName(Operation *op, raw_ostream &os) const {}
/// Get a special name to use when printing the given operation. See
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
virtual void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const {}
};

//===--------------------------------------------------------------------===//
// Operation OpAsm interface.
//===--------------------------------------------------------------------===//

/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.h.inc"

} // end namespace mlir

#endif
58 changes: 26 additions & 32 deletions mlir/lib/Dialect/StandardOps/Ops.cpp
Expand Up @@ -44,37 +44,6 @@ using namespace mlir;
// StandardOpsDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct StdOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;

/// Get a special name to use when printing the given operation. The desired
/// name should be streamed into 'os'.
void getOpResultName(Operation *op, raw_ostream &os) const final {
if (ConstantOp constant = dyn_cast<ConstantOp>(op))
return getConstantOpResultName(constant, os);
}

/// Get a special name to use when printing the given constant.
static void getConstantOpResultName(ConstantOp op, raw_ostream &os) {
Type type = op.getType();
Attribute value = op.getValue();
if (auto intCst = value.dyn_cast<IntegerAttr>()) {
if (type.isIndex()) {
os << 'c' << intCst.getInt();
} else if (type.cast<IntegerType>().isInteger(1)) {
// i1 constants get special names.
os << (intCst.getInt() ? "true" : "false");
} else {
os << 'c' << intCst.getInt() << '_' << type;
}
} else if (type.isa<FunctionType>()) {
os << 'f';
} else {
os << "cst";
}
}
};

/// This class defines the interface for handling inlining with standard
/// operations.
struct StdInlinerInterface : public DialectInlinerInterface {
Expand Down Expand Up @@ -191,7 +160,7 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Dialect/StandardOps/Ops.cpp.inc"
>();
addInterfaces<StdInlinerInterface, StdOpAsmInterface>();
addInterfaces<StdInlinerInterface>();
}

void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
Expand Down Expand Up @@ -1183,6 +1152,31 @@ OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValue();
}

void ConstantOp::getAsmResultNames(
function_ref<void(Value *, StringRef)> setNameFn) {
Type type = getType();
if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
IntegerType intTy = type.dyn_cast<IntegerType>();

// Sugar i1 constants with 'true' and 'false'.
if (intTy && intTy.getWidth() == 1)
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));

// Otherwise, build a complex name with the value and type.
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << 'c' << intCst.getInt();
if (intTy)
specialName << '_' << type;
setNameFn(getResult(), specialName.str());

} else if (type.isa<FunctionType>()) {
setNameFn(getResult(), "f");
} else {
setNameFn(getResult(), "cst");
}
}

/// Returns true if a constant operation can be built with the given value and
/// result type.
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
Expand Down

0 comments on commit eb41855

Please sign in to comment.