diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index ee18cf815e4a7..c27f9aa91332c 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -86,6 +86,13 @@ Value spirv::Deserializer::getValue(uint32_t id) { if (auto undef = getUndefType(id)) { return spirv::UndefOp::create(opBuilder, unknownLoc, undef); } + if (std::optional + graphConstantARMInfo = getGraphConstantARM(id)) { + IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID; + Type resultType = graphConstantARMInfo->resultType; + return spirv::GraphConstantARMOp::create(opBuilder, unknownLoc, resultType, + graphConstantID); + } return valueMap.lookup(id); } @@ -180,6 +187,7 @@ LogicalResult spirv::Deserializer::processInstruction( case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: case spirv::Opcode::OpTypeTensorARM: + case spirv::Opcode::OpTypeGraphARM: case spirv::Opcode::OpTypeCooperativeMatrixKHR: return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: @@ -208,12 +216,26 @@ LogicalResult spirv::Deserializer::processInstruction( return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantNull: return processConstantNull(operands); + case spirv::Opcode::OpGraphConstantARM: + return processGraphConstantARM(operands); case spirv::Opcode::OpDecorate: return processDecoration(operands); case spirv::Opcode::OpMemberDecorate: return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); + case spirv::Opcode::OpGraphEntryPointARM: + if (deferInstructions) { + deferredInstructions.emplace_back(opcode, operands); + return success(); + } + return processGraphEntryPointARM(operands); + case spirv::Opcode::OpGraphARM: + return processGraphARM(operands); + case spirv::Opcode::OpGraphSetOutputARM: + return processOpGraphSetOutputARM(operands); + case spirv::Opcode::OpGraphEndARM: + return processGraphEndARM(operands); case spirv::Opcode::OpLabel: return processLabel(operands); case spirv::Opcode::OpBranch: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 3625dd2eb7dd3..0c3e87a8dc1ef 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef operands) { return success(); } +LogicalResult +spirv::Deserializer::processGraphEntryPointARM(ArrayRef operands) { + if (operands.size() < 2) { + return emitError(unknownLoc, + "missing graph defintion in OpGraphEntryPointARM"); + } + + unsigned wordIndex = 0; + uint32_t graphID = operands[wordIndex++]; + if (!graphMap.contains(graphID)) { + return emitError(unknownLoc, + "missing graph definition/declaration with id ") + << graphID; + } + + spirv::GraphARMOp graphARM = graphMap[graphID]; + StringRef name = decodeStringLiteral(operands, wordIndex); + graphARM.setSymName(name); + graphARM.setEntryPoint(true); + + SmallVector interface; + for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) { + if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) { + interface.push_back(SymbolRefAttr::get(arg.getOperation())); + } else { + return emitError(unknownLoc, "undefined result ") + << operands[wordIndex] << " while decoding OpGraphEntryPoint"; + } + } + + // RAII guard to reset the insertion point to previous value when done. + OpBuilder::InsertionGuard insertionGuard(opBuilder); + opBuilder.setInsertionPoint(graphARM); + opBuilder.create( + unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name), + opBuilder.getArrayAttr(interface)); + + return success(); +} + +LogicalResult +spirv::Deserializer::processGraphARM(ArrayRef operands) { + if (curGraph) { + return emitError(unknownLoc, "found graph inside graph"); + } + // Get the result type. + if (operands.size() < 2) { + return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters"); + } + + Type type = getType(operands[0]); + if (!type || !isa(type)) { + return emitError(unknownLoc, "unknown graph type from ") + << operands[0]; + } + auto graphType = cast(type); + if (graphType.getNumResults() <= 0) { + return emitError(unknownLoc, "expected at least one result"); + } + + uint32_t graphID = operands[1]; + if (graphMap.count(graphID)) { + return emitError(unknownLoc, "duplicate graph definition/declaration"); + } + + std::string graphName = getGraphSymbol(graphID); + auto graphOp = + opBuilder.create(unknownLoc, graphName, graphType); + curGraph = graphMap[graphID] = graphOp; + Block *entryBlock = graphOp.addEntryBlock(); + LLVM_DEBUG({ + logger.startLine() + << "//===-------------------------------------------===//\n"; + logger.startLine() << "[graph] name: " << graphName << "\n"; + logger.startLine() << "[graph] type: " << graphType << "\n"; + logger.startLine() << "[graph] ID: " << graphID << "\n"; + logger.startLine() << "[graph] entry block: " << entryBlock << "\n"; + logger.indent(); + }); + + // Parse the op argument instructions. + for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) { + spirv::Opcode opcode; + ArrayRef operands; + if (failed(sliceInstruction(opcode, operands, + spirv::Opcode::OpGraphInputARM))) { + return failure(); + } + if (operands.size() != 3) { + return emitError(unknownLoc, "expected result type, result and " + "input index for OpGraphInputARM"); + } + + Type argDefinedType = getType(operands[0]); + if (!argDefinedType) { + return emitError(unknownLoc, "unknown operand type ") << operands[0]; + } + + if (argDefinedType != argType) { + return emitError(unknownLoc, + "mismatch in argument type between graph type " + "definition ") + << graphType << " and argument type definition " << argDefinedType + << " at argument " << index; + } + if (getValue(operands[1])) { + return emitError(unknownLoc, "duplicate definition of result ") + << operands[1]; + } + + IntegerAttr inputIndexAttr = getConstantInt(operands[2]); + if (!inputIndexAttr) { + return emitError(unknownLoc, + "unable to read inputIndex value from constant op ") + << operands[2]; + } + BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt()); + valueMap[operands[1]] = argValue; + } + + graphOutputs.resize(graphType.getNumResults()); + + // RAII guard to reset the insertion point to the module's region after + // deserializing the body of this function. + OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); + + blockMap[graphID] = entryBlock; + if (failed(createGraphBlock(graphID))) { + return failure(); + } + + // Process all the instructions in the graph until and including + // OpGraphEndARM. + spirv::Opcode opcode; + ArrayRef instOperands; + do { + if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) { + return failure(); + } + + if (failed(processInstruction(opcode, instOperands))) { + return failure(); + } + } while (opcode != spirv::Opcode::OpGraphEndARM); + + return success(); +} + +LogicalResult +spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef operands) { + if (operands.size() != 2) { + return emitError( + unknownLoc, + "expected value id and output index for OpGraphSetOutputARM"); + } + + uint32_t id = operands[0]; + Value value = getValue(id); + if (!value) { + return emitError(unknownLoc, "could not find result ") << id; + } + + IntegerAttr outputIndexAttr = getConstantInt(operands[1]); + if (!outputIndexAttr) { + return emitError(unknownLoc, + "unable to read outputIndex value from constant op ") + << operands[1]; + } + graphOutputs[outputIndexAttr.getInt()] = value; + return success(); +} + +LogicalResult +spirv::Deserializer::processGraphEndARM(ArrayRef operands) { + // Create GraphOutputsARM instruction. + opBuilder.create(unknownLoc, graphOutputs); + + // Process OpGraphEndARM. + if (!operands.empty()) { + return emitError(unknownLoc, "unexpected operands for OpGraphEndARM"); + } + + curBlock = nullptr; + curGraph = std::nullopt; + graphOutputs.clear(); + + LLVM_DEBUG({ + logger.unindent(); + logger.startLine() + << "//===-------------------------------------------===//\n"; + }); + return success(); +} + std::optional> spirv::Deserializer::getConstant(uint32_t id) { auto constIt = constantMap.find(id); @@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { return funcName; } +std::string spirv::Deserializer::getGraphSymbol(uint32_t id) { + std::string graphName = nameMap.lookup(id).str(); + if (graphName.empty()) { + graphName = "spirv_graph_" + std::to_string(id); + } + return graphName; +} + std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { auto constName = nameMap.lookup(id).str(); if (constName.empty()) { @@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, return op; } +std::optional +spirv::Deserializer::getGraphConstantARM(uint32_t id) { + auto graphConstIt = graphConstantMap.find(id); + if (graphConstIt == graphConstantMap.end()) + return std::nullopt; + return graphConstIt->getSecond(); +} + LogicalResult spirv::Deserializer::processGlobalVariable(ArrayRef operands) { unsigned wordIndex = 0; @@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, return processMatrixType(operands); case spirv::Opcode::OpTypeTensorARM: return processTensorARMType(operands); + case spirv::Opcode::OpTypeGraphARM: + return processGraphTypeARM(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } @@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef operands) { return success(); } +LogicalResult +spirv::Deserializer::processGraphTypeARM(ArrayRef operands) { + unsigned size = operands.size(); + if (size < 2) { + return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands " + "(result_id, num_inputs, (inout0_type, " + "inout1_type, ...))") + << size; + } + uint32_t numInputs = operands[1]; + SmallVector argTypes; + SmallVector returnTypes; + for (unsigned i = 2; i < size; ++i) { + Type inOutTy = getType(operands[i]); + if (!inOutTy) { + return emitError(unknownLoc, + "OpTypeGraphARM references undefined element type.") + << operands[i]; + } + if (i - 2 >= numInputs) { + returnTypes.push_back(inOutTy); + } else { + argTypes.push_back(inOutTy); + } + } + typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes); + return success(); +} + LogicalResult spirv::Deserializer::processTypeForwardPointer(ArrayRef operands) { if (operands.size() != 2) @@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef operands) { << resultType; } +LogicalResult +spirv::Deserializer::processGraphConstantARM(ArrayRef operands) { + if (operands.size() < 3) { + return emitError(unknownLoc) + << "OpGraphConstantARM must have at least 2 operands"; + } + + Type resultType = getType(operands[0]); + if (!resultType) { + return emitError(unknownLoc, "undefined result type from ") + << operands[0]; + } + + uint32_t resultID = operands[1]; + + if (!dyn_cast(resultType)) { + return emitError(unknownLoc, "result must be of type OpTypeTensorARM"); + } + + APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true); + Type i32Ty = opBuilder.getIntegerType(32); + IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id); + graphConstantMap.try_emplace( + resultID, GraphConstantARMOpMaterializationInfo{resultType, attr}); + + return success(); +} + //===----------------------------------------------------------------------===// // Control flow //===----------------------------------------------------------------------===// @@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef operands) { return success(); } +LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) { + if (!curGraph) { + return emitError(unknownLoc, "a graph block must appear inside a graph"); + } + + // We may have forward declared this block. + Block *block = getOrCreateBlock(graphID); + LLVM_DEBUG(logger.startLine() + << "[block] populating block " << block << "\n"); + // If we have seen this block, make sure it was just a forward declaration. + assert(block->empty() && "re-deserialize the same block!"); + + opBuilder.setInsertionPointToStart(block); + blockMap[graphID] = curBlock = block; + + return success(); +} + LogicalResult spirv::Deserializer::processSelectionMerge(ArrayRef operands) { if (!curBlock) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index db1cc3f8d79c2..6027f1ac94c23 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -106,6 +106,13 @@ struct SpecConstOperationMaterializationInfo { SmallVector enclosedOpOperands; }; +/// A struct that collects the info needed to materialize/emit a +/// GraphConstantARMOp. +struct GraphConstantARMOpMaterializationInfo { + Type resultType; + IntegerAttr graphConstantID; +}; + //===----------------------------------------------------------------------===// // Deserializer Declaration //===----------------------------------------------------------------------===// @@ -211,9 +218,14 @@ class Deserializer { /// exists; otherwise creates one based on the . std::string getFunctionSymbol(uint32_t id); - /// Returns a symbol to be used for the specialization constant with the given - /// result . This tries to use the specialization constant's OpName if + /// Returns a symbol to be used for the graph name with the given + /// result . This tries to use the graph's OpName if /// exists; otherwise creates one based on the . + std::string getGraphSymbol(uint32_t id); + + /// Returns a symbol to be used for the specialization constant with the + /// given result . This tries to use the specialization constant's + /// OpName if exists; otherwise creates one based on the . std::string getSpecConstantSymbol(uint32_t id); /// Gets the specialization constant with the given result . @@ -237,6 +249,11 @@ class Deserializer { spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue); + /// Gets the GraphConstantARM ID attribute and result type with the given + /// result . + std::optional + getGraphConstantARM(uint32_t id); + /// Processes the OpVariable instructions at current `offset` into `binary`. /// It is expected that this method is used for variables that are to be /// defined at module scope and will be deserialized into a @@ -306,6 +323,16 @@ class Deserializer { LogicalResult processTensorARMType(ArrayRef operands); + LogicalResult processGraphTypeARM(ArrayRef operands); + + LogicalResult processGraphEntryPointARM(ArrayRef operands); + + LogicalResult processGraphARM(ArrayRef operands); + + LogicalResult processOpGraphSetOutputARM(ArrayRef operands); + + LogicalResult processGraphEndARM(ArrayRef operands); + LogicalResult processTypeForwardPointer(ArrayRef operands); //===--------------------------------------------------------------------===// @@ -353,6 +380,10 @@ class Deserializer { /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); + /// Processes a SPIR-V OpGraphConstantARM instruction with the given + /// `operands`. + LogicalResult processGraphConstantARM(ArrayRef operands); + //===--------------------------------------------------------------------===// // Debug //===--------------------------------------------------------------------===// @@ -450,6 +481,9 @@ class Deserializer { /// blocks declared as selection/loop headers are handled. LogicalResult structurizeControlFlow(); + /// Creates a block for graph with the given graphID. + LogicalResult createGraphBlock(uint32_t graphID); + //===--------------------------------------------------------------------===// // Instruction //===--------------------------------------------------------------------===// @@ -546,6 +580,9 @@ class Deserializer { /// The current function under construction. std::optional curFunction; + /// The current graph under construction. + std::optional curGraph; + /// The current block under construction. Block *curBlock = nullptr; @@ -599,12 +636,19 @@ class Deserializer { DenseMap specConstOperationMap; + // Result to GraphConstantARM ID attribute and result type. + DenseMap + graphConstantMap; + // Result to variable mapping. DenseMap globalVariableMap; // Result to function mapping. DenseMap funcMap; + // Result to function mapping. + DenseMap graphMap; + // Result to block mapping. DenseMap blockMap; @@ -668,6 +712,9 @@ class Deserializer { /// Deserialization options. DeserializationOptions options; + /// List of IDs assigned to graph outputs. + SmallVector graphOutputs; + #ifndef NDEBUG /// A logger used to emit information during the deserialzation process. llvm::ScopedPrinter logger; diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index d62529b85b3aa..e9b180a70bb23 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -203,6 +203,16 @@ Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { return success(); } +LogicalResult +Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) { + if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(), + op.getGraphConstantIdAttr())) { + valueIDMap[op.getResult()] = resultID; + return success(); + } + return failure(); +} + LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; @@ -368,6 +378,118 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { return success(); } +LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) { + if (op.getNumResults() < 1) { + return op.emitError("cannot serialize graph with no return types"); + } + + LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n"); + assert(functionHeader.empty() && functionBody.empty()); + + uint32_t funcID = getOrCreateFunctionID(op.getName()); + uint32_t fnTypeID = 0; + // Generate type of the function. + if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID))) + return failure(); + encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM, + {fnTypeID, funcID}); + + // Declare the parameters. + for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { + uint32_t argTypeID = 0; + SmallVector inputOperands; + + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + + uint32_t argValueID = getNextID(); + valueIDMap[arg] = argValueID; + + auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx); + uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false); + + inputOperands.push_back(argTypeID); + inputOperands.push_back(argValueID); + inputOperands.push_back(indexID); + + encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM, + inputOperands); + } + + if (failed(processBlock(&op.front(), /*omitLabel=*/true))) + return failure(); + if (failed(visitInPrettyBlockOrder( + &op.front(), [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true))) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName() + << "' --\n"); + // Insert OpGraphEndARM. + encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {}); + + llvm::append_range(graphs, functionHeader); + llvm::append_range(graphs, functionBody); + functionHeader.clear(); + functionBody.clear(); + + return success(); +} + +LogicalResult +Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) { + SmallVector operands; + StringRef graph = op.getFn(); + // Add the graph . + uint32_t graphID = getOrCreateFunctionID(graph); + operands.push_back(graphID); + // Add the name of the graph. + spirv::encodeStringLiteralInto(operands, graph); + + // Add the interface values. + if (ArrayAttr interface = op.getInterface()) { + for (Attribute var : interface.getValue()) { + StringRef value = cast(var).getValue(); + if (uint32_t id = getVariableID(value)) { + operands.push_back(id); + } else { + return op.emitError( + "referencing undefined global variable." + "spirv.GraphEntryPointARM is at the end of spirv.module. All " + "referenced variables should already be defined"); + } + } + } + encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands); + return success(); +} + +LogicalResult +Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) { + for (auto [idx, value] : llvm::enumerate(op->getOperands())) { + SmallVector outputOperands; + + Type resType = value.getType(); + uint32_t resTypeID = 0; + if (failed(processType(op.getLoc(), resType, resTypeID))) { + return failure(); + } + + uint32_t outputID = getValueID(value); + auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx); + uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false); + + outputOperands.push_back(outputID); + outputOperands.push_back(indexID); + + encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM, + outputOperands); + } + return success(); +} + LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { SmallVector operands; SmallVector elidedAttrs; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 7fc779587f4f1..b56e7788625f5 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -136,7 +136,7 @@ void Serializer::collect(SmallVectorImpl &binary) { extensions.size() + extendedSets.size() + memoryModel.size() + entryPoints.size() + executionModes.size() + decorations.size() + - typesGlobalValues.size() + functions.size(); + typesGlobalValues.size() + functions.size() + graphs.size(); binary.clear(); binary.reserve(moduleSize); @@ -154,6 +154,7 @@ void Serializer::collect(SmallVectorImpl &binary) { binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); + binary.append(graphs.begin(), graphs.end()); } #ifndef NDEBUG @@ -509,6 +510,9 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, if ((isa(type) && succeeded(prepareFunctionType(loc, cast(type), typeEnum, operands))) || + (isa(type) && + succeeded( + prepareGraphType(loc, cast(type), typeEnum, operands))) || succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, deferSerialization, serializationCtx))) { if (deferSerialization) @@ -539,7 +543,7 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, return success(); } - return failure(); + return emitError(loc, "failed to process type: ") << type; } LogicalResult Serializer::prepareBasicType( @@ -875,6 +879,33 @@ Serializer::prepareFunctionType(Location loc, FunctionType type, return success(); } +LogicalResult +Serializer::prepareGraphType(Location loc, GraphType type, + spirv::Opcode &typeEnum, + SmallVectorImpl &operands) { + typeEnum = spirv::Opcode::OpTypeGraphARM; + assert(type.getNumResults() >= 1 && + "serialization requires at least a return value"); + + operands.push_back(type.getNumInputs()); + + for (Type argType : type.getInputs()) { + uint32_t argTypeID = 0; + if (failed(processType(loc, argType, argTypeID))) + return failure(); + operands.push_back(argTypeID); + } + + for (Type resType : type.getResults()) { + uint32_t resTypeID = 0; + if (failed(processType(loc, resType, resTypeID))) + return failure(); + operands.push_back(resTypeID); + } + + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// @@ -1135,6 +1166,41 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, return resultID; } +uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType, + IntegerAttr intAttr) { + // De-duplicate graph constants. + if (uint32_t id = getGraphConstantARMId(intAttr)) { + return id; + } + + // Process the type for this graph constant. + uint32_t typeID = 0; + if (failed(processType(loc, graphConstType, typeID))) { + return 0; + } + + uint32_t resultID = getNextID(); + APInt value = intAttr.getValue(); + unsigned bitwidth = value.getBitWidth(); + if (bitwidth > 32) { + emitError(loc, "Too wide attribute for OpGraphConstantARM: ") + << bitwidth << " bits"; + return 0; + } + bool isSigned = value.isSignedIntN(bitwidth); + + uint32_t word = 0; + if (isSigned) { + word = static_cast(value.getSExtValue()); + } else { + word = static_cast(value.getZExtValue()); + } + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM, + {typeID, resultID, word}); + graphConstIDMap[intAttr] = resultID; + return resultID; +} + uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec) { if (!isSpec) { @@ -1469,9 +1535,19 @@ LogicalResult Serializer::processOperation(Operation *opInst) { return processConstantCompositeReplicateOp(op); }) .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) + .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); }) + .Case([&](spirv::GraphEntryPointARMOp op) { + return processGraphEntryPointARMOp(op); + }) + .Case([&](spirv::GraphOutputsARMOp op) { + return processGraphOutputsARMOp(op); + }) .Case([&](spirv::GlobalVariableOp op) { return processGlobalVariableOp(op); }) + .Case([&](spirv::GraphConstantARMOp op) { + return processGraphConstantARMOp(op); + }) .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index fb2cecdff8e43..add372b19b5af 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -122,6 +122,8 @@ class Serializer { LogicalResult processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); + LogicalResult processGraphConstantARMOp(spirv::GraphConstantARMOp op); + /// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an @@ -135,6 +137,15 @@ class Serializer { LogicalResult processFuncOp(spirv::FuncOp op); LogicalResult processFuncParameter(spirv::FuncOp op); + /// Processes a SPIR-V GraphARM op. + LogicalResult processGraphARMOp(spirv::GraphARMOp op); + + /// Processes a SPIR-V GraphEntryPointARM op. + LogicalResult processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op); + + /// Processes a SPIR-V GraphOutputsARMOp op. + LogicalResult processGraphOutputsARMOp(spirv::GraphOutputsARMOp op); + LogicalResult processVariableOp(spirv::VariableOp op); /// Process a SPIR-V GlobalVariableOp @@ -189,6 +200,10 @@ class Serializer { spirv::Opcode &typeEnum, SmallVectorImpl &operands); + LogicalResult prepareGraphType(Location loc, GraphType type, + spirv::Opcode &typeEnum, + SmallVectorImpl &operands); + //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// @@ -238,6 +253,13 @@ class Serializer { uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec = false); + uint32_t getGraphConstantARMId(Attribute value) const { + return graphConstIDMap.lookup(value); + } + + uint32_t prepareGraphConstantId(Location loc, Type graphConstType, + IntegerAttr intAttr); + uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec = false); @@ -372,6 +394,7 @@ class Serializer { SmallVector decorations; SmallVector typesGlobalValues; SmallVector functions; + SmallVector graphs; /// Recursive struct references are serialized as OpTypePointer instructions /// to the recursive struct type. However, the OpTypePointer instruction @@ -388,15 +411,22 @@ class Serializer { recursiveStructInfos; /// `functionHeader` contains all the instructions that must be in the first - /// block in the function, and `functionBody` contains the rest. After - /// processing FuncOp, the encoded instructions of a function are appended to - /// `functions`. An example of instructions in `functionHeader` in order: + /// block in the function or graph, and `functionBody` contains the rest. + /// After processing FuncOp/GraphARMOp, the encoded instructions of a function + /// or graph are appended to `functions` or `graphs` respectively. Examples of + /// instructions in `functionHeader` in order: + /// + /// For a FuncOp: /// OpFunction ... /// OpFunctionParameter ... /// OpFunctionParameter ... /// OpLabel ... /// OpVariable ... /// OpVariable ... + /// + /// For a GraphARMOp + /// OpGraphARM ... + /// OpGraphInputARM ... SmallVector functionHeader; SmallVector functionBody; @@ -412,6 +442,9 @@ class Serializer { /// Map from specialization constant names to their s. llvm::StringMap specConstIDMap; + /// Map from graph constant ID value to their s. + DenseMap graphConstIDMap; + /// Map from GlobalVariableOps name to s. llvm::StringMap globalVarIDMap; diff --git a/mlir/test/Target/SPIRV/graph-ops.mlir b/mlir/test/Target/SPIRV/graph-ops.mlir new file mode 100644 index 0000000000000..c956157bfa6c1 --- /dev/null +++ b/mlir/test/Target/SPIRV/graph-ops.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %} + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce { +spirv.module Logical Vulkan requires #spirv.vce { + // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> + // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] + spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0 + // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> + %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> + // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16> + } + + // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = false} { + spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> + } +}