diff --git a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp index 23d32a041a91d..fafc2a4b18e9b 100644 --- a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp +++ b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp @@ -127,3 +127,51 @@ TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) { (void *)&addMethodWrapper, Result, ExecutorAddr::fromPtr(&AddObj), 2)); EXPECT_EQ(Result, (int32_t)3); } + +// A non-SPS wrapper function that calculates the sum of a byte array. +static __orc_rt_CWrapperFunctionResult sumArrayRawWrapper(const char *ArgData, + size_t ArgSize) { + auto WFR = WrapperFunctionResult::allocate(1); + *WFR.data() = 0; + for (unsigned I = 0; I != ArgSize; ++I) + *WFR.data() += ArgData[I]; + return WFR.release(); +} + +TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) { + { + // Check raw wrapper function calls. + char A[] = {1, 2, 3, 4}; + + WrapperFunctionCall WFC{ExecutorAddr::fromPtr(sumArrayRawWrapper), + ExecutorAddrRange(ExecutorAddr::fromPtr(A), + ExecutorAddrDiff(sizeof(A)))}; + + WrapperFunctionResult WFR(WFC.run()); + EXPECT_EQ(WFR.size(), 1U); + EXPECT_EQ(WFR.data()[0], 10); + } + + { + // Check calls to void functions. + WrapperFunctionCall WFC{ExecutorAddr::fromPtr(voidNoopWrapper), + ExecutorAddrRange()}; + auto Err = WFC.runWithSPSRet(); + EXPECT_FALSE(!!Err); + } + + { + // Check calls with arguments and return values. + auto ArgWFR = + WrapperFunctionResult::fromSPSArgs>(2, 4); + WrapperFunctionCall WFC{ + ExecutorAddr::fromPtr(addWrapper), + ExecutorAddrRange(ExecutorAddr::fromPtr(ArgWFR.data()), + ExecutorAddrDiff(ArgWFR.size()))}; + + int32_t Result = 0; + auto Err = WFC.runWithSPSRet(Result); + EXPECT_FALSE(!!Err); + EXPECT_EQ(Result, 6); + } +} diff --git a/compiler-rt/lib/orc/wrapper_function_utils.h b/compiler-rt/lib/orc/wrapper_function_utils.h index cf92ad890cd17..23385e1bd7944 100644 --- a/compiler-rt/lib/orc/wrapper_function_utils.h +++ b/compiler-rt/lib/orc/wrapper_function_utils.h @@ -104,6 +104,16 @@ class WrapperFunctionResult { return createOutOfBandError(Msg.c_str()); } + template + static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { + auto Result = allocate(SPSArgListT::size(Args...)); + SPSOutputBuffer OB(Result.data(), Result.size()); + if (!SPSArgListT::serialize(OB, Args...)) + return createOutOfBandError( + "Error serializing arguments to blob in call"); + return Result; + } + /// If this value is an out-of-band error then this returns the error message, /// otherwise returns nullptr. const char *getOutOfBandError() const { @@ -116,17 +126,6 @@ class WrapperFunctionResult { namespace detail { -template -WrapperFunctionResult -serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { - auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...)); - SPSOutputBuffer OB(Result.data(), Result.size()); - if (!SPSArgListT::serialize(OB, Args...)) - return WrapperFunctionResult::createOutOfBandError( - "Error serializing arguments to blob in call"); - return Result; -} - template class WrapperFunctionHandlerCaller { public: template @@ -212,15 +211,14 @@ class WrapperFunctionHandlerHelper class ResultSerializer { public: static WrapperFunctionResult serialize(RetT Result) { - return serializeViaSPSToWrapperFunctionResult>( - Result); + return WrapperFunctionResult::fromSPSArgs>(Result); } }; template class ResultSerializer { public: static WrapperFunctionResult serialize(Error Err) { - return serializeViaSPSToWrapperFunctionResult>( + return WrapperFunctionResult::fromSPSArgs>( toSPSSerializable(std::move(Err))); } }; @@ -229,7 +227,7 @@ template class ResultSerializer> { public: static WrapperFunctionResult serialize(Expected E) { - return serializeViaSPSToWrapperFunctionResult>( + return WrapperFunctionResult::fromSPSArgs>( toSPSSerializable(std::move(E))); } }; @@ -304,8 +302,7 @@ class WrapperFunction { return make_error("__orc_rt_jit_dispatch not set"); auto ArgBuffer = - detail::serializeViaSPSToWrapperFunctionResult>( - Args...); + WrapperFunctionResult::fromSPSArgs>(Args...); if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) return make_error(ErrMsg); @@ -397,6 +394,64 @@ makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { return MethodWrapperHandler(Method); } +/// Represents a call to a wrapper function. +struct WrapperFunctionCall { + ExecutorAddr Func; + ExecutorAddrRange ArgData; + + WrapperFunctionCall() = default; + WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData) + : Func(Func), ArgData(ArgData) {} + + /// Run and return result as WrapperFunctionResult. + WrapperFunctionResult run() { + WrapperFunctionResult WFR( + Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()( + ArgData.Start.toPtr(), + static_cast(ArgData.size().getValue()))); + return WFR; + } + + /// Run call and deserialize result using SPS. + template Error runWithSPSRet(RetT &RetVal) { + auto WFR = run(); + if (const char *ErrMsg = WFR.getOutOfBandError()) + return make_error(ErrMsg); + SPSInputBuffer IB(WFR.data(), WFR.size()); + if (!SPSSerializationTraits::deserialize(IB, RetVal)) + return make_error("Could not deserialize result from " + "serialized wrapper function call"); + return Error::success(); + } + + /// Overload for SPS functions returning void. + Error runWithSPSRet() { + SPSEmpty E; + return runWithSPSRet(E); + } +}; + +class SPSWrapperFunctionCall {}; + +template <> +class SPSSerializationTraits { +public: + static size_t size(const WrapperFunctionCall &WFC) { + return SPSArgList::size(WFC.Func, + WFC.ArgData); + } + + static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { + return SPSArgList::serialize( + OB, WFC.Func, WFC.ArgData); + } + + static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { + return SPSArgList::deserialize( + IB, WFC.Func, WFC.ArgData); + } +}; + } // end namespace __orc_rt #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H