From e7d40a87ff230528131541f6ac17a2e1a7dc78e1 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 1 Feb 2024 10:04:36 +0100 Subject: [PATCH] [mlir][EmitC] Add func, call and return operations and conversions (#79612) This adds a `func`, `call` and `return` operation to the EmitC dialect, closely related to the corresponding operations of the Func dialect. In contrast to the operations of the Func dialect, the EmitC operations do not support multiple results. The `emitc.func` op features a `specifiers` argument that for example allows, with corresponding support in the emitter, to emit `inline static` functions. Furthermore, this adds patterns and a pass to convert the Func dialect to EmitC. A `func.func` op that is `private` is converted to `emitc.func` with a `"static"` specifier. --- .../mlir/Conversion/FuncToEmitC/FuncToEmitC.h | 18 ++ .../Conversion/FuncToEmitC/FuncToEmitCPass.h | 21 +++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 9 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 1 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 170 ++++++++++++++++++ mlir/lib/Conversion/CMakeLists.txt | 1 + .../lib/Conversion/FuncToEmitC/CMakeLists.txt | 16 ++ .../Conversion/FuncToEmitC/FuncToEmitC.cpp | 116 ++++++++++++ .../FuncToEmitC/FuncToEmitCPass.cpp | 47 +++++ mlir/lib/Dialect/EmitC/IR/CMakeLists.txt | 2 + mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 119 ++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 148 +++++++++++---- .../Conversion/FuncToEmitC/func-to-emitc.mlir | 55 ++++++ mlir/test/Dialect/EmitC/invalid_ops.mlir | 37 ++++ mlir/test/Dialect/EmitC/ops.mlir | 15 ++ mlir/test/Target/Cpp/func.mlir | 39 ++++ .../llvm-project-overlay/mlir/BUILD.bazel | 31 ++++ 18 files changed, 815 insertions(+), 31 deletions(-) create mode 100644 mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h create mode 100644 mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h create mode 100644 mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt create mode 100644 mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp create mode 100644 mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp create mode 100644 mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir create mode 100644 mlir/test/Target/Cpp/func.mlir diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h new file mode 100644 index 0000000000000..5c7f87e470306 --- /dev/null +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h @@ -0,0 +1,18 @@ +//===- FuncToEmitC.h - Func to EmitC Patterns -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H +#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H + +namespace mlir { +class RewritePatternSet; + +void populateFuncToEmitCPatterns(RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h new file mode 100644 index 0000000000000..65936703ee13e --- /dev/null +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h @@ -0,0 +1,21 @@ +//===- FuncToEmitCPass.h - Func to EmitC Pass -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H +#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTFUNCTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 0bfc5064c5dd7..81f69210fade8 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index ec0a6284fe97d..94fc7a7d2194b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -359,6 +359,15 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> { ]; } +//===----------------------------------------------------------------------===// +// FuncToEmitC +//===----------------------------------------------------------------------===// + +def ConvertFuncToEmitC : Pass<"convert-func-to-emitc", "ModuleOp"> { + let summary = "Convert Func dialect to EmitC dialect"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + //===----------------------------------------------------------------------===// // FuncToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 4dff26e23c428..3d38744527d59 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -20,6 +20,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index ada64f10a1675..5c8c3c9ce7bb3 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -16,8 +16,10 @@ include "mlir/Dialect/EmitC/IR/EmitCAttributes.td" include "mlir/Dialect/EmitC/IR/EmitCTypes.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/RegionKindInterface.td" @@ -386,6 +388,174 @@ def EmitC_ForOp : EmitC_Op<"for", let hasRegionVerifier = 1; } +def EmitC_CallOp : EmitC_Op<"call", + [CallOpInterface, + DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `emitc.call` operation represents a direct call to an `emitc.func` + that is within the same symbol scope as the call. The operands and result type + of the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def EmitC_FuncOp : EmitC_Op<"func", [ + AutomaticAllocationScope, + FunctionOpInterface, IsolatedFromAbove +]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). While the MLIR textual form provides a nice + inline syntax for function arguments, they are internally represented as + “block arguments” to the first block in the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // A function with no results: + emitc.func @foo(%arg0 : i32) { + emitc.call_opaque "bar" (%arg0) : (i32) -> () + emitc.return + } + + // A function with its argument as single result: + emitc.func @foo(%arg0 : i32) -> i32 { + emitc.return %arg0 : i32 + } + + // A function with specifiers attribute: + emitc.func @example_specifiers_fn_attr() -> i32 + attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo" (): () -> i32 + emitc.return %0 : i32 + } + + ``` + }]; + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$specifiers, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, + ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `emitc.return` operation represents a return operation within a function. + The operation takes zero or exactly one operand and produces no results. + The operand number and type must match the signature of the function + that contains the operation. + + Example: + + ```mlir + emitc.func @foo() : (i32) { + ... + emitc.return %0 : i32 + } + ``` + }]; + let arguments = (ins Optional:$operand); + + let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?"; + let hasVerifier = 1; +} + def EmitC_IncludeOp : EmitC_Op<"include", [HasParent<"ModuleOp">]> { let summary = "Include operation"; diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 3a5dbc12c23f5..9e421f7c49dbc 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSCF) add_subdirectory(ControlFlowToSPIRV) add_subdirectory(ConvertToLLVM) +add_subdirectory(FuncToEmitC) add_subdirectory(FuncToLLVM) add_subdirectory(FuncToSPIRV) add_subdirectory(GPUCommon) diff --git a/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt new file mode 100644 index 0000000000000..97752205bbcb4 --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRFuncToEmitC + FuncToEmitC.cpp + FuncToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/FuncToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIREmitCDialect + MLIRFuncDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp new file mode 100644 index 0000000000000..ac3d8297953f3 --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -0,0 +1,116 @@ +//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert the Func dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +class CallOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Multiple results func was not converted to `emitc.func`. + if (callOp.getNumResults() > 1) + return rewriter.notifyMatchFailure( + callOp, "only functions with zero or one result can be converted"); + + rewriter.replaceOpWithNewOp( + callOp, + callOp.getNumResults() ? callOp.getResult(0).getType() : nullptr, + adaptor.getOperands(), callOp->getAttrs()); + + return success(); + } +}; + +class FuncOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (funcOp.getFunctionType().getNumResults() > 1) + return rewriter.notifyMatchFailure( + funcOp, "only functions with zero or one result can be converted"); + + if (funcOp.isDeclaration()) + return rewriter.notifyMatchFailure(funcOp, + "declarations cannot be converted"); + + // Create the converted `emitc.func` op. + emitc::FuncOp newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); + + // Copy over all attributes other than the function name and type. + for (const auto &namedAttr : funcOp->getAttrs()) { + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && + namedAttr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + // Add `static` to specifiers if `func.func` is private. + if (funcOp.isPrivate()) { + ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"}); + newFuncOp.setSpecifiersAttr(specifiers); + } + + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + rewriter.eraseOp(funcOp); + + return success(); + } +}; + +class ReturnOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (returnOp.getNumOperands() > 1) + return rewriter.notifyMatchFailure( + returnOp, "only zero or one operand is supported"); + + rewriter.replaceOpWithNewOp( + returnOp, + returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + + patterns.add(ctx); +} diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp new file mode 100644 index 0000000000000..26d32e29bef8c --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp @@ -0,0 +1,47 @@ +//===- FuncToEmitC.cpp - Func to EmitC Pass ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Func dialect to the EmitC dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTFUNCTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertFuncToEmitC + : public impl::ConvertFuncToEmitCBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertFuncToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(&getContext()); + populateFuncToEmitCPatterns(patterns); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt index 4665c41a62e80..4cc54201d2745 100644 --- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt @@ -9,8 +9,10 @@ add_mlir_dialect_library(MLIREmitCDialect MLIREmitCAttributesIncGen LINK_LIBS PUBLIC + MLIRCallInterfaces MLIRCastInterfaces MLIRControlFlowInterfaces + MLIRFunctionInterfaces MLIRIR MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5f502f1f7a171..df489c6d90fb1 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -8,7 +8,10 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -347,6 +350,122 @@ LogicalResult ForOp::verifyRegions() { return success(); } +//===----------------------------------------------------------------------===// +// CallOp +//===----------------------------------------------------------------------===// + +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +FunctionType CallOp::getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +LogicalResult FuncOp::verify() { + if (getNumResults() > 1) + return emitOpError("requires zero or exactly one result, but has ") + << getNumResults(); + + if (isExternal()) + return emitOpError("does not support empty function bodies"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + if (getNumOperands() != function.getNumResults()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << function.getNumResults(); + + if (function.getNumResults() == 1) + if (getOperand().getType() != function.getResultTypes()[0]) + return emitError() << "type of the return operand (" + << getOperand().getType() + << ") doesn't match function result type (" + << function.getResultTypes()[0] << ")" + << " in function @" << function.getName(); + return success(); +} + //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 72b382709925e..c0c6105409f8d 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -504,18 +504,33 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } -static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { - if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) +static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, + StringRef callee) { + if (failed(emitter.emitAssignPrefix(*callOp))) return failure(); raw_ostream &os = emitter.ostream(); - os << callOp.getCallee() << "("; - if (failed(emitter.emitOperands(*callOp.getOperation()))) + os << callee << "("; + if (failed(emitter.emitOperands(*callOp))) return failure(); os << ")"; return success(); } +static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { + Operation *operation = callOp.getOperation(); + StringRef callee = callOp.getCallee(); + + return printCallOperation(emitter, operation, callee); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { + Operation *operation = callOp.getOperation(); + StringRef callee = callOp.getCallee(); + + return printCallOperation(emitter, operation, callee); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOpaqueOp callOpaqueOp) { raw_ostream &os = emitter.ostream(); @@ -733,6 +748,19 @@ static LogicalResult printOperation(CppEmitter &emitter, } } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ReturnOp returnOp) { + raw_ostream &os = emitter.ostream(); + os << "return"; + if (returnOp.getNumOperands() == 0) + return success(); + + os << " "; + if (failed(emitter.emitOperand(returnOp.getOperand()))) + return failure(); + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { CppEmitter::Scope scope(emitter); @@ -743,39 +771,34 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { return success(); } -static LogicalResult printOperation(CppEmitter &emitter, - func::FuncOp functionOp) { - // We need to declare variables at top if the function has multiple blocks. - if (!emitter.shouldDeclareVariablesAtTop() && - functionOp.getBlocks().size() > 1) { - return functionOp.emitOpError( - "with multiple blocks needs variables declared at top"); - } - - CppEmitter::Scope scope(emitter); +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + Region::BlockArgListType arguments) { raw_indented_ostream &os = emitter.ostream(); - if (failed(emitter.emitTypes(functionOp.getLoc(), - functionOp.getFunctionType().getResults()))) - return failure(); - os << " " << functionOp.getName(); - os << "("; if (failed(interleaveCommaWithError( - functionOp.getArguments(), os, - [&](BlockArgument arg) -> LogicalResult { - if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) + arguments, os, [&](BlockArgument arg) -> LogicalResult { + if (failed(emitter.emitType(functionOp->getLoc(), arg.getType()))) return failure(); os << " " << emitter.getOrCreateName(arg); return success(); }))) return failure(); - os << ") {\n"; + + return success(); +} + +static LogicalResult printFunctionBody(CppEmitter &emitter, + Operation *functionOp, + Region::BlockListType &blocks) { + raw_indented_ostream &os = emitter.ostream(); os.indent(); + if (emitter.shouldDeclareVariablesAtTop()) { // Declare all variables that hold op results including those from nested // regions. WalkResult result = - functionOp.walk([&](Operation *op) -> WalkResult { + functionOp->walk([&](Operation *op) -> WalkResult { if (isa(op) || isa(op->getParentOp()) || (isa(op) && @@ -794,7 +817,6 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); } - Region::BlockListType &blocks = functionOp.getBlocks(); // Create label names for basic blocks. for (Block &block : blocks) { emitter.getOrCreateName(block); @@ -804,7 +826,7 @@ static LogicalResult printOperation(CppEmitter &emitter, for (Block &block : llvm::drop_begin(blocks)) { for (BlockArgument &arg : block.getArguments()) { if (emitter.hasValueInScope(arg)) - return functionOp.emitOpError(" block argument #") + return functionOp->emitOpError(" block argument #") << arg.getArgNumber() << " is out of scope"; if (failed( emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { @@ -833,7 +855,71 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); } } - os.unindent() << "}\n"; + + os.unindent(); + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + func::FuncOp functionOp) { + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ") {\n"; + if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) + return failure(); + os << "}\n"; + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::FuncOp functionOp) { + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (functionOp.getSpecifiers()) { + for (Attribute specifier : functionOp.getSpecifiersAttr()) { + os << cast(specifier).str() << " "; + } + } + + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ") {\n"; + if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) + return failure(); + os << "}\n"; + return success(); } @@ -1148,12 +1234,12 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case( [&](auto op) { return printOperation(*this, op); }) // EmitC ops. - .Case( + emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp, + emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir new file mode 100644 index 0000000000000..a1c8af2587aa0 --- /dev/null +++ b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt -split-input-file -convert-func-to-emitc %s | FileCheck %s + +// CHECK-LABEL: emitc.func @foo() +// CHECK-NEXT: emitc.return +func.func @foo() { + return +} + +// ----- + +// CHECK-LABEL: emitc.func private @foo() attributes {specifiers = ["static"]} +// CHECK-NEXT: emitc.return +func.func private @foo() { + return +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32) +func.func @foo(%arg0: i32) { + emitc.call_opaque "bar"(%arg0) : (i32) -> () + return +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32) -> i32 +// CHECK-NEXT: emitc.return %arg0 : i32 +func.func @foo(%arg0: i32) -> i32 { + return %arg0 : i32 +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32, %arg1: i32) -> i32 +func.func @foo(%arg0: i32, %arg1: i32) -> i32 { + %0 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: emitc.func private @return_i32(%arg0: i32) -> i32 attributes {specifiers = ["static"]} +// CHECK-NEXT: emitc.return %arg0 : i32 +func.func private @return_i32(%arg0: i32) -> i32 { + return %arg0 : i32 +} + +// CHECK-LABEL: emitc.func @call(%arg0: i32) -> i32 +// CHECK-NEXT: %0 = emitc.call @return_i32(%arg0) : (i32) -> i32 +// CHECK-NEXT: emitc.return %0 : i32 +func.func @call(%arg0: i32) -> i32 { + %0 = call @return_i32(%arg0) : (i32) -> (i32) + return %0 : i32 +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 46eccb1c24eea..707f9a5b23b0b 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -289,3 +289,40 @@ func.func @test_expression_multiple_results(%arg0: i32) -> i32 { } return %r : i32 } + +// ----- + +// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}} +emitc.func @multiple_results(%0: i32) -> (i32, i32) { + emitc.return %0 : i32 +} + +// ----- + +emitc.func @resulterror() -> i32 { +^bb42: + emitc.return // expected-error {{'emitc.return' op has 0 operands, but enclosing function (@resulterror) returns 1}} +} + +// ----- + +emitc.func @return_type_mismatch() -> i32 { + %0 = emitc.call_opaque "foo()"(): () -> f32 + emitc.return %0 : f32 // expected-error {{type of the return operand ('f32') doesn't match function result type ('i32') in function @return_type_mismatch}} +} + +// ----- + +func.func @return_inside_func.func(%0: i32) -> (i32) { + // expected-error@+1 {{'emitc.return' op expects parent op 'emitc.func'}} + emitc.return %0 : i32 +} +// ----- + +// expected-error@+1 {{expected non-function type}} +emitc.func @func_variadic(...) + +// ----- + +// expected-error@+1 {{'emitc.func' op does not support empty function bodies}} +emitc.func private @empty() diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 7ad3787558b7f..b41333faa4d4e 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -15,6 +15,21 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) { return } +emitc.func @func(%arg0 : i32) { + emitc.call_opaque "foo"(%arg0) : (i32) -> () + emitc.return +} + +emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo"(): () -> i32 + emitc.return %0 : i32 +} + +emitc.func @call() -> i32 { + %0 = emitc.call @return_i32() : () -> (i32) + emitc.return %0 : i32 +} + func.func @cast(%arg0: i32) { %1 = emitc.cast %arg0: i32 to f32 return diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir new file mode 100644 index 0000000000000..d2e14a9e5a7ae --- /dev/null +++ b/mlir/test/Target/Cpp/func.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP + + +emitc.func @emitc_func(%arg0 : i32) { + emitc.call_opaque "foo" (%arg0) : (i32) -> () + emitc.return +} +// CPP-DEFAULT: void emitc_func(int32_t [[V0:[^ ]*]]) { +// CPP-DEFAULT-NEXT: foo([[V0:[^ ]*]]); +// CPP-DEFAULT-NEXT: return; + + +emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo" (): () -> i32 + emitc.return %0 : i32 +} +// CPP-DEFAULT: static inline int32_t return_i32() { +// CPP-DEFAULT-NEXT: [[V0:[^ ]*]] = foo(); +// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]]; + +// CPP-DECLTOP: static inline int32_t return_i32() { +// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DECLTOP-NEXT: [[V0:]] = foo(); +// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]]; + + +emitc.func @emitc_call() -> i32 { + %0 = emitc.call @return_i32() : () -> (i32) + emitc.return %0 : i32 +} +// CPP-DEFAULT: int32_t emitc_call() { +// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = return_i32(); +// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]]; + +// CPP-DECLTOP: int32_t emitc_call() { +// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DECLTOP-NEXT: [[V0:[^ ]*]] = return_i32(); +// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]]; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index bc7f48a563b0e..7dd52ffff2583 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1583,8 +1583,10 @@ td_library( includes = ["include"], deps = [ ":BuiltinDialectTdFiles", + ":CallInterfacesTdFiles", ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":FunctionInterfacesTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -3687,10 +3689,12 @@ cc_library( ]), includes = ["include"], deps = [ + ":CallOpInterfaces", ":CastInterfaces", ":ControlFlowInterfaces", ":EmitCAttributesIncGen", ":EmitCOpsIncGen", + ":FunctionInterfaces", ":IR", ":SideEffectInterfaces", "//llvm:Support", @@ -3927,6 +3931,7 @@ cc_library( ":ControlFlowToSPIRV", ":ConversionPassIncGen", ":ConvertToLLVM", + ":FuncToEmitC", ":FuncToLLVM", ":FuncToSPIRV", ":GPUToGPURuntimeTransforms", @@ -6865,6 +6870,32 @@ cc_library( ], ) +cc_library( + name = "FuncToEmitC", + srcs = glob([ + "lib/Conversion/FuncToEmitC*.cpp", + "lib/Conversion/FuncToEmitC/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/FuncToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/FuncToEmitC", + ], + deps = [ + ":ConversionPassIncGen", + ":FuncDialect", + ":EmitCDialect", + ":IR", + ":Pass", + ":Support", + ":TransformUtils", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "FuncToSPIRV", srcs = glob([