191 changes: 191 additions & 0 deletions mlir/include/mlir/IR/SideEffects.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
//===-- SideEffects.td - Side Effect Interfaces ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a set of interfaces that can be used to define information
// about what effects are applied by an operation.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_SIDEEFFECTS
#define MLIR_IR_SIDEEFFECTS

include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// EffectOpInterface
//===----------------------------------------------------------------------===//

// A base interface used to query information about the side effects applied to
// an operation. This template class takes the name of the derived interface
// class, as well as the name of the base effect class.
class EffectOpInterfaceBase<string name, string baseEffect>
: OpInterface<name> {
let methods = [
InterfaceMethod<[{
Collects all of the operation's effects into `effects`.
}],
"void", "getEffects",
(ins "SmallVectorImpl<SideEffects::EffectInstance<"
# baseEffect # ">> &":$effects)
>,
InterfaceMethod<[{
Collects all of the operation's effects into `effects`.
}],
"void", "getEffectsOnValue",
(ins "Value":$value,
"SmallVectorImpl<SideEffects::EffectInstance<"
# baseEffect # ">> &":$effects), [{
op.getEffects(effects);
llvm::erase_if(effects, [&](auto &it) {
return it.getValue() != value;
});
}]
>,
InterfaceMethod<[{
Collects all of the effects that are exhibited by this operation on the
given resource and place them in 'effects'.
}],
"void", "getEffectsOnResource",
(ins "SideEffects::Resource *":$resource,
"SmallVectorImpl<SideEffects::EffectInstance<"
# baseEffect # ">> &":$effects), [{
op.getEffects(effects);
llvm::erase_if(effects, [&](auto &it) {
return it.getResource() != resource;
});
}]
>
];

let extraClassDeclaration = [{
/// Collect all of the effect instances that correspond to the given
/// `Effect` and place them in 'effects'.
template <typename Effect> void getEffects(
SmallVectorImpl<SideEffects::EffectInstance<
}] # baseEffect # [{>> &effects) {
getEffects(effects);
llvm::erase_if(effects, [&](auto &it) {
return !llvm::isa<Effect>(it.getEffect());
});
}

/// Returns true if this operation exhibits the given effect.
template <typename Effect> bool hasEffect() {
SmallVector<SideEffects::EffectInstance<}] # baseEffect # [{>, 4> effects;
getEffects(effects);
return llvm::any_of(effects, [](const auto &it) {
return llvm::isa<Effect>(it.getEffect());
});
}

/// Returns if this operation only has the given effect.
template <typename Effect> bool onlyHasEffect() {
SmallVector<SideEffects::EffectInstance<}] # baseEffect # [{>, 4> effects;
getEffects(effects);
return !effects.empty() && llvm::all_of(effects, [](const auto &it) {
return isa<Effect>(it.getEffect());
});
}

/// Returns if this operation has no effects.
bool hasNoEffect() {
SmallVector<SideEffects::EffectInstance<}] # baseEffect # [{>, 4> effects;
getEffects(effects);
return effects.empty();
}
}];

// The base effect name of this interface.
string baseEffectName = baseEffect;
}

// This class is the general base side effect class. This is used by derived
// effect interfaces to define their effects.
class SideEffect<EffectOpInterfaceBase interface, string effectName,
string resourceName> : OpVariableDecorator {
/// The parent interface that the effect belongs to.
string interfaceTrait = interface.trait;

/// The name of the base effects class.
string baseEffect = interface.baseEffectName;

/// The derived effect that is being applied.
string effect = effectName;

/// The resource that the effect is being applied to.
string resource = resourceName;
}

// This class is the base used for specifying effects applied to an operation.
class SideEffectsTraitBase<EffectOpInterfaceBase parentInterface,
list<SideEffect> staticEffects>
: OpInterfaceTrait<""> {
/// The name of the interface trait to use.
let trait = parentInterface.trait;

/// The derived effects being applied.
list<SideEffect> effects = staticEffects;
}

//===----------------------------------------------------------------------===//
// MemoryEffects
//===----------------------------------------------------------------------===//

// This def represents the definition for the memory effects interface. Users
// should generally not use this directly, and should instead use
// `MemoryEffects`.
def MemoryEffectsOpInterface
: EffectOpInterfaceBase<"MemoryEffectOpInterface",
"MemoryEffects::Effect"> {
let description = [{
An interface used to query information about the memory effects applied by
an operation.
}];
}

// The base class for defining specific memory effects.
class MemoryEffect<string effectName, string resourceName>
: SideEffect<MemoryEffectsOpInterface, effectName, resourceName>;

// This class represents the trait for memory effects that may be placed on
// operations.
class MemoryEffects<list<MemoryEffect> effects = []>
: SideEffectsTraitBase<MemoryEffectsOpInterface, effects>;

//===----------------------------------------------------------------------===//
// Effects

// The following effect indicates that the operation allocates from some
// resource. An 'allocate' effect implies only allocation of the resource, and
// not any visible mutation or dereference.
class MemAlloc<string resourceName>
: MemoryEffect<"MemoryEffects::Allocate", resourceName>;
def MemAlloc : MemAlloc<"">;

// The following effect indicates that the operation frees some resource that
// has been allocated. A 'free' effect implies only de-allocation of the
// resource, and not any visible allocation, mutation or dereference.
class MemFree<string resourceName>
: MemoryEffect<"MemoryEffects::Free", resourceName>;
def MemFree : MemFree<"">;

// The following effect indicates that the operation reads from some
// resource. A 'read' effect implies only dereferencing of the resource, and
// not any visible mutation.
class MemRead<string resourceName>
: MemoryEffect<"MemoryEffects::Read", resourceName>;
def MemRead : MemRead<"">;

// The following effect indicates that the operation writes to some
// resource. A 'write' effect implies only mutating a resource, and not any
// visible dereference or read.
class MemWrite<string resourceName>
: MemoryEffect<"MemoryEffects::Write", resourceName>;
def MemWrite : MemWrite<"">;

#endif // MLIR_IR_SIDEEFFECTS
31 changes: 31 additions & 0 deletions mlir/include/mlir/TableGen/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,34 @@ class Operator {
// Returns this op's C++ class name prefixed with namespaces.
std::string getQualCppClassName() const;

/// A class used to represent the decorators of an operator variable, i.e.
/// argument or result.
struct VariableDecorator {
public:
explicit VariableDecorator(const llvm::Record *def) : def(def) {}
const llvm::Record &getDef() const { return *def; }

protected:
// The TableGen definition of this decorator.
const llvm::Record *def;
};

// A utility iterator over a list of variable decorators.
struct VariableDecoratorIterator
: public llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)> {
using reference = VariableDecorator;

/// Initializes the iterator to the specified iterator.
VariableDecoratorIterator(llvm::Init *const *it)
: llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)>(it,
&unwrap) {}
static VariableDecorator unwrap(llvm::Init *init);
};
using var_decorator_iterator = VariableDecoratorIterator;
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;

using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;

Expand Down Expand Up @@ -84,6 +112,8 @@ class Operator {
TypeConstraint getResultTypeConstraint(int index) const;
// Returns the `index`-th result's name.
StringRef getResultName(int index) const;
// Returns the `index`-th result's decorators.
var_decorator_range getResultDecorators(int index) const;

// Returns the number of variadic results in this operation.
unsigned getNumVariadicResults() const;
Expand Down Expand Up @@ -128,6 +158,7 @@ class Operator {
// Op argument (attribute or operand) accessors.
Argument getArg(int index) const;
StringRef getArgName(int index) const;
var_decorator_range getArgDecorators(int index) const;

// Returns the trait wrapper for the given MLIR C++ `trait`.
// TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
Expand Down
55 changes: 55 additions & 0 deletions mlir/include/mlir/TableGen/SideEffects.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Wrapper around side effect related classes defined in TableGen.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_
#define MLIR_TABLEGEN_SIDEEFFECTS_H_

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Operator.h"

namespace mlir {
namespace tblgen {

// This class represents a specific instance of an effect that is being
// exhibited.
class SideEffect : public Operator::VariableDecorator {
public:
// Return the name of the C++ effect.
StringRef getName() const;

// Return the name of the base C++ effect.
StringRef getBaseName() const;

// Return the name of the parent interface trait.
StringRef getInterfaceTrait() const;

// Return the name of the resource class.
StringRef getResource() const;

static bool classof(const Operator::VariableDecorator *var);
};

// This class represents an instance of a side effect interface applied to an
// operation. This is a wrapper around an OpInterfaceTrait that also includes
// the effects that are applied.
class SideEffectTrait : public InterfaceOpTrait {
public:
// Return the effects that are attached to the side effect interface.
Operator::var_decorator_range getEffects() const;

static bool classof(const OpTrait *t);
};

} // end namespace tblgen
} // end namespace mlir

#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_
1 change: 1 addition & 0 deletions mlir/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_mlir_library(MLIRIR

DEPENDS
MLIRCallOpInterfacesIncGen
MLIRSideEffectOpInterfacesIncGen
MLIROpAsmInterfacesIncGen
)
target_link_libraries(MLIRIR
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/IR/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,17 @@ LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
return verifyValueSizeAttr(op, attrName, /*isOperand=*/false);
}

//===----------------------------------------------------------------------===//
// SideEffect Interfaces

/// Include the definitions of the side effect interfaces.
#include "mlir/IR/SideEffectInterfaces.cpp.inc"

bool MemoryEffects::Effect::classof(const SideEffects::Effect *effect) {
return isa<Allocate>(effect) || isa<Free>(effect) || isa<Read>(effect) ||
isa<Write>(effect);
}

//===----------------------------------------------------------------------===//
// BinaryOp implementation
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/TableGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_llvm_library(LLVMMLIRTableGen
OpTrait.cpp
Pattern.cpp
Predicate.cpp
SideEffects.cpp
Successor.cpp
Type.cpp

Expand Down
37 changes: 34 additions & 3 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ StringRef tblgen::Operator::getResultName(int index) const {
return results->getArgNameStr(index);
}

auto tblgen::Operator::getResultDecorators(int index) const
-> var_decorator_range {
Record *result =
cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
if (!result->isSubClassOf("OpVariable"))
return var_decorator_range(nullptr, nullptr);
return *result->getValueAsListInit("decorators");
}

unsigned tblgen::Operator::getNumVariadicResults() const {
return std::count_if(
results.begin(), results.end(),
Expand Down Expand Up @@ -138,6 +147,15 @@ StringRef tblgen::Operator::getArgName(int index) const {
return argumentValues->getArgName(index)->getValue();
}

auto tblgen::Operator::getArgDecorators(int index) const
-> var_decorator_range {
Record *arg =
cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
if (!arg->isSubClassOf("OpVariable"))
return var_decorator_range(nullptr, nullptr);
return *arg->getValueAsListInit("decorators");
}

const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
for (const auto &t : traits) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
Expand Down Expand Up @@ -226,6 +244,7 @@ void tblgen::Operator::populateOpStructure() {
auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto attrClass = recordKeeper.getClass("Attr");
auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
auto opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;

DagInit *argumentValues = def.getValueAsDag("arguments");
Expand All @@ -240,10 +259,12 @@ void tblgen::Operator::populateOpStructure() {
PrintFatalError(def.getLoc(),
Twine("undefined type for argument #") + Twine(i));
Record *argDef = argDefInit->getDef();
if (argDef->isSubClassOf(opVarClass))
argDef = argDef->getValueAsDef("constraint");

if (argDef->isSubClassOf(typeConstraintClass)) {
operands.push_back(
NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
NamedTypeConstraint{givenName, TypeConstraint(argDef)});
} else if (argDef->isSubClassOf(attrClass)) {
if (givenName.empty())
PrintFatalError(argDef->getLoc(), "attributes must be named");
Expand Down Expand Up @@ -285,6 +306,8 @@ void tblgen::Operator::populateOpStructure() {
int operandIndex = 0, attrIndex = 0;
for (unsigned i = 0; i != numArgs; ++i) {
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
if (argDef->isSubClassOf(opVarClass))
argDef = argDef->getValueAsDef("constraint");

if (argDef->isSubClassOf(typeConstraintClass)) {
arguments.emplace_back(&operands[operandIndex++]);
Expand All @@ -303,11 +326,14 @@ void tblgen::Operator::populateOpStructure() {
// Handle results.
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
auto name = resultsDag->getArgNameStr(i);
auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
if (!resultDef) {
auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
if (!resultInit) {
PrintFatalError(def.getLoc(),
Twine("undefined type for result #") + Twine(i));
}
auto *resultDef = resultInit->getDef();
if (resultDef->isSubClassOf(opVarClass))
resultDef = resultDef->getValueAsDef("constraint");
results.push_back({name, TypeConstraint(resultDef)});
}

Expand Down Expand Up @@ -394,3 +420,8 @@ void tblgen::Operator::print(llvm::raw_ostream &os) const {
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
}
}

auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
-> VariableDecorator {
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
}
51 changes: 51 additions & 0 deletions mlir/lib/TableGen/SideEffects.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===- SideEffects.cpp - SideEffect classes -------------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/SideEffects.h"
#include "llvm/TableGen/Record.h"

using namespace mlir;
using namespace mlir::tblgen;

//===----------------------------------------------------------------------===//
// SideEffect
//===----------------------------------------------------------------------===//

StringRef SideEffect::getName() const {
return def->getValueAsString("effect");
}

StringRef SideEffect::getBaseName() const {
return def->getValueAsString("baseEffect");
}

StringRef SideEffect::getInterfaceTrait() const {
return def->getValueAsString("interfaceTrait");
}

StringRef SideEffect::getResource() const {
auto value = def->getValueAsString("resource");
return value.empty() ? "::mlir::SideEffects::DefaultResource" : value;
}

bool SideEffect::classof(const Operator::VariableDecorator *var) {
return var->getDef().isSubClassOf("SideEffect");
}

//===----------------------------------------------------------------------===//
// SideEffectsTrait
//===----------------------------------------------------------------------===//

Operator::var_decorator_range SideEffectTrait::getEffects() const {
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("effects"));
return {listInit->begin(), listInit->end()};
}

bool SideEffectTrait::classof(const OpTrait *t) {
return t->getDef().isSubClassOf("SideEffectsTraitBase");
}
1 change: 1 addition & 0 deletions mlir/lib/Translation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ add_mlir_library(MLIRTranslation
target_link_libraries(MLIRTranslation
PUBLIC
LLVMSupport
MLIRIR
)
20 changes: 20 additions & 0 deletions mlir/test/IR/test-side-effects.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: mlir-opt %s -test-side-effects -verify-diagnostics

// expected-remark@+1 {{operation has no memory effects}}
%0 = "test.side_effect_op"() {} : () -> i32

// expected-remark@+2 {{found an instance of 'read' on resource '<Default>'}}
// expected-remark@+1 {{found an instance of 'free' on resource '<Default>'}}
%1 = "test.side_effect_op"() {effects = [
{effect="read"}, {effect="free"}
]} : () -> i32

// expected-remark@+1 {{found an instance of 'write' on resource '<Test>'}}
%2 = "test.side_effect_op"() {effects = [
{effect="write", test_resource}
]} : () -> i32

// expected-remark@+1 {{found an instance of 'allocate' on a value, on resource '<Test>'}}
%3 = "test.side_effect_op"() {effects = [
{effect="allocate", on_result, test_resource}
]} : () -> i32
1 change: 1 addition & 0 deletions mlir/test/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_llvm_library(MLIRTestIR
TestFunc.cpp
TestMatchers.cpp
TestSideEffects.cpp
TestSymbolUses.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
58 changes: 58 additions & 0 deletions mlir/test/lib/IR/TestSideEffects.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//===- TestSidEffects.cpp - Pass to test side effects ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TestDialect.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {
struct SideEffectsPass : public ModulePass<SideEffectsPass> {
void runOnModule() override {
auto module = getModule();

// Walk operations detecting side effects.
SmallVector<MemoryEffects::EffectInstance, 8> effects;
module.walk([&](MemoryEffectOpInterface op) {
effects.clear();
op.getEffects(effects);

// Check to see if this operation has any memory effects.
if (effects.empty()) {
op.emitRemark() << "operation has no memory effects";
return;
}

for (MemoryEffects::EffectInstance instance : effects) {
auto diag = op.emitRemark() << "found an instance of ";

if (isa<MemoryEffects::Allocate>(instance.getEffect()))
diag << "'allocate'";
else if (isa<MemoryEffects::Free>(instance.getEffect()))
diag << "'free'";
else if (isa<MemoryEffects::Read>(instance.getEffect()))
diag << "'read'";
else if (isa<MemoryEffects::Write>(instance.getEffect()))
diag << "'write'";

if (instance.getValue())
diag << " on a value,";

diag << " on resource '" << instance.getResource()->getName() << "'";
}
});
}
};
} // end anonymous namespace

namespace mlir {
void registerSideEffectTestPasses() {
PassRegistration<SideEffectsPass>("test-side-effects",
"Test side effects interfaces");
}
} // namespace mlir
50 changes: 50 additions & 0 deletions mlir/test/lib/TestDialect/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,56 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
return success();
}

//===----------------------------------------------------------------------===//
// Test SideEffect interfaces
//===----------------------------------------------------------------------===//

namespace {
/// A test resource for side effects.
struct TestResource : public SideEffects::Resource::Base<TestResource> {
StringRef getName() final { return "<Test>"; }
};
} // end anonymous namespace

void SideEffectOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Check for an effects attribute on the op instance.
ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
if (!effectsAttr)
return;

// If there is one, it is an array of dictionary attributes that hold
// information on the effects of this operation.
for (Attribute element : effectsAttr) {
DictionaryAttr effectElement = element.cast<DictionaryAttr>();

// Get the specific memory effect.
MemoryEffects::Effect *effect =
llvm::StringSwitch<MemoryEffects::Effect *>(
effectElement.get("effect").cast<StringAttr>().getValue())
.Case("allocate", MemoryEffects::Allocate::get())
.Case("free", MemoryEffects::Free::get())
.Case("read", MemoryEffects::Read::get())
.Case("write", MemoryEffects::Write::get());

// Check for a result to affect.
Value value;
if (effectElement.get("on_result"))
value = getResult();

// Check for a non-default resource to use.
SideEffects::Resource *resource = SideEffects::DefaultResource::get();
if (effectElement.get("test_resource"))
resource = TestResource::get();

effects.emplace_back(effect, value, resource);
}
}

//===----------------------------------------------------------------------===//
// Dialect Registration
//===----------------------------------------------------------------------===//

// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;

Expand Down
10 changes: 10 additions & 0 deletions mlir/test/lib/TestDialect/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SideEffects.td"
include "mlir/Analysis/ControlFlowInterfaces.td"
include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/InferTypeOpInterface.td"
Expand Down Expand Up @@ -1176,4 +1177,13 @@ def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
let assemblyFormat = "$targets attr-dict";
}

//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//

def SideEffectOp : TEST_Op<"side_effect_op",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let results = (outs AnyType:$result);
}

#endif // TEST_OPS
26 changes: 26 additions & 0 deletions mlir/test/mlir-tblgen/op-side-effects.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s

include "mlir/IR/SideEffects.td"

def TEST_Dialect : Dialect {
let name = "test";
}
class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TEST_Dialect, mnemonic, traits>;

def SideEffectOpA : TEST_Op<"side_effect_op_a"> {
let arguments = (ins Arg<Variadic<AnyMemRef>, "", [MemRead]>);
let results = (outs Res<AnyMemRef, "", [MemAlloc<"CustomResource">]>);
}

def SideEffectOpB : TEST_Op<"side_effect_op_b",
[MemoryEffects<[MemWrite<"CustomResource">]>]>;

// CHECK: void SideEffectOpA::getEffects
// CHECK: for (Value value : getODSOperands(0))
// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get());
// CHECK: for (Value value : getODSResults(0))
// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get());

// CHECK: void SideEffectOpB::getEffects
// CHECK: effects.emplace_back(MemoryEffects::Write::get(), CustomResource::get());
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void registerMemRefBoundCheck();
void registerPassManagerTestPass();
void registerPatternsTestPass();
void registerPrintOpAvailabilityPass();
void registerSideEffectTestPasses();
void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
Expand Down Expand Up @@ -89,6 +90,7 @@ void registerTestPasses() {
registerPassManagerTestPass();
registerPatternsTestPass();
registerPrintOpAvailabilityPass();
registerSideEffectTestPasses();
registerSimpleParametricTilingPass();
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
Expand Down
74 changes: 74 additions & 0 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Signals.h"
Expand Down Expand Up @@ -280,6 +281,9 @@ class OpEmitter {
// Generate the OpInterface methods.
void genOpInterfaceMethods();

// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();

private:
// The TableGen record for this op.
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
Expand Down Expand Up @@ -321,6 +325,7 @@ OpEmitter::OpEmitter(const Operator &op)
genFolderDecls();
genOpInterfaceMethods();
generateOpFormat(op, opClass);
genSideEffectInterfaceMethods();
}

void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
Expand Down Expand Up @@ -1161,6 +1166,75 @@ void OpEmitter::genOpInterfaceMethods() {
}
}

void OpEmitter::genSideEffectInterfaceMethods() {
enum EffectKind { Operand, Result, Static };
struct EffectLocation {
/// The effect applied.
SideEffect effect;

/// The index if the kind is either operand or result.
unsigned index : 30;

/// The kind of the location.
EffectKind kind : 2;
};

StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
auto resolveDecorators = [&](Operator::var_decorator_range decorators,
unsigned index, EffectKind kind) {
for (auto decorator : decorators)
if (SideEffect *effect = dyn_cast<SideEffect>(&decorator))
interfaceEffects[effect->getInterfaceTrait()].push_back(
EffectLocation{*effect, index, kind});
};

// Collect effects that were specified via:
/// Traits.
for (const auto &trait : op.getTraits())
if (const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait))
resolveDecorators(opTrait->getEffects(), /*index=*/0, EffectKind::Static);
/// Operands.
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
if (op.getArg(i).is<NamedTypeConstraint *>()) {
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
++operandIt;
}
}
/// Results.
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);

for (auto &it : interfaceEffects) {
StringRef baseEffect = it.second.front().effect.getBaseName();
auto effectsParam =
llvm::formatv(
"SmallVectorImpl<SideEffects::EffectInstance<{0}>> &effects",
baseEffect)
.str();

// Generate the 'getEffects' method.
auto &getEffects = opClass.newMethod("void", "getEffects", effectsParam);
auto &body = getEffects.body();

// Add effect instances for each of the locations marked on the operation.
for (auto &location : it.second) {
if (location.kind != EffectKind::Static) {
body << " for (Value value : getODS"
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
<< "(" << location.index << "))\n ";
}

body << " effects.emplace_back(" << location.effect.getName()
<< "::get()";

// If the effect isn't static, it has a specific value attached to it.
if (location.kind != EffectKind::Static)
body << ", value";
body << ", " << location.effect.getResource() << "::get());\n";
}
}
}

void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser") ||
hasStringAttribute(def, "assemblyFormat"))
Expand Down