diff --git a/offload/liboffload/API/CMakeLists.txt b/offload/liboffload/API/CMakeLists.txt index e4baa4772a1ef..daa8f382197df 100644 --- a/offload/liboffload/API/CMakeLists.txt +++ b/offload/liboffload/API/CMakeLists.txt @@ -19,6 +19,7 @@ offload_tablegen(OffloadEntryPoints.inc -gen-entry-points) offload_tablegen(OffloadFuncs.inc -gen-func-names) offload_tablegen(OffloadImplFuncDecls.inc -gen-impl-func-decls) offload_tablegen(OffloadPrint.hpp -gen-print-header) +offload_tablegen(OffloadTypedGetInfo.inc -gen-get-info-wrappers) add_public_tablegen_target(OffloadGenerate) diff --git a/offload/test/tools/offload-tblgen/get_info_wrappers.td b/offload/test/tools/offload-tblgen/get_info_wrappers.td new file mode 100644 index 0000000000000..a35ece354f296 --- /dev/null +++ b/offload/test/tools/offload-tblgen/get_info_wrappers.td @@ -0,0 +1,60 @@ +// RUN: %offload-tblgen -gen-get-info-wrappers -I %S/../../../liboffload/API %s | %fcheck-generic + +include "APIDefs.td" + +def ol_foo_handle_t : Handle { +} + +def ol_foo_info_t : Enum { + let is_typed = 1; + let etors = [ + TaggedEtor<"INT", "int", "">, + TaggedEtor<"STRING", "char[]", "">, + TaggedEtor<"ARRAY", "int[]", "">, + ]; +} + +def olGetFooInfo : Function { + let params = [ + Param<"ol_foo_handle_t", "Foo", "", PARAM_IN>, + Param<"ol_foo_info_t", "PropName", "", PARAM_IN>, + Param<"size_t", "PropSize", "", PARAM_IN>, + TypeTaggedParam<"void*", "PropValue", "array of bytes holding the info.", PARAM_OUT, + TypeInfo<"PropName", "PropSize">> + ]; + let returns = [ + Return<"OL_FOO_INVALID"> + ]; +} + +// CHECK-LABEL: template inline auto get_info(ol_foo_handle_t Foo); +// CHECK-NEXT: template<> inline auto get_info(ol_foo_handle_t Foo) { +// CHECK-NEXT: int Result; +// CHECK-NEXT: if (auto Err = olGetFooInfo(Foo, OL_FOO_INFO_INT, 1, &Result)) +// CHECK-NEXT: return std::variant{Err}; +// CHECK-NEXT: else +// CHECK-NEXT: return std::variant{Result}; +// CHECK-NEXT: } +// CHECK-NEXT: template<> inline auto get_info(ol_foo_handle_t Foo) { +// CHECK-NEXT: std::string Result; +// CHECK-NEXT: size_t ResultSize = 0; +// CHECK-NEXT: if (auto Err = olGetFooInfoSize(Foo, OL_FOO_INFO_STRING, &ResultSize)) +// CHECK-NEXT: return std::variant{Err}; +// CHECK-NEXT: Result.resize(ResultSize - 1); +// CHECK-NEXT: if (auto Err = olGetFooInfo(Foo, OL_FOO_INFO_STRING, ResultSize, Result.data())) +// CHECK-NEXT: return std::variant{Err}; +// CHECK-NEXT: else +// CHECK-NEXT: return std::variant{Result}; +// CHECK-NEXT: } +// CHECK-NEXT: template<> inline auto get_info(ol_foo_handle_t Foo) { +// CHECK-NEXT: std::vector Result; +// CHECK-NEXT: size_t ResultSize = 0; +// CHECK-NEXT: if (auto Err = olGetFooInfoSize(Foo, OL_FOO_INFO_ARRAY, &ResultSize)) +// CHECK-NEXT: return std::variant, ol_result_t>{Err}; +// CHECK-NEXT: assert(ResultSize % sizeof(int) == 0); +// CHECK-NEXT: Result.resize(ResultSize / sizeof(int)); +// CHECK-NEXT: if (auto Err = olGetFooInfo(Foo, OL_FOO_INFO_ARRAY, ResultSize, Result.data())) +// CHECK-NEXT: return std::variant, ol_result_t>{Err}; +// CHECK-NEXT: else +// CHECK-NEXT: return std::variant, ol_result_t>{Result}; +// CHECK-NEXT: } diff --git a/offload/tools/offload-tblgen/CMakeLists.txt b/offload/tools/offload-tblgen/CMakeLists.txt index a5ae1c3757fbf..bc3c4fa5b6ef7 100644 --- a/offload/tools/offload-tblgen/CMakeLists.txt +++ b/offload/tools/offload-tblgen/CMakeLists.txt @@ -20,6 +20,7 @@ add_tablegen(offload-tblgen OFFLOAD offload-tblgen.cpp PrintGen.cpp RecordTypes.hpp + TypedGetInfoWrappers.cpp ) # Make sure that C++ headers are available, if libcxx is built at the same diff --git a/offload/tools/offload-tblgen/Generators.hpp b/offload/tools/offload-tblgen/Generators.hpp index fda63f8b198e5..84e14ea0c16e9 100644 --- a/offload/tools/offload-tblgen/Generators.hpp +++ b/offload/tools/offload-tblgen/Generators.hpp @@ -25,3 +25,5 @@ void EmitOffloadExports(const llvm::RecordKeeper &Records, void EmitOffloadErrcodes(const llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitOffloadInfo(const llvm::RecordKeeper &Records, llvm::raw_ostream &OS); +void EmitTypedGetInfoWrappers(const llvm::RecordKeeper &Records, + llvm::raw_ostream &OS); diff --git a/offload/tools/offload-tblgen/TypedGetInfoWrappers.cpp b/offload/tools/offload-tblgen/TypedGetInfoWrappers.cpp new file mode 100644 index 0000000000000..98ea538dcccbc --- /dev/null +++ b/offload/tools/offload-tblgen/TypedGetInfoWrappers.cpp @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen backend that produces typed C++ inline wrappers for +// various `olGet*Info interfaces. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#include "GenCommon.hpp" +#include "RecordTypes.hpp" + +using namespace llvm; +using namespace offload::tblgen; + +void EmitTypedGetInfoWrappers(const llvm::RecordKeeper &Records, + llvm::raw_ostream &OS) { + OS << GenericHeader; + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + auto Name = R->getName(); + if (!Name.starts_with("olGet") || !Name.ends_with("Info")) + continue; + auto F = FunctionRec{R}; + auto Params = F.getParams(); + assert(Params.size() == 4); + auto Object = Params[0]; + auto InfoDesc = Params[1]; + + OS << formatv("template <{} Desc> inline auto get_info({} {});\n", + InfoDesc.getType(), Object.getType(), Object.getName()); + + EnumRec E{Records.getDef(InfoDesc.getType())}; + for (auto &V : E.getValues()) { + auto Desc = E.getEnumValNamePrefix() + "_" + V.getName(); + auto TaggedType = V.getTaggedType(); + auto ElementType = TaggedType.rtrim("[]"); + auto ResultType = [&]() -> std::string { + if (!TaggedType.ends_with("[]")) + return TaggedType.str(); + if (TaggedType == "char[]") + return "std::string"; + + return ("std::vector<" + ElementType + ">").str(); + }(); + auto ReturnType = + "std::variant<" + ResultType + ", " + PrefixLower + "_result_t>"; + OS << formatv("template<> inline auto get_info<{}>({} {}) {{\n", Desc, + Object.getType(), Object.getName()); + if (TaggedType.ends_with("[]")) { + OS << TAB_1 << formatv("{0} Result;\n", ResultType); + OS << TAB_1 << "size_t ResultSize = 0;\n"; + OS << TAB_1 + << formatv("if (auto Err = {}Size({}, {}, &ResultSize))\n", + F.getName(), Object.getName(), Desc); + OS << TAB_2 << formatv("return {}{{Err};\n", ReturnType); + if (TaggedType == "char[]") { + // Null terminator isn't counted in `std::string::size()`. + OS << TAB_1 << "Result.resize(ResultSize - 1);\n"; + } else { + OS << TAB_1 + << formatv("assert(ResultSize % sizeof({}) == 0);\n", ElementType); + OS << TAB_1 + << formatv("Result.resize(ResultSize / sizeof({}));\n", + ElementType); + } + OS << TAB_1 + << formatv("if (auto Err = {}({}, {}, ResultSize, Result.data()))\n", + F.getName(), Object.getName(), Desc); + OS << TAB_2 << formatv("return {0}{{Err};\n", ReturnType); + OS << TAB_1 << "else\n"; + OS << TAB_2 << formatv("return {0}{{Result};\n", ReturnType); + } else { + OS << TAB_1 << formatv("{0} Result;\n", TaggedType); + OS << TAB_1 + << formatv("if (auto Err = {}({}, {}, 1, &Result))\n", F.getName(), + Object.getName(), Desc); + OS << TAB_2 << formatv("return {0}{{Err};\n", ReturnType); + OS << TAB_1 << "else\n"; + OS << TAB_2 << formatv("return {0}{{Result};\n", ReturnType); + } + OS << "}\n"; + } + OS << "\n"; + } +} diff --git a/offload/tools/offload-tblgen/offload-tblgen.cpp b/offload/tools/offload-tblgen/offload-tblgen.cpp index 18aaf9e00f08a..3a0f56acdc460 100644 --- a/offload/tools/offload-tblgen/offload-tblgen.cpp +++ b/offload/tools/offload-tblgen/offload-tblgen.cpp @@ -34,6 +34,7 @@ enum ActionType { GenExports, GenErrcodes, GenInfo, + GenTypedGetInfoWrappers, }; namespace { @@ -60,7 +61,10 @@ cl::opt Action( "Generate export file for the Offload library"), clEnumValN(GenErrcodes, "gen-errcodes", "Generate Offload Error Code enum"), - clEnumValN(GenInfo, "gen-info", "Generate Offload Info enum"))); + clEnumValN(GenInfo, "gen-info", "Generate Offload Info enum"), + clEnumValN(GenTypedGetInfoWrappers, "gen-get-info-wrappers", + "Generate typed C++ wrappers around various olGet*Info " + "interfaces"))); } static bool OffloadTableGenMain(raw_ostream &OS, const RecordKeeper &Records) { @@ -98,6 +102,8 @@ static bool OffloadTableGenMain(raw_ostream &OS, const RecordKeeper &Records) { case GenInfo: EmitOffloadInfo(Records, OS); break; + case GenTypedGetInfoWrappers: + EmitTypedGetInfoWrappers(Records, OS); } return false;