Skip to content

Commit

Permalink
Add some utility builder functions for SPIR-V operations.
Browse files Browse the repository at this point in the history
Add builder functions for spv._address_of, spv.EntryPoint,
spv.ExecutionMode and spv.Load to make it easier to create these
operations.
Fix a minor bug in printing of spv.EntryPoint
Add a utility function to get the attribute name associated with a
decoration.

PiperOrigin-RevId: 272952846
  • Loading branch information
Mahesh Ravishankar authored and tensorflower-gardener committed Oct 4, 2019
1 parent 754ea72 commit 77a809d
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 12 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
Expand Up @@ -27,6 +27,8 @@
namespace mlir {
namespace spirv {

enum class Decoration : uint32_t;

class SPIRVDialect : public Dialect {
public:
explicit SPIRVDialect(MLIRContext *context);
Expand All @@ -36,6 +38,10 @@ class SPIRVDialect : public Dialect {
/// Checks if the given `type` is valid in SPIR-V dialect.
static bool isValidType(Type type);

/// Returns the attribute name to use when specifying decorations on results
/// of operations.
static std::string getAttributeName(Decoration decoration);

/// Parses a type registered to this dialect.
Type parseType(llvm::StringRef spec, Location loc) const override;

Expand Down
11 changes: 10 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
Expand Up @@ -324,14 +324,19 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
let arguments = (ins
SymbolRefAttr:$fn,
SPV_ExecutionModeAttr:$execution_mode,
OptionalAttr<I32ArrayAttr>:$values
I32ArrayAttr:$values
);

let results = (outs);

let verifier = [{ return success(); }];

let autogenSerialization = 0;

let builders = [OpBuilder<[{Builder *builder, OperationState &state,
FuncOp function,
spirv::ExecutionMode executionMode,
ArrayRef<int32_t> params}]>];
}

// -----
Expand Down Expand Up @@ -380,6 +385,10 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
let results = (outs
SPV_Type:$value
);

let builders = [OpBuilder<[{Builder *builder, OperationState &state,
Value *basePtr, /*optional*/IntegerAttr memory_access,
/*optional*/IntegerAttr alignment}]>];
}

// -----
Expand Down
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
Expand Up @@ -67,6 +67,9 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
let hasOpcode = 0;

let autogenSerialization = 0;

let builders = [OpBuilder<[{Builder *builder, OperationState &state,
spirv::GlobalVariableOp var}]>];
}

def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
Expand Down Expand Up @@ -174,12 +177,17 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
let arguments = (ins
SPV_ExecutionModelAttr:$execution_model,
SymbolRefAttr:$fn,
OptionalAttr<SymbolRefArrayAttr>:$interface
SymbolRefArrayAttr:$interface
);

let results = (outs);

let autogenSerialization = 0;

let builders = [OpBuilder<[{Builder *builder, OperationState &state,
spirv::ExecutionModel executionModel,
FuncOp function,
ArrayRef<Attribute> interfaceVars}]>];
}


Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
Expand Down Expand Up @@ -51,6 +52,10 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
allowUnknownOperations();
}

std::string SPIRVDialect::getAttributeName(Decoration decoration) {
return convertToSnakeCase(stringifyDecoration(decoration));
}

//===----------------------------------------------------------------------===//
// Type Parsing
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 43 additions & 10 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -545,6 +545,11 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
// spv._address_of
//===----------------------------------------------------------------------===//

void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
spirv::GlobalVariableOp var) {
build(builder, state, var.type(), builder->getSymbolRefAttr(var));
}

static ParseResult parseAddressOfOp(OpAsmParser &parser,
OperationState &state) {
SymbolRefAttr varRefAttr;
Expand Down Expand Up @@ -981,11 +986,22 @@ static void print(spirv::ControlBarrierOp op, OpAsmPrinter &printer) {
// spv.EntryPoint
//===----------------------------------------------------------------------===//

void spirv::EntryPointOp::build(Builder *builder, OperationState &state,
spirv::ExecutionModel executionModel,
FuncOp function,
ArrayRef<Attribute> interfaceVars) {
build(builder, state,
builder->getI32IntegerAttr(static_cast<int32_t>(executionModel)),
builder->getSymbolRefAttr(function),
builder->getArrayAttr(interfaceVars));
}

static ParseResult parseEntryPointOp(OpAsmParser &parser,
OperationState &state) {
spirv::ExecutionModel execModel;
SmallVector<OpAsmParser::OperandType, 0> identifiers;
SmallVector<Type, 0> idTypes;
SmallVector<Attribute, 4> interfaceVars;

SymbolRefAttr fn;
if (parseEnumAttribute(execModel, parser, state) ||
Expand All @@ -995,7 +1011,6 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,

if (!parser.parseOptionalComma()) {
// Parse the interface variables
SmallVector<Attribute, 4> interfaceVars;
do {
// The name of the interface variable attribute isnt important
auto attrName = "var_symbol";
Expand All @@ -1006,19 +1021,20 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
}
interfaceVars.push_back(var);
} while (!parser.parseOptionalComma());
state.addAttribute(kInterfaceAttrName,
parser.getBuilder().getArrayAttr(interfaceVars));
}
state.addAttribute(kInterfaceAttrName,
parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}

static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
printer << spirv::EntryPointOp::getOperationName() << " \""
<< stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
<< entryPointOp.fn();
if (auto interface = entryPointOp.interface()) {
auto interfaceVars = entryPointOp.interface().getValue();
if (!interfaceVars.empty()) {
printer << ", ";
interleaveComma(interface.getValue().getValue(), printer);
interleaveComma(interfaceVars, printer);
}
}

Expand All @@ -1032,6 +1048,15 @@ static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
// spv.ExecutionMode
//===----------------------------------------------------------------------===//

void spirv::ExecutionModeOp::build(Builder *builder, OperationState &state,
FuncOp function,
spirv::ExecutionMode executionMode,
ArrayRef<int32_t> params) {
build(builder, state, builder->getSymbolRefAttr(function),
builder->getI32IntegerAttr(static_cast<int32_t>(executionMode)),
builder->getI32ArrayAttr(params));
}

static ParseResult parseExecutionModeOp(OpAsmParser &parser,
OperationState &state) {
spirv::ExecutionMode execMode;
Expand Down Expand Up @@ -1061,13 +1086,13 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
<< execModeOp.fn() << " \""
<< stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
auto values = execModeOp.values();
if (!values) {
if (!values.size()) {
return;
}
printer << ", ";
interleaveComma(
values.getValue().cast<ArrayAttr>(), printer,
[&](Attribute a) { printer << a.cast<IntegerAttr>().getInt(); });
interleaveComma(values, printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt();
});
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1261,6 +1286,14 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
// spv.LoadOp
//===----------------------------------------------------------------------===//

void spirv::LoadOp::build(Builder *builder, OperationState &state,
Value *basePtr, IntegerAttr memory_access,
IntegerAttr alignment) {
auto ptrType = basePtr->getType().cast<spirv::PointerType>();
build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
alignment);
}

static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
// Parse the storage class specification
spirv::StorageClass storageClass;
Expand Down Expand Up @@ -1598,7 +1631,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
<< entryPointOp.fn() << "' not found in 'spv.module'";
}
if (auto interface = entryPointOp.interface()) {
for (auto varRef : interface.getValue().getValue()) {
for (Attribute varRef : interface) {
auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
if (!varSymRef) {
return entryPointOp.emitError(
Expand Down

0 comments on commit 77a809d

Please sign in to comment.