diff --git a/cppwinrt/code_writers.h b/cppwinrt/code_writers.h index ee082f717..c7f19d335 100644 --- a/cppwinrt/code_writers.h +++ b/cppwinrt/code_writers.h @@ -1305,12 +1305,12 @@ namespace cppwinrt else if (type_name == "Windows.Foundation.Collections.IMapView`2") { w.write(R"( - auto TryLookup(param_type const& key) const noexcept + auto TryLookup(param_type const& key) const { if constexpr (std::is_base_of_v) { V result{ nullptr }; - WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMapView)->Lookup(get_abi(key), put_abi(result)); + impl::check_hresult_allow_bounds(WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMapView)->Lookup(get_abi(key), put_abi(result))); return result; } else @@ -1318,7 +1318,7 @@ namespace cppwinrt std::optional result; V value{ empty_value() }; - if (0 == WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMapView)->Lookup(get_abi(key), put_abi(value))) + if (0 == impl::check_hresult_allow_bounds(WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMapView)->Lookup(get_abi(key), put_abi(value)))) { result = std::move(value); } @@ -1331,12 +1331,12 @@ namespace cppwinrt else if (type_name == "Windows.Foundation.Collections.IMap`2") { w.write(R"( - auto TryLookup(param_type const& key) const noexcept + auto TryLookup(param_type const& key) const { if constexpr (std::is_base_of_v) { V result{ nullptr }; - WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMap)->Lookup(get_abi(key), put_abi(result)); + impl::check_hresult_allow_bounds(WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMap)->Lookup(get_abi(key), put_abi(result))); return result; } else @@ -1344,7 +1344,7 @@ namespace cppwinrt std::optional result; V value{ empty_value() }; - if (0 == WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMap)->Lookup(get_abi(key), put_abi(value))) + if (0 == impl::check_hresult_allow_bounds(WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMap)->Lookup(get_abi(key), put_abi(value)))) { result = std::move(value); } @@ -1352,6 +1352,11 @@ namespace cppwinrt return result; } } + + auto TryRemove(param_type const& key) const + { + return 0 == impl::check_hresult_allow_bounds(WINRT_IMPL_SHIM(Windows::Foundation::Collections::IMap)->Remove(get_abi(key))); + } )"); } else if (type_name == "Windows.Foundation.IAsyncAction") diff --git a/strings/base_collections_base.h b/strings/base_collections_base.h index 641d4aadb..d798df991 100644 --- a/strings/base_collections_base.h +++ b/strings/base_collections_base.h @@ -1,4 +1,3 @@ - WINRT_EXPORT namespace winrt { template @@ -415,8 +414,14 @@ WINRT_EXPORT namespace winrt void Remove(K const& key) { + auto& container = static_cast(*this).get_container(); + auto found = container.find(static_cast(*this).wrap_value(key)); + if (found == container.end()) + { + throw hresult_out_of_bounds(); + } this->increment_version(); - static_cast(*this).get_container().erase(static_cast(*this).wrap_value(key)); + container.erase(found); } void Clear() noexcept diff --git a/strings/base_error.h b/strings/base_error.h index 20a13a6e1..630e1dd33 100644 --- a/strings/base_error.h +++ b/strings/base_error.h @@ -579,3 +579,15 @@ WINRT_EXPORT namespace winrt abort(); } } + +namespace winrt::impl +{ + inline hresult check_hresult_allow_bounds(hresult const result) + { + if (result != impl::error_out_of_bounds) + { + check_hresult(result); + } + return result; + } +} \ No newline at end of file diff --git a/test/old_tests/UnitTests/TryLookup.cpp b/test/old_tests/UnitTests/TryLookup.cpp index c03cf1104..47d2b5536 100644 --- a/test/old_tests/UnitTests/TryLookup.cpp +++ b/test/old_tests/UnitTests/TryLookup.cpp @@ -90,3 +90,41 @@ TEST_CASE("TryLookup") REQUIRE(map.TryLookup(123).value() == 456); } } + +TEST_CASE("TryRemove") +{ + auto map = single_threaded_map(std::map{ + { 123, nullptr }, + { 124, make(L"remove") }, + { 125, make(L"keep") }, + }); + + REQUIRE(map.TryRemove(122) == false); + REQUIRE(map.TryRemove(123) == true); + REQUIRE(map.TryRemove(124) == true); + + // Should still have one item left. + REQUIRE(map.Size() == 1); + REQUIRE(map.Lookup(125).ToString() == L"keep"); +} + +TEST_CASE("TryLookup TryRemove error") +{ + // Simulate a non-agile map that is being accessed from the wrong thread. + // "Try" operations should throw rather than erroneously report "not found". + // Because they didn't even try. The operation never got off the ground. + struct incorrectly_used_non_agile_map : implements> + { + int Lookup(int) { throw hresult_wrong_thread(); } + int32_t Size() { throw hresult_wrong_thread(); } + bool HasKey(int) { throw hresult_wrong_thread(); } + IMapView GetView() { throw hresult_wrong_thread(); } + bool Insert(int, int) { throw hresult_wrong_thread(); } + void Remove(int) { throw hresult_wrong_thread(); } + void Clear() { throw hresult_wrong_thread(); } + }; + + auto map = make(); + REQUIRE_THROWS_AS(map.TryLookup(123), hresult_wrong_thread); + REQUIRE_THROWS_AS(map.TryRemove(123), hresult_wrong_thread); +} \ No newline at end of file diff --git a/test/old_tests/UnitTests/produce_map.cpp b/test/old_tests/UnitTests/produce_map.cpp index f81267294..b456082a4 100644 --- a/test/old_tests/UnitTests/produce_map.cpp +++ b/test/old_tests/UnitTests/produce_map.cpp @@ -93,7 +93,7 @@ TEST_CASE("produce_IMap_int32_t_hstring") REQUIRE(m.Size() == 2); m.Remove(1); // existing REQUIRE(m.Size() == 1); - m.Remove(3); // not existing + REQUIRE_THROWS_AS(m.Remove(3), hresult_out_of_bounds); // not existing REQUIRE(m.Size() == 1); m.Clear(); @@ -177,7 +177,8 @@ TEST_CASE("produce_IMap_hstring_int32_t") REQUIRE(m.Size() == 2); m.Remove(L"one"); // existing REQUIRE(m.Size() == 1); - m.Remove(L"three"); // not existing + REQUIRE_THROWS_AS(m.Remove(L"three"), hresult_out_of_bounds); // not existing + REQUIRE(!m.TryRemove(L"three")); // not existing REQUIRE(m.Size() == 1); m.Clear(); diff --git a/test/old_tests/UnitTests/single_threaded_map.cpp b/test/old_tests/UnitTests/single_threaded_map.cpp index 44be94c08..ac58c9fa8 100644 --- a/test/old_tests/UnitTests/single_threaded_map.cpp +++ b/test/old_tests/UnitTests/single_threaded_map.cpp @@ -28,6 +28,7 @@ namespace values.Insert(2,20); values.Insert(3,30); IIterator> first = values.First(); + REQUIRE(!values.TryRemove(999)); // failed removal does not invalidate REQUIRE(first.HasCurrent()); [[maybe_unused]] auto pair = first.Current(); REQUIRE(first.MoveNext()); @@ -52,7 +53,8 @@ namespace REQUIRE(!values.Insert(2, 20)); compare(values, { { 1,100 }, {2,20} }); - values.Remove(3); + REQUIRE_THROWS_AS(values.Remove(3), hresult_out_of_bounds); + REQUIRE(!values.TryRemove(3)); compare(values, { { 1,100 },{ 2,20 } }); values.Remove(2); compare(values, { { 1,100 } }); @@ -65,7 +67,7 @@ namespace compare(values, {}); test_invalidation(values, [&] { values.Clear(); }); - test_invalidation(values, [&] { values.Remove(10); }); + test_invalidation(values, [&] { values.Remove(1); }); test_invalidation(values, [&] { values.Insert(1,10); }); } }