Skip to content

Commit

Permalink
[spirv] Add support for function calls.
Browse files Browse the repository at this point in the history
Add spv.FunctionCall operation and (de)serialization.

Closes tensorflow/mlir#137

COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#137 from denis0x0D:sandbox/function_call_op e2e6f07d21e7f23e8b44c7df8a8ab784f3356ce4
PiperOrigin-RevId: 269437167
  • Loading branch information
denis0x0D authored and tensorflower-gardener committed Sep 16, 2019
1 parent 9619ba1 commit 8a34d5d
Show file tree
Hide file tree
Showing 7 changed files with 419 additions and 14 deletions.
17 changes: 9 additions & 8 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
Expand Up @@ -100,6 +100,7 @@ def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite",
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
Expand Down Expand Up @@ -161,13 +162,13 @@ def SPV_OpcodeAttr :
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
Expand Down Expand Up @@ -1113,7 +1114,7 @@ def SPV_SamplerUseAttr:
// Check that an op can only be used within the scope of a FuncOp.
def InFunctionScope : PredOpTrait<
"op must appear in a 'func' block",
CPred<"llvm::isa_and_nonnull<FuncOp>($_op.getParentOp())">>;
CPred<"($_op.getParentOfType<FuncOp>())">>;

// Check that an op can only be used within the scope of a SPIR-V ModuleOp.
def InModuleScope : PredOpTrait<
Expand Down
46 changes: 46 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
Expand Up @@ -151,6 +151,52 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {

// -----

def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> {
let summary = "Call a function.";

let description = [{
Result Type is the type of the return value of the function. It must be
the same as the Return Type operand of the Function Type operand of the
Function operand.

Function is an OpFunction instruction. This could be a forward
reference.

Argument N is the object to copy to parameter N of Function.

Note: A forward call is possible because there is no missing type
information: Result Type must match the Return Type of the function, and
the calling argument types must match the formal parameter types.

### Custom assembly form

``` {.ebnf}
function-call-op ::= `spv.FunctionCall` function-id `(` ssa-use-list `)`
`:` function-type
```

For example:

```
spv.FunctionCall @f_void(%arg0) : (i32) -> ()
%0 = spv.FunctionCall @f_iadd(%arg0, %arg1) : (i32, i32) -> i32
```
}];

let arguments = (ins
SymbolRefAttr:$callee,
Variadic<SPV_Type>:$arguments
);

let results = (outs
SPV_Optional<SPV_Type>:$result
);

let autogenSerialization = 0;
}

// -----

def SPV_LoopOp : SPV_Op<"loop"> {
let summary = "Define a structured loop.";

Expand Down
103 changes: 103 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -35,6 +35,7 @@ using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
Expand Down Expand Up @@ -912,6 +913,108 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
[&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
}

//===----------------------------------------------------------------------===//
// spv.FuncionCall
//===----------------------------------------------------------------------===//

static ParseResult parseFunctionCallOp(OpAsmParser *parser,
OperationState *state) {
SymbolRefAttr calleeAttr;
FunctionType type;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto loc = parser->getNameLoc();
if (parser->parseAttribute(calleeAttr, kCallee, state->attributes) ||
parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser->parseColonType(type)) {
return failure();
}

auto funcType = type.dyn_cast<FunctionType>();
if (!funcType) {
return parser->emitError(loc, "expected function type, but provided ")
<< type;
}

if (funcType.getNumResults() > 1) {
return parser->emitError(loc, "expected callee function to have 0 or 1 "
"result, but provided ")
<< funcType.getNumResults();
}

return failure(parser->addTypesToList(funcType.getResults(), state->types) ||
parser->resolveOperands(operands, funcType.getInputs(), loc,
state->operands));
}

static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter *printer) {
SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
Type functionType =
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());

*printer << spirv::FunctionCallOp::getOperationName() << ' '
<< functionCallOp.getAttr(kCallee) << '(';
printer->printOperands(functionCallOp.arguments());
*printer << ") : " << functionType;
}

static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
auto fnName = functionCallOp.callee();

auto moduleOp = functionCallOp.getParentOfType<spirv::ModuleOp>();
if (!moduleOp) {
return functionCallOp.emitOpError(
"must appear in a function inside 'spv.module'");
}

auto funcOp = moduleOp.lookupSymbol<FuncOp>(fnName);
if (!funcOp) {
return functionCallOp.emitOpError("callee function '")
<< fnName << "' not found in 'spv.module'";
}

auto functionType = funcOp.getType();

if (functionCallOp.getNumResults() > 1) {
return functionCallOp.emitOpError(
"expected callee function to have 0 or 1 result, but provided ")
<< functionCallOp.getNumResults();
}

if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
return functionCallOp.emitOpError(
"has incorrect number of operands for callee: expected ")
<< functionType.getNumInputs() << ", but provided "
<< functionCallOp.getNumOperands();
}

for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
if (functionCallOp.getOperand(i)->getType() != functionType.getInput(i)) {
return functionCallOp.emitOpError(
"operand type mismatch: expected operand type ")
<< functionType.getInput(i) << ", but provided "
<< functionCallOp.getOperand(i)->getType()
<< " for operand number " << i;
}
}

if (functionType.getNumResults() != functionCallOp.getNumResults()) {
return functionCallOp.emitOpError(
"has incorrect number of results has for callee: expected ")
<< functionType.getNumResults() << ", but provided "
<< functionCallOp.getNumResults();
}

if (functionCallOp.getNumResults() &&
(functionCallOp.getResult(0)->getType() != functionType.getResult(0))) {
return functionCallOp.emitOpError("result type mismatch: expected ")
<< functionType.getResult(0) << ", but provided "
<< functionCallOp.getResult(0)->getType();
}

return success();
}

//===----------------------------------------------------------------------===//
// spv.globalVariable
//===----------------------------------------------------------------------===//
Expand Down
62 changes: 58 additions & 4 deletions mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
Expand Up @@ -128,6 +128,11 @@ class Deserializer {
/// Gets the constant's attribute and type associated with the given <id>.
Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);

/// Returns a symbol to be used for the function name with the given
/// result <id>. This tries to use the function's OpName if
/// exists; otherwise creates one based on the <id>.
std::string getFunctionSymbol(uint32_t id);

/// Returns a symbol to be used for the specialization constant with the given
/// result <id>. This tries to use the specialization constant's OpName if
/// exists; otherwise creates one based on the <id>.
Expand Down Expand Up @@ -637,10 +642,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
<< functionType << " and return type " << resultType << " specified";
}

std::string fnName = nameMap.lookup(operands[1]).str();
if (fnName.empty()) {
fnName = "spirv_fn_" + std::to_string(operands[2]);
}
std::string fnName = getFunctionSymbol(operands[1]);
auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
ArrayRef<NamedAttribute>());
curFunction = funcMap[operands[1]] = funcOp;
Expand Down Expand Up @@ -762,6 +764,14 @@ Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
return constIt->getSecond();
}

std::string Deserializer::getFunctionSymbol(uint32_t id) {
auto funcName = nameMap.lookup(id).str();
if (funcName.empty()) {
funcName = "spirv_fn_" + std::to_string(id);
}
return funcName;
}

std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
Expand Down Expand Up @@ -1779,6 +1789,50 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
if (operands.size() < 3) {
return emitError(unknownLoc,
"OpFunctionCall must have at least 3 operands");
}

Type resultType = getType(operands[0]);
if (!resultType) {
return emitError(unknownLoc, "undefined result type from <id> ")
<< operands[0];
}

auto resultID = operands[1];
auto functionID = operands[2];

auto functionName = getFunctionSymbol(functionID);

llvm::SmallVector<Value *, 4> arguments;
for (auto operand : llvm::drop_begin(operands, 3)) {
auto *value = getValue(operand);
if (!value) {
return emitError(unknownLoc, "unknown <id> ")
<< operand << " used by OpFunctionCall";
}
arguments.push_back(value);
}

SmallVector<Type, 1> resultTypes;
if (!isVoidType(resultType)) {
resultTypes.push_back(resultType);
}

auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName),
arguments);

if (!resultTypes.empty()) {
valueMap[resultID] = opFunctionCall.getResult(0);
}
return success();
}

// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
// various Deserializer::processOp<...>() specializations.
#define GET_DESERIALIZATION_FNS
Expand Down
47 changes: 45 additions & 2 deletions mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
Expand Up @@ -131,6 +131,10 @@ class Serializer {
return funcIDMap.lookup(fnName);
}

/// Gets the <id> for the function with the given name. Assigns the next
/// available <id> if the function haven't been deserialized.
uint32_t getOrCreateFunctionID(StringRef fnName);

void processCapability();

void processExtension();
Expand Down Expand Up @@ -392,6 +396,15 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
// Module structure
//===----------------------------------------------------------------------===//

uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
auto funcID = funcIDMap.lookup(fnName);
if (!funcID) {
funcID = getNextID();
funcIDMap[fnName] = funcID;
}
return funcID;
}

void Serializer::processCapability() {
auto caps = module.getAttrOfType<ArrayAttr>("capabilities");
if (!caps)
Expand Down Expand Up @@ -537,8 +550,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
return failure();
}
operands.push_back(resTypeID);
auto funcID = getNextID();
funcIDMap[op.getName()] = funcID;
auto funcID = getOrCreateFunctionID(op.getName());
operands.push_back(funcID);
// TODO : Support other function control options.
operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
Expand Down Expand Up @@ -1461,6 +1473,37 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
operands);
}

template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
auto funcName = op.callee();
uint32_t resTypeID = 0;

llvm::SmallVector<Type, 1> resultTypes(op.getResultTypes());
if (failed(processType(op.getLoc(),
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
resTypeID))) {
return failure();
}

auto funcID = getOrCreateFunctionID(funcName);
auto funcCallID = getNextID();
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};

for (auto *value : op.arguments()) {
auto valueID = findValueID(value);
assert(valueID && "cannot find a value for spv.FunctionCall");
operands.push_back(valueID);
}

if (!resultTypes.empty()) {
valueIDMap[op.getResult(0)] = funcCallID;
}

return encodeInstructionInto(functions, spirv::Opcode::OpFunctionCall,
operands);
}

// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
// various Serializer::processOp<...>() specializations.
#define GET_SERIALIZATION_FNS
Expand Down

0 comments on commit 8a34d5d

Please sign in to comment.