diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 910418f1706a6..ce4bb6c2e4934 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -423,6 +423,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>; def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>; +def SPV_ARM_graph : I32EnumAttrCase<"SPV_ARM_graph", 6001>; def SPIRV_ExtensionAttr : SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [ @@ -447,7 +448,7 @@ def SPIRV_ExtensionAttr : SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add, SPV_EXT_mesh_shader, - SPV_ARM_tensors, + SPV_ARM_tensors, SPV_ARM_graph, SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot, SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask, SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod, @@ -1332,6 +1333,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora Extension<[SPV_ARM_tensors]> ]; } +def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> { + list implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel]; + list availability = [ + Extension<[SPV_ARM_graph]> + ]; +} def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> { list implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR]; list availability = [ @@ -1545,7 +1552,7 @@ def SPIRV_CapabilityAttr : SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect, SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport, SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT, - SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, + SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM, SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers, SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV, SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV, @@ -4245,6 +4252,7 @@ def SPIRV_AnyTensorArm : DialectType; def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>; + def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>; def SPIRV_Composite : AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, @@ -4551,6 +4559,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUnifo def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>; def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>; def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>; +def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>; +def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>; +def SPIRV_OC_OpGraphARM : I32EnumAttrCase<"OpGraphARM", 4183>; +def SPIRV_OC_OpGraphInputARM : I32EnumAttrCase<"OpGraphInputARM", 4184>; +def SPIRV_OC_OpGraphSetOutputARM : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>; +def SPIRV_OC_OpGraphEndARM : I32EnumAttrCase<"OpGraphEndARM", 4186>; +def SPIRV_OC_OpTypeGraphARM : I32EnumAttrCase<"OpTypeGraphARM", 4190>; def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>; def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>; @@ -4666,6 +4681,9 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr, SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpTypeTensorARM, + SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM, + SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM, + SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM, SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, @@ -4836,6 +4854,11 @@ class SPIRV_NvVendorOp traits = []> : SPIRV_VendorOp { } +class SPIRV_ArmVendorOp traits = []> : + SPIRV_VendorOp { +} + + def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">; def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>; def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td new file mode 100644 index 0000000000000..38fb4b2eff414 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td @@ -0,0 +1,201 @@ +//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of Graph extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS +#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS + +include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" + +//===----------------------------------------------------------------------===// +// SPIR-V Graph opcode specification. +//===----------------------------------------------------------------------===// + +// Base class for all Graph ops. +class SPIRV_GraphARMOp traits = []> : + SPIRV_ArmVendorOp { + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>, + Capability<[SPIRV_C_GraphARM]> + ]; +} + +def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> { + let summary = "Declare a graph constant."; + + let description = [{ + Declare a graph constant. + Result Type must be an OpTypeTensorARM. + GraphConstantID must be a 32-bit integer literal. + }]; + + let arguments = (ins + I32Attr: $graph_constant_id + ); + + let results = (outs + SPIRV_AnyTensorArm:$output + ); + + let hasVerifier = 0; + + let autogenSerialization = 0; + + let assemblyFormat = [{ + attr-dict `:` type($output) + }]; +} + +// ----- + +def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [ + AutomaticAllocationScope, DeclareOpInterfaceMethods, + FunctionOpInterface, InModuleScope, IsolatedFromAbove + ]> { + + let summary = "Declare or define a SPIR-V graph"; + + let description = [{ + This op declares or defines a SPIR-V graph using one region, which + contains one or more blocks. + + Different from the SPIR-V binary format, this op is not allowed to + implicitly capture global values, and all external references must use + function arguments or symbol references. This op itself defines a symbol + that is unique in the enclosing module op. + + This op itself takes no operands and generates no results. Its region + can take zero or more arguments and return zero or more values. + + ``` + spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature + region + ``` + }]; + + let arguments = (ins + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + OptionalAttr:$entry_point, + StrAttr:$sym_name + ); + + let results = (outs); + + let regions = (region AnyRegion:$body); + + let hasVerifier = 0; + + let builders = [ + OpBuilder<(ins "StringRef":$name, "GraphType":$type, + CArg<"ArrayRef", "{}">:$attrs, CArg<"bool", "false">:$entry_point)>]; + + let hasOpcode = 0; + + let autogenSerialization = 0; + + let extraClassDeclaration = [{ + /// Hook for FunctionOpInterface, called after verifying that the 'type' + /// attribute is present and checks if it holds a function type. Ensures + /// getType, getNumArguments, and getNumResults can be called safely + LogicalResult verifyType(); + + /// Hook for FunctionOpInterface, called after verifying the function + /// type and the presence of the (potentially empty) function body. + /// Ensures SPIR-V specific semantics. + LogicalResult verifyBody(); + }]; +} + +// Check that an op can only be used within the scope of a spirv.ARM.Graph op. +def InGraphScope : PredOpTrait< + "op must appear in a spirv.ARM.Graph op's block", + CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>; + +// ----- + +def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> { + let summary = [{ + Declare a graph entry point and its interface. + }]; + + let description = [{ + Graph Entry Point must be the Result of an OpGraphARM instruction. + + Name is a name string for the graphentry point. A module cannot have two + OpGraphEntryPointARM instructions with the same Name string. + + Interface is a list of symbol references to `spirv.GlobalVariable` + operations. These declare the set of global variables from a + module that form the interface of this entry point. The set of + Interface symbols must be equal to or a superset of the + `spirv.GlobalVariable`s referenced by the entry point’s static call + tree, within the interface’s storage classes. + + ``` + entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint` + symbol-reference (`, ` symbol-reference)* + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + SymbolRefArrayAttr:$interface + ); + + let results = (outs); + + let autogenSerialization = 0; + + let builders = [ + OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef":$interfaceVars)>]; +} + +// ----- + +def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure, + Terminator]> { + + let summary = "Define graph outputs."; + + let description = [{ + Values are the graph outputs values and must match the GraphOutputs Type + operand of the OpTypeGraphARM type of the OpGraphARM body this + instruction is in. + + This instruction must be the last instruction in a block. + + ``` + graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens + ``` + }]; + + let arguments = (ins + Variadic:$value + ); + + let results = (outs); + + let autogenSerialization = 0; + + let hasOpcode = 0; + + let assemblyFormat = "$value attr-dict `:` type($value)"; +} + +#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td index 0fa1bb9d5bd01..96ef035eda37a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td" diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index ad59ea63a6901..aa7d30b87db14 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -24,6 +24,7 @@ class Type; class IntegerType; class FloatType; class FunctionType; +class GraphType; class IndexType; class MemRefType; class VectorType; @@ -81,6 +82,7 @@ class Builder { IntegerType getIntegerType(unsigned width); IntegerType getIntegerType(unsigned width, bool isSigned); FunctionType getFunctionType(TypeRange inputs, TypeRange results); + GraphType getGraphType(TypeRange inputs, TypeRange results); TupleType getTupleType(TypeRange elementTypes); NoneType getNoneType(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index a0c8acea91dc5..08847dd11c685 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> { // FunctionType //===----------------------------------------------------------------------===// -def Builtin_Function : Builtin_Type<"Function", "function"> { +class Builtin_FunctionLike : Builtin_Type { let summary = "Map from a list of inputs to a list of results"; let description = [{ Syntax: @@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> { }]> ]; let skipDefaultBuilders = 1; + let storageClass = "FunctionTypeStorage"; let genStorageClass = 0; let extraClassDeclaration = [{ /// Input types. @@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> { unsigned getNumResults() const; Type getResult(unsigned i) const { return getResults()[i]; } - /// Returns a clone of this function type with the given argument + /// Returns a clone of this function-like type with the given argument /// and result types. - FunctionType clone(TypeRange inputs, TypeRange results) const; + }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const; - /// Returns a new function type with the specified arguments and results + /// Returns a new function-like type with the specified arguments and results /// inserted. - FunctionType getWithArgsAndResults(ArrayRef argIndices, + }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef argIndices, TypeRange argTypes, ArrayRef resultIndices, TypeRange resultTypes); - /// Returns a new function type without the specified arguments and results. - FunctionType getWithoutArgsAndResults(const BitVector &argIndices, + /// Returns a new function-like type without the specified arguments and results. + }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices, const BitVector &resultIndices); }]; } +def Builtin_Function : Builtin_FunctionLike<"Function", "function">; +def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">; + //===----------------------------------------------------------------------===// // IndexType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 45ec1846580f2..aab1b01c5cff9 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -387,6 +387,13 @@ class OpaqueType def FunctionType : Type($_self)">, "function type", "::mlir::FunctionType">; +// Graph Type + +// Any graph type. +def GraphType : Type($_self)">, + "graph type", "::mlir::GraphType">; + + // A container type is a type that has another type embedded within it. class ContainerType : diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 88c7adf3dfcb3..4b8ed08249b3a 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -1019,8 +1019,15 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, return verifyRegionAttribute(op->getLoc(), argType, attribute); } -LogicalResult SPIRVDialect::verifyRegionResultAttribute( - Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, - NamedAttribute attribute) { - return op->emitError("cannot attach SPIR-V attributes to region result"); +LogicalResult +SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, + unsigned resultIndex, + NamedAttribute attribute) { + auto funcOp = dyn_cast(op); + if (!funcOp) + return op->emitError( + "cannot attach SPIR-V attributes to region result which is " + "not a FunctionOpInterface type"); + return verifyRegionAttribute(op->getLoc(), + funcOp.getResultTypes()[resultIndex], attribute); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index d8dfe164458e2..2f3a28ff16173 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) { return isNestedInFunctionOpInterface(op->getParentOp()); } +/// Returns true if the given op is a GraphARM op or nested in a +/// GraphARM op without a module-like op in the middle. +static bool isNestedInGraphARMOpInterface(Operation *op) { + if (!op) + return false; + if (op->hasTrait()) + return false; + if (isa(op)) + return true; + return isNestedInGraphARMOpInterface(op->getParentOp()); +} + /// Returns true if the given op is an module-like op that maintains a symbol /// table. static bool isDirectInModuleLikeOp(Operation *op) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index eb2974d62fdd1..17cbab189588f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1084,6 +1084,236 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, state.addRegion(); } +//===----------------------------------------------------------------------===// +// spirv.GraphEntryPointARM +//===----------------------------------------------------------------------===// + +void spirv::GraphEntryPointARMOp::build(OpBuilder &builder, + OperationState &state, + spirv::GraphARMOp graph, + ArrayRef interfaceVars) { + build(builder, state, SymbolRefAttr::get(graph), + builder.getArrayAttr(interfaceVars)); +} + +ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector idTypes; + SmallVector interfaceVars; + + FlatSymbolRefAttr fn; + if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { + return failure(); + } + + if (!parser.parseOptionalComma()) { + // Parse the interface variables + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + // The name of the interface variable attribute isnt important + FlatSymbolRefAttr var; + NamedAttrList attrs; + if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) + return failure(); + interfaceVars.push_back(var); + return success(); + })) + return failure(); + } + result.addAttribute("interface", + parser.getBuilder().getArrayAttr(interfaceVars)); + return success(); +} + +void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getFn()); + auto interfaceVars = getInterface().getValue(); + if (!interfaceVars.empty()) { + printer << ", "; + llvm::interleaveComma(interfaceVars, printer); + } +} + +LogicalResult spirv::GraphEntryPointARMOp::verify() { + // Checks for fn and interface symbol reference are done in spirv::ModuleOp + // verification. + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphARM +//===----------------------------------------------------------------------===// + +ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + bool isVariadic = false; + if (function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + SmallVector argTypes; + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); + auto grType = builder.getGraphType(argTypes, resultTypes); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(grType)); + + // If additional attributes are present, parse them. + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + call_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); + + // Parse the optional function body. + auto *body = result.addRegion(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs); + return failure(parseResult.has_value() && failed(*parseResult)); +} + +void spirv::GraphARMOp::print(OpAsmPrinter &printer) { + // Print graph name, signature, and control. + printer << " "; + printer.printSymbolName(getSymName()); + auto grType = getFunctionType(); + function_interface_impl::printFunctionSignature( + printer, *this, grType.getInputs(), + /*isVariadic=*/false, grType.getResults()); + function_interface_impl::printFunctionAttributes(printer, *this, + {getFunctionTypeAttrName(), + getArgAttrsAttrName(), + getResAttrsAttrName()}); + + // Print the body. + Region &body = this->getBody(); + if (!body.empty()) { + printer << ' '; + printer.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +LogicalResult spirv::GraphARMOp::verifyType() { + if (getFunctionType().getNumResults() < 1) + return emitOpError("there should be at least one result"); + return success(); +} + +LogicalResult spirv::GraphARMOp::verifyBody() { + GraphType grType = getFunctionType(); + if (!isExternal()) { + Block &entryBlock = front(); + + unsigned numArguments = this->getNumArguments(); + if (entryBlock.getNumArguments() != numArguments) + return emitOpError("entry block must have ") + << numArguments << " arguments to match graph signature"; + + for (auto [index, grArgType, blockArgType] : + llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) { + if (blockArgType != grArgType) { + return emitOpError("type of entry block argument #") + << index << '(' << blockArgType + << ") must match the type of the corresponding argument in " + << "graph signature(" << grArgType << ')'; + } + } + } + + auto walkResult = walk([grType](Operation *op) -> WalkResult { + if (auto graphOutputsARMOp = dyn_cast(op)) { + if (grType.getNumResults() != graphOutputsARMOp.getNumOperands()) + return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ") + << graphOutputsARMOp.getNumOperands() + << "value(s) but enclosing graph requires " + << grType.getNumResults() << " results"; + + auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType(); + for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) { + auto graphOutputOperandType = graphOutputOperandTypes[i]; + auto grResultType = grType.getResult(i); + if (graphOutputOperandType != grResultType) + return graphOutputsARMOp.emitError("type of return operand ") + << i << " (" << graphOutputOperandType + << ") doesn't match graph result type (" << grResultType + << ")"; + } + } + return WalkResult::advance(); + }); + + return failure(walkResult.wasInterrupted()); +} + +void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state, + StringRef name, GraphType type, + ArrayRef attrs, bool entryPoint) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addAttribute(getEntryPointAttrName(state.name), + builder.getBoolAttr(entryPoint)); + state.addRegion(); +} + +// Returns the argument types of this function. +ArrayRef spirv::GraphARMOp::getArgumentTypes() { + return getFunctionType().getInputs(); +} + +// Returns the result types of this function. +ArrayRef spirv::GraphARMOp::getResultTypes() { + return getFunctionType().getResults(); +} + +// CallableOpInterface +Region *spirv::GraphARMOp::getCallableRegion() { + return isExternal() ? nullptr : &getBody(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphOutputsARM +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GraphOutputsARMOp::verify() { + auto graph = cast((*this)->getParentOp()); + + // The operand number and types must match the graph signature. + const auto &results = graph.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing graph (@" + << graph.getName() << ") returns " << results.size(); + + for (unsigned i = 0; i < results.size(); i++) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match graph result type (" << results[i] + << ")" + << " in graph @" << graph.getName(); + + return success(); +} + //===----------------------------------------------------------------------===// // spirv.GLFClampOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 6fd20466e36e3..40a85dca60939 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -76,10 +76,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, abiInfo.getBinding()); } +/// Creates a global variable for an argument or result based on the ABI info. +static spirv::GlobalVariableOp +createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp, + unsigned index, bool isArg, + spirv::InterfaceVarABIAttr abiInfo) { + auto spirvModule = graphOp->getParentOfType(); + if (!spirvModule) + return nullptr; + + OpBuilder::InsertionGuard moduleInsertionGuard(builder); + builder.setInsertionPoint(graphOp.getOperation()); + std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") + + std::to_string(index); + + auto varType = isArg ? graphOp.getFunctionType().getInput(index) + : graphOp.getFunctionType().getResult(index); + + auto pointerType = spirv::PointerType::get( + varType, + abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant)); + + return builder.create( + graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(), + abiInfo.getBinding()); +} + /// Gets the global variables that need to be specified as interface variable /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so. static LogicalResult -getInterfaceVariables(spirv::FuncOp funcOp, +getInterfaceVariables(mlir::FunctionOpInterface funcOp, SmallVectorImpl &interfaceVars) { auto module = funcOp->getParentOfType(); if (!module) { @@ -215,6 +241,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; +/// A pattern to convert graph signature according to interface variable ABI +/// attributes. +/// +/// Specifically, this pattern creates global variables according to interface +/// variable ABI attributes attached to graph arguments and results. +class ProcessGraphInterfaceVarABI final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Pass to implement the ABI information specified as attributes. class LowerABIAttributesPass final : public spirv::impl::SPIRVLowerABIAttributesPassBase< @@ -288,6 +329,89 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( return success(); } +namespace { + +/// Lowers the graph entry point +LogicalResult lowerGraphEntryPoint(OpBuilder &builder, + spirv::GraphARMOp graphOp, + ArrayRef interfaceVars) { + if (!graphOp.getEntryPoint().value_or(false)) { + return failure(); + } + + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPoint(graphOp); + builder.create(graphOp.getLoc(), graphOp, + interfaceVars); + return success(); +} +} // namespace + +LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite( + spirv::GraphARMOp graphOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (!graphOp.getEntryPoint().value_or(false)) { + // Non-entry point graphs are not handled. + return failure(); + } + TypeConverter::SignatureConversion signatureConverter( + graphOp.getFunctionType().getNumInputs()); + + auto attrName = spirv::getInterfaceVarABIAttrName(); + + SmallVector interfaceVars; + + // Convert arguments + for (const auto &argType : + llvm::enumerate(graphOp.getFunctionType().getInputs())) { + auto abiInfo = graphOp.getArgAttrOfType( + argType.index(), attrName); + if (!abiInfo) { + // Non-entry point graphs are not handled in this ABI lowering and will + // produce an error. + return failure(); + } + spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint( + rewriter, graphOp, argType.index(), true, abiInfo); + if (!var) + return failure(); + interfaceVars.push_back( + SymbolRefAttr::get(rewriter.getContext(), var.getSymName())); + } + + for (const auto &resType : + llvm::enumerate(graphOp.getFunctionType().getResults())) { + auto abiInfo = graphOp.getResultAttrOfType( + resType.index(), attrName); + if (!abiInfo) { + // Non-entry point graphs are not handled in this ABI lowering and will + // produce an error. + return failure(); + } + spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint( + rewriter, graphOp, resType.index(), false, abiInfo); + if (!var) + return failure(); + interfaceVars.push_back( + SymbolRefAttr::get(rewriter.getContext(), var.getSymName())); + } + + // Creates a new function with the update signature. + rewriter.modifyOpInPlace(graphOp, [&] { + for (const auto &argType : + llvm::enumerate(graphOp.getFunctionType().getInputs())) { + graphOp.removeArgAttr(argType.index(), attrName); + } + for (const auto &resType : + llvm::enumerate(graphOp.getFunctionType().getResults())) { + graphOp.removeResultAttr(resType.index(), + rewriter.getStringAttr(attrName)); + } + }); + + return lowerGraphEntryPoint(rewriter, graphOp, interfaceVars); +} + void LowerABIAttributesPass::runOnOperation() { // Uses the signature conversion methodology of the dialect conversion // framework to implement the conversion. @@ -314,6 +438,7 @@ void LowerABIAttributesPass::runOnOperation() { RewritePatternSet patterns(context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); ConversionTarget target(*context); // "Legal" function ops should have no interface variable ABI attributes. @@ -324,6 +449,17 @@ void LowerABIAttributesPass::runOnOperation() { return false; return true; }); + target.addDynamicallyLegalOp([&](spirv::GraphARMOp op) { + StringRef attrName = spirv::getInterfaceVarABIAttrName(); + for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) + if (op.getArgAttr(i, attrName)) + return false; + for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) + if (op.getResultAttr(i, attrName)) + return false; + return true; + }); + // All other SPIR-V ops are legal. target.markUnknownOpDynamicallyLegal([](Operation *op) { return op->getDialect()->getNamespace() == diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 095db6b815f51..d636ea29fe019 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -154,6 +154,14 @@ void UpdateVCEPass::runOnOperation() { if (auto globalVar = dyn_cast(op)) valueTypes.push_back(globalVar.getType()); + // If the op is FunctionLike make sure to process input and result types + if (auto funcOpInterface = dyn_cast(op)) { + auto inputTypes = funcOpInterface.getArgumentTypes(); + auto resultTypes = funcOpInterface.getResultTypes(); + valueTypes.append(inputTypes.begin(), inputTypes.end()); + valueTypes.append(resultTypes.begin(), resultTypes.end()); + } + // Requirements from values' types SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f95ad290a1981..58fa14c0f5251 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -104,7 +104,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) { // it is a function (avoiding a grammar ambiguity). bool wrapped = op->getNumResults() != 1; if (!wrapped && op->getResult(0).getType() && - llvm::isa(op->getResult(0).getType())) + (llvm::isa(op->getResult(0).getType()) || + llvm::isa(op->getResult(0).getType()))) wrapped = true; if (wrapped) @@ -2837,6 +2838,20 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { os << '>'; }) .Case([&](Type) { os << "none"; }) + .Case([&](GraphType graphTy) { + os << '('; + interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); }); + os << ") -> "; + ArrayRef results = graphTy.getResults(); + if (results.size() == 1 && !(llvm::isa(results[0]) || + llvm::isa(results[0]))) { + printType(results[0]); + } else { + os << '('; + interleaveComma(results, [&](Type ty) { printType(ty); }); + os << ')'; + } + }) .Default([&](Type type) { return printDialectType(type); }); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index f657db142eeb9..3d366276b4375 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -76,6 +76,10 @@ FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) { return FunctionType::get(context, inputs, results); } +GraphType Builder::getGraphType(TypeRange inputs, TypeRange results) { + return GraphType::get(context, inputs, results); +} + TupleType Builder::getTupleType(TypeRange elementTypes) { return TupleType::get(context, elementTypes); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 1604ebba190a1..ce47c60c9b932 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -179,6 +179,45 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, return clone(newArgTypes, newResultTypes); } +//===----------------------------------------------------------------------===// +// GraphType +//===----------------------------------------------------------------------===// + +unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; } + +ArrayRef GraphType::getInputs() const { return getImpl()->getInputs(); } + +unsigned GraphType::getNumResults() const { return getImpl()->numResults; } + +ArrayRef GraphType::getResults() const { return getImpl()->getResults(); } + +GraphType GraphType::clone(TypeRange inputs, TypeRange results) const { + return get(getContext(), inputs, results); +} + +/// Returns a new function type with the specified arguments and results +/// inserted. +GraphType GraphType::getWithArgsAndResults(ArrayRef argIndices, + TypeRange argTypes, + ArrayRef resultIndices, + TypeRange resultTypes) { + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = + insertTypesInto(getInputs(), argIndices, argTypes, argStorage); + TypeRange newResultTypes = + insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage); + return clone(newArgTypes, newResultTypes); +} + +/// Returns a new function type without the specified arguments and results. +GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices, + const BitVector &resultIndices) { + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage); + TypeRange newResultTypes = + filterTypesOut(getResults(), resultIndices, resultStorage); + return clone(newArgTypes, newResultTypes); +} //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index 55d6a380d0bff..abe6d4bc7040b 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -71,6 +71,12 @@ Value spirv::Deserializer::getValue(uint32_t id) { if (auto undef = getUndefType(id)) { return opBuilder.create(unknownLoc, undef); } + if (auto graphConstantARMInfo = getGraphConstantARM(id)) { + auto graphConstantID = graphConstantARMInfo->graphConstantID; + auto resultType = graphConstantARMInfo->resultType; + return opBuilder.create(unknownLoc, resultType, + graphConstantID); + } return valueMap.lookup(id); } @@ -165,6 +171,7 @@ LogicalResult spirv::Deserializer::processInstruction( case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: case spirv::Opcode::OpTypeTensorARM: + case spirv::Opcode::OpTypeGraphARM: case spirv::Opcode::OpTypeCooperativeMatrixKHR: return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: @@ -189,12 +196,26 @@ LogicalResult spirv::Deserializer::processInstruction( return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantNull: return processConstantNull(operands); + case spirv::Opcode::OpGraphConstantARM: + return processGraphConstantARM(operands); case spirv::Opcode::OpDecorate: return processDecoration(operands); case spirv::Opcode::OpMemberDecorate: return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); + case spirv::Opcode::OpGraphEntryPointARM: + if (deferInstructions) { + deferredInstructions.emplace_back(opcode, operands); + return success(); + } + return processGraphEntryPointARM(operands); + case spirv::Opcode::OpGraphARM: + return processGraphARM(operands); + case spirv::Opcode::OpGraphSetOutputARM: + return processOpGraphSetOutputARM(operands); + case spirv::Opcode::OpGraphEndARM: + return processGraphARMEnd(operands); case spirv::Opcode::OpLabel: return processLabel(operands); case spirv::Opcode::OpBranch: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index b1abd8b3dffe9..de3dc19349642 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -670,6 +670,213 @@ spirv::Deserializer::processFunctionEnd(ArrayRef operands) { return success(); } +LogicalResult +spirv::Deserializer::processGraphEntryPointARM(ArrayRef operands) { + unsigned wordIndex = 0; + if (wordIndex >= operands.size()) { + return emitError(unknownLoc, + "missing graph defintion in OpGraphEntryPointARM"); + } + + uint32_t grID = operands[wordIndex++]; + if (!graphMap.count(grID)) { + return emitError(unknownLoc, + "missing graph definition/declaration with id ") + << grID; + } + + spirv::GraphARMOp graphARM = graphMap[grID]; + StringRef name = decodeStringLiteral(operands, wordIndex); + graphARM.setSymName(name); + graphARM.setEntryPoint(true); + + SmallVector interface; + while (wordIndex < operands.size()) { + auto arg = getGlobalVariable(operands[wordIndex]); + if (!arg) { + return emitError(unknownLoc, "undefined result ") + << operands[wordIndex] << " while decoding OpGraphEntryPoint"; + } + interface.push_back(SymbolRefAttr::get(arg.getOperation())); + wordIndex++; + } + + // RAII guard to reset the insertion point to previous value when done. + OpBuilder::InsertionGuard insertionGuard(opBuilder); + opBuilder.setInsertionPoint(graphARM); + opBuilder.create( + unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name), + opBuilder.getArrayAttr(interface)); + + return success(); +} + +LogicalResult +spirv::Deserializer::processGraphARM(ArrayRef operands) { + if (curGraph) { + return emitError(unknownLoc, "found graph inside graph"); + } + // Get the result type + if (operands.size() < 2) { + return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters"); + } + + Type grType = getType(operands[0]); + if (!grType || !llvm::isa(grType)) { + return emitError(unknownLoc, "unknown graph type from ") + << operands[0]; + } + auto graphType = llvm::cast(grType); + if (graphType.getNumResults() <= 0) { + return emitError(unknownLoc, "expected at least one result"); + } + + uint32_t grID = operands[1]; + if (graphMap.count(grID)) { + return emitError(unknownLoc, "duplicate graph definition/declaration"); + } + + std::string grName = getGraphSymbol(grID); + auto graphOp = + opBuilder.create(unknownLoc, grName, graphType); + curGraph = graphMap[grID] = graphOp; + auto *entryBlock = graphOp.addEntryBlock(); + LLVM_DEBUG({ + logger.startLine() + << "//===-------------------------------------------===//\n"; + logger.startLine() << "[graph] name: " << grName << "\n"; + logger.startLine() << "[graph] type: " << grType << "\n"; + logger.startLine() << "[graph] ID: " << grID << "\n"; + logger.startLine() << "[graph] entry block: " << entryBlock << "\n"; + logger.indent(); + }); + + // Parse the op argument instructions + if (graphType.getNumInputs()) { + for (size_t i = 0, e = graphType.getNumInputs(); i != e; ++i) { + auto argType = graphType.getInput(i); + spirv::Opcode opcode = spirv::Opcode::OpNop; + ArrayRef operands; + if (failed(sliceInstruction(opcode, operands, + spirv::Opcode::OpGraphInputARM))) { + return failure(); + } + if (opcode != spirv::Opcode::OpGraphInputARM) { + return emitError(unknownLoc, + "missing OpGraphInputARM instruction for argument ") + << i; + } + + if (operands.size() != 3) { + return emitError(unknownLoc, "expected result type, result and " + "input index for OpGraphInputARM"); + } + + auto argDefinedType = getType(operands[0]); + if (!argDefinedType) { + return emitError(unknownLoc, "unknown operand type ") + << operands[0]; + } + + if (argDefinedType != argType) { + return emitError(unknownLoc, + "mismatch in argument type between graph type " + "definition ") + << graphType << " and argument type definition " + << argDefinedType << " at argument " << i; + } + if (getValue(operands[1])) { + return emitError(unknownLoc, "duplicate definition of result ") + << operands[1]; + } + + auto inputIndexAttr = getConstantInt(operands[2]); + if (inputIndexAttr == nullptr) { + return emitError(unknownLoc, + "unable to read inputIndex value from constant op ") + << operands[2]; + } + auto argValue = graphOp.getArgument(inputIndexAttr.getInt()); + valueMap[operands[1]] = argValue; + } + } + + graphOutputs.resize(graphType.getNumResults()); + + // RAII guard to reset the insertion point to the module's region after + // deserializing the body of this function. + OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); + + spirv::Opcode opcode = spirv::Opcode::OpNop; + + blockMap[grID] = entryBlock; + if (failed(createGraphBlock(grID))) { + return failure(); + } + + // Process all the instructions in the graph until and including + // OpGraphEndARM. + ArrayRef instOperands; + do { + if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) { + return failure(); + } + + if (failed(processInstruction(opcode, instOperands))) { + return failure(); + } + } while (opcode != spirv::Opcode::OpGraphEndARM); + + return success(); +} + +LogicalResult +spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef operands) { + + if (operands.size() != 2) { + return emitError( + unknownLoc, + "expected value id and output index for OpGraphSetOutputARM"); + } + + auto id = operands[0]; + auto value = getValue(id); + if (!value) { + return emitError(unknownLoc, "could not find result ") << id; + } + + auto outputIndexAttr = getConstantInt(operands[1]); + if (outputIndexAttr == nullptr) { + return emitError(unknownLoc, + "unable to read outputIndex value from constant op ") + << operands[1]; + } + graphOutputs[outputIndexAttr.getInt()] = value; + return success(); +} + +LogicalResult +spirv::Deserializer::processGraphARMEnd(ArrayRef operands) { + // Create GraphOutputsARM instruction + opBuilder.create(unknownLoc, graphOutputs); + + // Process OpGraphEndARM. + if (!operands.empty()) { + return emitError(unknownLoc, "unexpected operands for OpGraphEndARM"); + } + + curBlock = nullptr; + curGraph = std::nullopt; + graphOutputs.clear(); + + LLVM_DEBUG({ + logger.unindent(); + logger.startLine() + << "//===-------------------------------------------===//\n"; + }); + return success(); +} + std::optional> spirv::Deserializer::getConstant(uint32_t id) { auto constIt = constantMap.find(id); @@ -694,6 +901,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { return funcName; } +std::string spirv::Deserializer::getGraphSymbol(uint32_t id) { + auto graphName = nameMap.lookup(id).str(); + if (graphName.empty()) { + graphName = "spirv_graph_" + std::to_string(id); + } + return graphName; +} + std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { auto constName = nameMap.lookup(id).str(); if (constName.empty()) { @@ -716,6 +931,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, return op; } +std::optional +spirv::Deserializer::getGraphConstantARM(uint32_t id) { + auto graphConstIt = graphConstantMap.find(id); + if (graphConstIt == graphConstantMap.end()) + return std::nullopt; + return graphConstIt->getSecond(); +} + LogicalResult spirv::Deserializer::processGlobalVariable(ArrayRef operands) { unsigned wordIndex = 0; @@ -937,6 +1160,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, return processMatrixType(operands); case spirv::Opcode::OpTypeTensorARM: return processTensorARMType(operands); + case spirv::Opcode::OpTypeGraphARM: + return processGraphTypeARM(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } @@ -1289,6 +1514,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef operands) { return success(); } +LogicalResult +spirv::Deserializer::processGraphTypeARM(ArrayRef operands) { + unsigned size = operands.size(); + if (size < 2) { + return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands " + "(result_id, num_inputs, (inout0_type, " + "inout1_type, ...))") + << size; + } + uint32_t numInputs = operands[1]; + SmallVector argTypes; + SmallVector returnTypes; + for (unsigned i = 2; i < size; i++) { + Type inOutTy = getType(operands[i]); + if (!inOutTy) { + return emitError(unknownLoc, + "OpTypeGraphARM references undefined element type.") + << operands[i]; + } + if (i - 2 >= numInputs) { + returnTypes.push_back(inOutTy); + } else { + argTypes.push_back(inOutTy); + } + } + typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes); + return success(); +} + LogicalResult spirv::Deserializer::processTypeForwardPointer(ArrayRef operands) { if (operands.size() != 2) @@ -1699,6 +1953,38 @@ spirv::Deserializer::processConstantNull(ArrayRef operands) { << resultType; } +LogicalResult +spirv::Deserializer::processGraphConstantARM(ArrayRef operands) { + if (operands.size() < 2) { + return emitError(unknownLoc) + << "OpGraphConstantARM must have type and result "; + } + if (operands.size() < 3) { + return emitError(unknownLoc) + << "OpGraphConstantARM must have at least 1 more parameter"; + } + + Type resultType = getType(operands[0]); + if (!resultType) { + return emitError(unknownLoc, "undefined result type from ") + << operands[0]; + } + + auto resultID = operands[1]; + + if (!llvm::dyn_cast(resultType)) { + return emitError(unknownLoc, "result must be of type OpTypeTensorARM"); + } + + APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true); + Type i32Ty = opBuilder.getIntegerType(32); + auto attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id); + graphConstantMap.try_emplace( + resultID, GraphConstantARMOpMaterializationInfo{resultType, attr}); + + return success(); +} + //===----------------------------------------------------------------------===// // Control flow //===----------------------------------------------------------------------===// @@ -1796,6 +2082,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef operands) { return success(); } +LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) { + if (!curGraph) { + return emitError(unknownLoc, "a graph block must appear inside a graph"); + } + + // We may have forward declared this block. + auto *block = getOrCreateBlock(graphID); + LLVM_DEBUG(logger.startLine() + << "[block] populating block " << block << "\n"); + // If we have seen this block, make sure it was just a forward declaration. + assert(block->empty() && "re-deserialize the same block!"); + + opBuilder.setInsertionPointToStart(block); + blockMap[graphID] = curBlock = block; + + return success(); +} + LogicalResult spirv::Deserializer::processSelectionMerge(ArrayRef operands) { if (!curBlock) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 1bc9e4a3c75d8..90740112c8d13 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -105,6 +105,13 @@ struct SpecConstOperationMaterializationInfo { SmallVector enclosedOpOperands; }; +/// A struct that collects the info needed to materialize/emit a +/// GraphConstantARMOp. +struct GraphConstantARMOpMaterializationInfo { + Type resultType; + IntegerAttr graphConstantID; +}; + //===----------------------------------------------------------------------===// // Deserializer Declaration //===----------------------------------------------------------------------===// @@ -205,9 +212,14 @@ class Deserializer { /// exists; otherwise creates one based on the . std::string getFunctionSymbol(uint32_t id); - /// Returns a symbol to be used for the specialization constant with the given - /// result . This tries to use the specialization constant's OpName if + /// Returns a symbol to be used for the graph name with the given + /// result . This tries to use the graph's OpName if /// exists; otherwise creates one based on the . + std::string getGraphSymbol(uint32_t id); + + /// Returns a symbol to be used for the specialization constant with the + /// given result . This tries to use the specialization constant's + /// OpName if exists; otherwise creates one based on the . std::string getSpecConstantSymbol(uint32_t id); /// Gets the specialization constant with the given result . @@ -224,6 +236,11 @@ class Deserializer { spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue); + /// Gets the GraphConstantARM ID attribute and result type with the given + /// result . + std::optional + getGraphConstantARM(uint32_t id); + /// Processes the OpVariable instructions at current `offset` into `binary`. /// It is expected that this method is used for variables that are to be /// defined at module scope and will be deserialized into a @@ -293,6 +310,16 @@ class Deserializer { LogicalResult processTensorARMType(ArrayRef operands); + LogicalResult processGraphTypeARM(ArrayRef operands); + + LogicalResult processGraphEntryPointARM(ArrayRef operands); + + LogicalResult processGraphARM(ArrayRef operands); + + LogicalResult processOpGraphSetOutputARM(ArrayRef operands); + + LogicalResult processGraphARMEnd(ArrayRef operands); + LogicalResult processTypeForwardPointer(ArrayRef operands); //===--------------------------------------------------------------------===// @@ -330,6 +357,10 @@ class Deserializer { /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); + /// Processes a SPIR-V OpGraphConstantARM instruction with the given + /// `operands`. + LogicalResult processGraphConstantARM(ArrayRef operands); + //===--------------------------------------------------------------------===// // Debug //===--------------------------------------------------------------------===// @@ -427,6 +458,9 @@ class Deserializer { /// blocks declared as selection/loop headers are handled. LogicalResult structurizeControlFlow(); + /// Creates a block for graph with the given graphID + LogicalResult createGraphBlock(uint32_t graphID); + //===--------------------------------------------------------------------===// // Instruction //===--------------------------------------------------------------------===// @@ -523,6 +557,9 @@ class Deserializer { /// The current function under construction. std::optional curFunction; + /// The current graph under construction. + std::optional curGraph; + /// The current block under construction. Block *curBlock = nullptr; @@ -560,12 +597,19 @@ class Deserializer { DenseMap specConstOperationMap; + // Result to GraphConstantARM ID attribute and result type. + DenseMap + graphConstantMap; + // Result to variable mapping. DenseMap globalVariableMap; // Result to function mapping. DenseMap funcMap; + // Result to function mapping. + DenseMap graphMap; + // Result to block mapping. DenseMap blockMap; @@ -629,6 +673,9 @@ class Deserializer { /// Deserialization options. DeserializationOptions options; + /// List of IDs assigned to graph outputs. + SmallVector graphOutputs; + #ifndef NDEBUG /// A logger used to emit information during the deserialzation process. llvm::ScopedPrinter logger; diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index ff3cc92ee8078..4a8b10001ec02 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -161,6 +161,16 @@ Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { return success(); } +LogicalResult +Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) { + if (auto resultID = prepareGraphConstantId(op.getLoc(), op.getType(), + op.getGraphConstantIdAttr())) { + valueIDMap[op.getResult()] = resultID; + return success(); + } + return failure(); +} + LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; @@ -326,6 +336,123 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { return success(); } +LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) { + + if (op.getNumResults() < 1) { + return op.emitError("cannot serialize graph with no return types"); + } + + LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n"); + assert(functionHeader.empty() && functionBody.empty()); + + uint32_t funcID = getOrCreateFunctionID(op.getName()); + uint32_t fnTypeID = 0; + // Generate type of the function. + if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID))) + return failure(); + encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM, + {fnTypeID, funcID}); + + // Declare the parameters. + for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { + uint32_t argTypeID = 0; + SmallVector inputOperands; + + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + + uint32_t argValueID = getNextID(); + valueIDMap[arg] = argValueID; + + auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx); + auto indexID = prepareConstantInt(op.getLoc(), attr, false); + + inputOperands.push_back(argTypeID); + inputOperands.push_back(argValueID); + inputOperands.push_back(indexID); + + encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM, + inputOperands); + } + + // Process the body. + if (op.isExternal()) { + return op.emitError("external function is unhandled"); + } + + if (failed(processBlock(&op.front(), /*omitLabel=*/true))) + return failure(); + if (failed(visitInPrettyBlockOrder( + &op.front(), [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true))) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName() + << "' --\n"); + // Insert OpFunctionEnd. + encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {}); + + graphs.append(functionHeader.begin(), functionHeader.end()); + graphs.append(functionBody.begin(), functionBody.end()); + functionHeader.clear(); + functionBody.clear(); + + return success(); +} + +LogicalResult +Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) { + SmallVector operands; + auto graph = op.getFn(); + // Add the graph . + uint32_t graphID = getOrCreateFunctionID(graph); + operands.push_back(graphID); + // Add the name of the graph. + spirv::encodeStringLiteralInto(operands, graph); + + // Add the interface values. + if (auto interface = op.getInterface()) { + for (auto var : interface.getValue()) { + auto id = getVariableID(llvm::cast(var).getValue()); + if (!id) { + return op.emitError( + "referencing undefined global variable." + "spirv.GraphEntryPointARM is at the end of spirv.module. All " + "referenced variables should already be defined"); + } + operands.push_back(id); + } + } + encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands); + return success(); +} + +LogicalResult +Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) { + for (auto [idx, value] : llvm::enumerate(op->getOperands())) { + SmallVector outputOperands; + + auto resType = value.getType(); + uint32_t resTypeID = 0; + if (failed(processType(op.getLoc(), resType, resTypeID))) { + return failure(); + } + + uint32_t outputID = getValueID(value); + auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx); + auto indexID = prepareConstantInt(op.getLoc(), attr, false); + + outputOperands.push_back(outputID); + outputOperands.push_back(indexID); + + encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM, + outputOperands); + } + return success(); +} + LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { SmallVector operands; SmallVector elidedAttrs; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index ebebd2d283afa..b5ba5c91885d9 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -115,7 +115,7 @@ void Serializer::collect(SmallVectorImpl &binary) { extensions.size() + extendedSets.size() + memoryModel.size() + entryPoints.size() + executionModes.size() + decorations.size() + - typesGlobalValues.size() + functions.size(); + typesGlobalValues.size() + functions.size() + graphs.size(); binary.clear(); binary.reserve(moduleSize); @@ -133,6 +133,7 @@ void Serializer::collect(SmallVectorImpl &binary) { binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); + binary.append(graphs.begin(), graphs.end()); } #ifndef NDEBUG @@ -457,9 +458,12 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, auto typeEnum = spirv::Opcode::OpTypeVoid; bool deferSerialization = false; - if ((isa(type) && - succeeded(prepareFunctionType(loc, cast(type), typeEnum, - operands))) || + if ((llvm::isa(type) && + succeeded(prepareFunctionType(loc, llvm::cast(type), + typeEnum, operands))) || + (llvm::isa(type) && + succeeded(prepareGraphType(loc, llvm::cast(type), typeEnum, + operands))) || succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, deferSerialization, serializationCtx))) { if (deferSerialization) @@ -490,6 +494,7 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, return success(); } + emitError(loc, "failed to process type: ") << type; return failure(); } @@ -805,6 +810,35 @@ Serializer::prepareFunctionType(Location loc, FunctionType type, return success(); } +LogicalResult +Serializer::prepareGraphType(Location loc, GraphType type, + spirv::Opcode &typeEnum, + SmallVectorImpl &operands) { + typeEnum = spirv::Opcode::OpTypeGraphARM; + assert(type.getNumResults() >= 1 && + "serialization requires at least a return value"); + + operands.push_back(type.getNumInputs()); + + for (auto &res : type.getInputs()) { + uint32_t argTypeID = 0; + if (failed(processType(loc, res, argTypeID))) { + return failure(); + } + operands.push_back(argTypeID); + } + + for (auto &res : type.getResults()) { + uint32_t resultID = 0; + if (failed(processType(loc, res, resultID))) { + return failure(); + } + operands.push_back(resultID); + } + + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// @@ -1056,6 +1090,41 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, return resultID; } +uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType, + IntegerAttr intAttr) { + // De-duplicate graph constants. + if (auto id = getGraphConstantARMId(intAttr)) { + return id; + } + + // Process the type for this graph constant. + uint32_t typeID = 0; + if (failed(processType(loc, graphConstType, typeID))) { + return 0; + } + + auto resultID = getNextID(); + APInt value = intAttr.getValue(); + unsigned bitwidth = value.getBitWidth(); + if (bitwidth > 32) { + emitError(loc, "Too wide attribute for OpGraphConstantARM: ") + << bitwidth << " bits"; + return 0; + } + bool isSigned = value.isSignedIntN(bitwidth); + + uint32_t word = 0; + if (isSigned) { + word = static_cast(value.getSExtValue()); + } else { + word = static_cast(value.getZExtValue()); + } + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM, + {typeID, resultID, word}); + graphConstIDMap[intAttr] = resultID; + return resultID; +} + uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec) { if (!isSpec) { @@ -1329,9 +1398,19 @@ LogicalResult Serializer::processOperation(Operation *opInst) { }) .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) + .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); }) + .Case([&](spirv::GraphEntryPointARMOp op) { + return processGraphEntryPointARMOp(op); + }) + .Case([&](spirv::GraphOutputsARMOp op) { + return processGraphOutputsARMOp(op); + }) .Case([&](spirv::GlobalVariableOp op) { return processGlobalVariableOp(op); }) + .Case([&](spirv::GraphConstantARMOp op) { + return processGraphConstantARMOp(op); + }) .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index 9edb0f4af008d..e26e873f02daa 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -116,6 +116,8 @@ class Serializer { LogicalResult processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); + LogicalResult processGraphConstantARMOp(spirv::GraphConstantARMOp op); + /// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an @@ -129,6 +131,15 @@ class Serializer { LogicalResult processFuncOp(spirv::FuncOp op); LogicalResult processFuncParameter(spirv::FuncOp op); + /// Processes a SPIR-V GraphARM op. + LogicalResult processGraphARMOp(spirv::GraphARMOp op); + + /// Processes a SPIR-V GraphEntryPointARM op. + LogicalResult processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op); + + /// Processes a SPIR-V GraphOutputsARMOp op. + LogicalResult processGraphOutputsARMOp(spirv::GraphOutputsARMOp op); + LogicalResult processVariableOp(spirv::VariableOp op); /// Process a SPIR-V GlobalVariableOp @@ -183,6 +194,10 @@ class Serializer { spirv::Opcode &typeEnum, SmallVectorImpl &operands); + LogicalResult prepareGraphType(Location loc, GraphType type, + spirv::Opcode &typeEnum, + SmallVectorImpl &operands); + //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// @@ -227,6 +242,13 @@ class Serializer { uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec = false); + uint32_t getGraphConstantARMId(Attribute value) const { + return graphConstIDMap.lookup(value); + } + + uint32_t prepareGraphConstantId(Location loc, Type graphConstType, + IntegerAttr intAttr); + uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec = false); @@ -323,7 +345,7 @@ class Serializer { spirv::ModuleOp module; /// An MLIR builder for getting MLIR constructs. - mlir::Builder mlirBuilder; + mlir::OpBuilder mlirBuilder; /// Serialization options. SerializationOptions options; @@ -355,6 +377,7 @@ class Serializer { SmallVector decorations; SmallVector typesGlobalValues; SmallVector functions; + SmallVector graphs; /// Recursive struct references are serialized as OpTypePointer instructions /// to the recursive struct type. However, the OpTypePointer instruction @@ -371,15 +394,22 @@ class Serializer { recursiveStructInfos; /// `functionHeader` contains all the instructions that must be in the first - /// block in the function, and `functionBody` contains the rest. After - /// processing FuncOp, the encoded instructions of a function are appended to - /// `functions`. An example of instructions in `functionHeader` in order: + /// block in the function or graph, and `functionBody` contains the rest. + /// After processing FuncOp/GraphARMOp, the encoded instructions of a function + /// or graph are appended to `functions` or `graphs` respectively. Examples of + /// instructions in `functionHeader` in order: + /// + /// For a FuncOp: /// OpFunction ... /// OpFunctionParameter ... /// OpFunctionParameter ... /// OpLabel ... /// OpVariable ... /// OpVariable ... + /// + /// For a GraphARMOp + /// OpGraphARM ... + /// OpGraphInputARM ... SmallVector functionHeader; SmallVector functionBody; @@ -392,6 +422,9 @@ class Serializer { /// Map from specialization constant names to their s. llvm::StringMap specConstIDMap; + /// Map from graph constant ID value to their s. + DenseMap graphConstIDMap; + /// Map from GlobalVariableOps name to s. llvm::StringMap globalVarIDMap; diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index 64ba8e3fc249e..77dbdfaca19b9 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -278,3 +278,20 @@ func.func @set_mesh_outputs(%0 : i32, %1 : i32) -> () { spirv.EXT.SetMeshOutputs %0, %1 : i32, i32 spirv.Return } + +//===----------------------------------------------------------------------===// +// GraphARM ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: graph_arm +spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + // CHECK: spirv.ARM.GraphOutputs min version: v1.0 + // CHECK: spirv.ARM.GraphOutputs max version: v1.6 + // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ] + // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ] + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> +// CHECK: spirv.ARM.Graph min version: v1.0 +// CHECK: spirv.ARM.Graph max version: v1.6 +// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ] +// CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ] +} diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir new file mode 100644 index 0000000000000..90c31e19db382 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spirv.ARM.GraphConstant +//===----------------------------------------------------------------------===// + +spirv.module Logical Vulkan requires #spirv.vce { + // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32> + %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32> + + // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> + // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] + spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0 + // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> + %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> + // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16> + spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16> + } + + // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> + } +} diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir index 10fbcf06eb052..515162bf99aea 100644 --- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir @@ -14,7 +14,7 @@ func.func @unknown_attr_on_region(%arg: i32 {spirv.something}) { // ----- -// expected-error @+1 {{cannot attach SPIR-V attributes to region result}} +// expected-error @+1 {{found unsupported 'spirv.something' attribute on region argument}} func.func @unknown_attr_on_region() -> (i32 {spirv.something}) { %0 = arith.constant 10.0 : f32 return %0: f32 @@ -101,6 +101,27 @@ func.func @interface_var( // ----- +// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>} +func.func @interface_var(%arg: f32) -> ( + f32 {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>} +) { return %arg : f32 } + +// ----- + +// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1), Uniform>} +func.func @interface_var(%arg: f32) -> ( + f32 {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1), Uniform>} +) { return %arg : f32 } + +// ----- + +// expected-error @+1 {{'spirv.interface_var_abi' attribute cannot specify storage class when attaching to a non-scalar value}} +func.func @interface_var(%arg0 : memref<4xf32>) -> ( + memref<4xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1), Uniform>} +) { return %arg0 : memref<4xf32> } + +// ----- + //===----------------------------------------------------------------------===// // spirv.resource_limits //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir index bd51a07843652..9f5694135d623 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -35,6 +35,28 @@ spirv.module Logical GLSL450 { // ----- +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: spirv.module +spirv.module Logical Vulkan { + // CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> + // CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> + + // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] + // CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true} + spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} { + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> + } +} // end spirv.module + +} // end module + +// ----- + module { // expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}} spirv.module Logical GLSL450 {} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2b237665ffc4a..2482af8927aa4 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -231,3 +231,14 @@ spirv.module Logical GLSL450 attributes { spirv.ReturnValue %val : bf16 } } + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits<>> +} { + spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> { + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8> + } +} diff --git a/mlir/test/Target/SPIRV/graph-ops.mlir b/mlir/test/Target/SPIRV/graph-ops.mlir new file mode 100644 index 0000000000000..5b39d33cd49b9 --- /dev/null +++ b/mlir/test/Target/SPIRV/graph-ops.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce { +spirv.module Logical Vulkan requires #spirv.vce { + // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> + // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] + spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0 + // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> + %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> + // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16> + } + + // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = false} { + spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> + } +} diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp index 2e5e591fe5f91..9efca825a663d 100644 --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -21,7 +21,7 @@ using namespace mlir; namespace { /// A pass for testing SPIR-V op availability. struct PrintOpAvailability - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability) void runOnOperation() override; @@ -33,12 +33,10 @@ struct PrintOpAvailability } // namespace void PrintOpAvailability::runOnOperation() { - auto f = getOperation(); - llvm::outs() << f.getName() << "\n"; - + auto moduleOp = getOperation(); Dialect *spirvDialect = getContext().getLoadedDialect("spirv"); - f->walk([&](Operation *op) { + auto opCallback = [&](Operation *op) { if (op->getDialect() != spirvDialect) return WalkResult::advance(); @@ -89,6 +87,16 @@ void PrintOpAvailability::runOnOperation() { os.flush(); return WalkResult::advance(); + }; + + moduleOp.walk([&](func::FuncOp f) { + llvm::outs() << f.getName() << "\n"; + f->walk(opCallback); + }); + + moduleOp.walk([&](spirv::GraphARMOp g) { + llvm::outs() << g.getName() << "\n"; + g->walk(opCallback); }); }