diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index af3db4a64..c95234e8e 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -467,14 +467,14 @@ namespace winrt::impl } template - Expression&& await_transform(Expression&& expression) + auto await_transform(Expression&& expression) { if (Status() == AsyncStatus::Canceled) { throw winrt::hresult_canceled(); } - return std::forward(expression); + return notify_awaiter{ static_cast(expression) }; } cancellation_token await_transform(get_cancellation_token_t) noexcept diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index b983333f5..9f9e75a3c 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -53,6 +53,154 @@ namespace winrt::impl } } } + + template + class has_awaitable_member + { + template ().await_ready())> static constexpr bool get_value(int) { return true; } + template static constexpr bool get_value(...) { return false; } + + public: + + static constexpr bool value = get_value(0); + }; + + template + class has_awaitable_free + { + template ()))> static constexpr bool get_value(int) { return true; } + template static constexpr bool get_value(...) { return false; } + + public: + + static constexpr bool value = get_value(0); + }; + + template + struct free_await_adapter_impl + { + T&& awaitable; + + bool ready() + { + return await_ready(awaitable); + } + + template + auto suspend(std::experimental::coroutine_handle handle) + { + return await_suspend(awaitable, handle); + } + + auto resume() + { + return await_resume(awaitable); + } + }; + + template + struct free_await_adapter + { + T&& awaitable; + + bool await_ready() + { + return free_await_adapter_impl{ static_cast(awaitable) }.ready(); + } + + template + auto await_suspend(std::experimental::coroutine_handle handle) + { + return free_await_adapter_impl{ static_cast(awaitable) }.suspend(handle); + } + + auto await_resume() + { + return free_await_adapter_impl{ static_cast(awaitable) }.resume(); + } + }; + + template + struct member_await_adapter + { + T&& awaitable; + + bool await_ready() + { + return awaitable.await_ready(); + } + + template + auto await_suspend(std::experimental::coroutine_handle handle) + { + return awaitable.await_suspend(handle); + } + + auto await_resume() + { + return awaitable.await_resume(); + } + }; + + template + auto get_awaiter(T&& value) noexcept -> decltype(static_cast(value).operator co_await()) + { + return static_cast(value).operator co_await(); + } + + template + auto get_awaiter(T&& value) noexcept -> decltype(operator co_await(static_cast(value))) + { + return operator co_await(static_cast(value)); + } + + template ::value, int> = 0> + auto get_awaiter(T&& value) noexcept + { + return member_await_adapter{ static_cast(value) }; + } + + template ::value, int> = 0> + auto get_awaiter(T&& value) noexcept + { + return free_await_adapter{ static_cast(value) }; + } + + template + struct notify_awaiter + { + decltype(get_awaiter(std::declval())) awaitable; + + notify_awaiter(T&& awaitable) : awaitable(get_awaiter(static_cast(awaitable))) + { + } + + bool await_ready() + { + if (winrt_suspend_handler) + { + winrt_suspend_handler(this); + } + + return awaitable.await_ready(); + } + + template + auto await_suspend(std::experimental::coroutine_handle handle) + { + return awaitable.await_suspend(handle); + } + + auto await_resume() + { + if (winrt_resume_handler) + { + winrt_resume_handler(this); + } + + return awaitable.await_resume(); + } + }; } WINRT_EXPORT namespace winrt @@ -378,6 +526,12 @@ namespace std::experimental { winrt::terminate(); } + + template + auto await_transform(Expression&& expression) + { + return winrt::impl::notify_awaiter{ static_cast(expression) }; + } }; }; } diff --git a/strings/base_error.h b/strings/base_error.h index 9332ac797..4c03162a4 100644 --- a/strings/base_error.h +++ b/strings/base_error.h @@ -1,6 +1,4 @@ -__declspec(selectany) int32_t (__stdcall *winrt_to_hresult_handler)(void* address) noexcept{}; - namespace winrt::impl { struct heap_traits diff --git a/strings/base_extern.h b/strings/base_extern.h index 0c5fd116c..a149ebd69 100644 --- a/strings/base_extern.h +++ b/strings/base_extern.h @@ -1,4 +1,8 @@ +__declspec(selectany) int32_t(__stdcall* winrt_to_hresult_handler)(void* address) noexcept {}; +__declspec(selectany) void(__stdcall* winrt_suspend_handler)(void const* token) noexcept {}; +__declspec(selectany) void(__stdcall* winrt_resume_handler)(void const* token) noexcept {}; + extern "C" { int32_t __stdcall WINRT_GetRestrictedErrorInfo(void** info) noexcept; diff --git a/test/test/disconnected.cpp b/test/test/disconnected.cpp index 32f53bf35..583b7df36 100644 --- a/test/test/disconnected.cpp +++ b/test/test/disconnected.cpp @@ -79,16 +79,20 @@ TEST_CASE("disconnected") { auto async = ActionProgress(); + handle signal{ CreateEventW(nullptr, true, false, nullptr) }; async.Progress([](auto&&...) { throw hresult_error(RPC_E_DISCONNECTED); }); - async.Completed([](auto&&...) + async.Completed([&](auto&&...) { + SetEvent(signal.get()); throw hresult_error(RPC_E_DISCONNECTED); }); + + WaitForSingleObject(signal.get(), INFINITE); } { @@ -102,15 +106,19 @@ TEST_CASE("disconnected") { auto async = OperationProgress(); + handle signal{ CreateEventW(nullptr, true, false, nullptr) }; async.Progress([](auto&&...) { throw hresult_error(RPC_E_DISCONNECTED); }); - async.Completed([](auto&&...) + async.Completed([&](auto&&...) { + SetEvent(signal.get()); throw hresult_error(RPC_E_DISCONNECTED); }); + + WaitForSingleObject(signal.get(), INFINITE); } } diff --git a/test/test/generic_types.cpp b/test/test/generic_types.cpp index 4c8e68e9b..77c61931e 100644 --- a/test/test/generic_types.cpp +++ b/test/test/generic_types.cpp @@ -8,5 +8,9 @@ TEST_CASE("generic_types") REQUIRE_EQUAL_NAME(L"Windows.Foundation.Uri", Uri); REQUIRE_EQUAL_NAME(L"Windows.Foundation.PropertyType", PropertyType); REQUIRE_EQUAL_NAME(L"Windows.Foundation.Point", Point); + + // Clang 9 doesn't think this is a constant expression. +#ifndef __clang__ REQUIRE_EQUAL_NAME(L"{96369f54-8eb6-48f0-abce-c1b211e627c3}", IStringable); +#endif } diff --git a/test/test/notify_awaiter.cpp b/test/test/notify_awaiter.cpp new file mode 100644 index 000000000..9b074d2e4 --- /dev/null +++ b/test/test/notify_awaiter.cpp @@ -0,0 +1,217 @@ +#include "pch.h" + +using namespace winrt; +using namespace Windows::Foundation; + +namespace +{ + struct free_awaitable + { + }; + bool await_ready(free_awaitable) + { + return true; + } + void await_suspend(free_awaitable, std::experimental::coroutine_handle<>) + { + } + void await_resume(free_awaitable) + { + + } + + struct member_awaitable + { + bool await_ready() + { + return true; + } + void await_suspend(std::experimental::coroutine_handle<>) + { + } + void await_resume() + { + + } + }; + + struct free_operator_awaitable + { + }; + auto operator co_await(free_operator_awaitable) + { + struct awaitable + { + bool await_ready() + { + return true; + } + void await_suspend(std::experimental::coroutine_handle<>) + { + } + void await_resume() + { + } + }; + return awaitable{}; + } + + struct member_operator_awaitable + { + auto operator co_await() + { + struct awaitable + { + bool await_ready() + { + return true; + } + void await_suspend(std::experimental::coroutine_handle<>) + { + } + void await_resume() + { + } + }; + return awaitable{}; + } + }; + + struct no_copy_awaitable + { + no_copy_awaitable() = default; + no_copy_awaitable(no_copy_awaitable const&) = delete; + + bool await_ready() + { + return true; + } + void await_suspend(std::experimental::coroutine_handle<>) + { + } + void await_resume() + { + + } + }; + + IAsyncAction AsyncAction() + { + co_return; + } + IAsyncActionWithProgress AsyncActionWithProgress() + { + co_return; + } + IAsyncOperation AsyncOperation() + { + co_return 0; + } + IAsyncOperationWithProgress AsyncOperationWithProgress() + { + co_return 0; + } + + struct notification + { + uint32_t suspend{}; + uint32_t resume{}; + }; + + static std::map watcher; + static slim_mutex lock; + static handle start_racing{ CreateEventW(nullptr, true, false, nullptr) }; + constexpr size_t test_coroutines = 20; + constexpr size_t test_suspension_points = 12; + + IAsyncAction Async() + { + co_await resume_on_signal(start_racing.get()); + co_await resume_background(); + co_await resume_background(); + co_await free_awaitable{}; + co_await member_awaitable{}; + co_await free_operator_awaitable{}; + co_await member_operator_awaitable{}; + co_await no_copy_awaitable{}; + co_await AsyncAction(); + co_await AsyncActionWithProgress(); + co_await AsyncOperation(); + co_await AsyncOperationWithProgress(); + } +} + +TEST_CASE("notify_awaiter") +{ + // Everything works fine when nobody is watching. + + REQUIRE(!winrt_suspend_handler); + REQUIRE(!winrt_resume_handler); + SetEvent(start_racing.get()); + Async().get(); + ResetEvent(start_racing.get()); + + // Hook up some watchers. + + winrt_suspend_handler = [](void const* token) noexcept + { + slim_lock_guard guard(lock); + watcher[token].suspend += 1; + }; + + winrt_resume_handler = [](void const* token) noexcept + { + slim_lock_guard guard(lock); + watcher[token].resume += 1; + }; + + // Prepare a few coroutines. + + std::vector concurrency; + REQUIRE(watcher.empty()); + + for (size_t i = 0; i != test_coroutines; ++i) + { + concurrency.push_back(Async()); + } + + // Give coroutines a moment to get to the starting line. + + Sleep(1000); + + // Each coroutine should have suspended once. + + REQUIRE(concurrency.size() == test_coroutines); + REQUIRE(watcher.size() == test_coroutines); + + for (auto&& [_, tally] : watcher) + { + REQUIRE(tally.suspend == 1); + REQUIRE(tally.resume == 0); + } + + // And the race is on! + + SetEvent(start_racing.get()); + + for (auto&& async : concurrency) + { + async.get(); + } + + // Each suspension point should have been recorded. + + REQUIRE(watcher.size() == test_coroutines * test_suspension_points); + + for (auto&& [_, tally] : watcher) + { + // And should be be perfectly balanced. + REQUIRE(tally.suspend == 1); + REQUIRE(tally.resume == 1); + } + + // Remove watchers. + + winrt_suspend_handler = nullptr; + winrt_resume_handler = nullptr; +} diff --git a/test/test/test.vcxproj b/test/test/test.vcxproj index 18600b663..398cf4e3c 100644 --- a/test/test/test.vcxproj +++ b/test/test/test.vcxproj @@ -225,6 +225,7 @@ + NotUsing