Skip to content

Commit

Permalink
[IEnum] make_range<TStoredType> (#404)
Browse files Browse the repository at this point in the history
* Don't allow raw output type access; either it's a COM pointer, or you need to pass the type you want to hold (like unique_idlist)

* .

* ctad

* fix merge

* 'format'

* fix RegistryTests not passing due to signed/unsigned comparison

* .

* .
  • Loading branch information
asklar committed Jan 3, 2024
1 parent 7e98b00 commit ce27eed
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 20 deletions.
61 changes: 50 additions & 11 deletions include/wil/com.h
Original file line number Diff line number Diff line change
Expand Up @@ -3173,25 +3173,47 @@ namespace details
template <typename T>
constexpr bool has_next_v = has_next<T>::value;

template <typename T>
struct You_must_specify_Smart_Output_type_explicitly
{
// If you get this error, you must specify a smart pointer type to receive the enumerated objects.
// We deduce the enumerator's output type (the type of the second parameter to the Next method).
// If that type is a COM pointer type (IFoo*), then we use wil::com_ptr<IFoo>. Otherwise, you must explicitly
// specify a smart-object type to receive the enumerated objects as it is not obvious how to handle disposing
// of an enumerated object.
// For example, if you have an enumerator that enumerates BSTRs, you must specify wil::unique_bstr as the
// smart pointer type to receive the enumerated BSTRs.
// auto it = wil::com_iterator<wil::unique_bstr>(pEnumBStr);
static_assert(
wistd::is_same_v<T, void>,
"Couldn't deduce a smart pointer type for the enumerator's output. You must explicitly specify a smart-object type to receive the enumerated objects.");
};

template <typename Interface>
struct com_enumerator_traits
{
using Result = typename com_enumerator_next_traits<decltype(&Interface::Next)>::Result;
// If the result is a COM pointer type (IFoo*), then we use wil::com_ptr<IFoo>. Otherwise, we use the raw pointer type IFoo*.

// If the result is a COM pointer type (IFoo*), then we use wil::com_ptr<IFoo>.
// Otherwise, you must explicitly specify a smart output type.
using smart_result = wistd::conditional_t<
wistd::is_pointer_v<Result> && wistd::is_base_of_v<::IUnknown, wistd::remove_pointer_t<Result>>,
wil::com_ptr<wistd::remove_pointer_t<Result>>,
Result>;
You_must_specify_Smart_Output_type_explicitly<Interface>>;
};
} // namespace details
/// @endcond

template <typename IEnumType, typename TStoredType = typename details::com_enumerator_traits<IEnumType>::smart_result>
template <typename TStoredType, typename IEnumType>
struct com_iterator
{
using TActualStoredType =
wistd::conditional_t<wistd::is_same_v<TStoredType, void>, typename wil::details::com_enumerator_traits<IEnumType>::smart_result, TStoredType>;

wil::com_ptr<IEnumType> m_enum{};
TStoredType m_currentValue{};
TActualStoredType m_currentValue{};

using smart_result = TActualStoredType;
com_iterator(com_iterator&&) = default;
com_iterator(com_iterator const&) = default;
com_iterator& operator=(com_iterator&&) = default;
Expand Down Expand Up @@ -3244,7 +3266,7 @@ struct com_iterator
if (m_enum)
{
// we cannot say m_currentValue = {} because com_ptr has 2 operator= overloads: one for T* and one for nullptr_t
m_currentValue = TStoredType{};
m_currentValue = TActualStoredType{};
auto hr = m_enum->Next(1, &m_currentValue, nullptr);
if (hr == S_FALSE)
{
Expand All @@ -3258,31 +3280,48 @@ struct com_iterator
}
};

template <typename IEnumXxx, wistd::enable_if_t<wil::details::has_next_v<IEnumXxx*>, int> = 0>
// CTAD for com_iterator

template <typename IEnumType>
com_iterator(IEnumType*) -> com_iterator<void, IEnumType>;

template <typename TStoredType = void, typename IEnumXxx, wistd::enable_if_t<wil::details::has_next_v<IEnumXxx*>, int> = 0>
WI_NODISCARD auto make_range(IEnumXxx* enumPtr)
{
using TActualStoredType =
wistd::conditional_t<wistd::is_same_v<TStoredType, void>, typename wil::details::com_enumerator_traits<IEnumXxx>::smart_result, TStoredType>;

struct iterator_range
{
using TStoredType = typename wil::details::com_enumerator_traits<IEnumXxx>::smart_result;
com_iterator<IEnumXxx, TStoredType> m_begin;

iterator_range(IEnumXxx* enumPtr) : m_begin(enumPtr)
static_assert(!wistd::is_same_v<TActualStoredType, void>, "You must specify a type to receive the enumerated objects.");

// the stored type must be constructible from the output type of the enumerator
static_assert(
wistd::is_constructible_v<TActualStoredType, typename wil::details::com_enumerator_traits<IEnumXxx>::Result>,
"The type you specified cannot be converted to the enumerator's output type.");

using enumerator_type = com_iterator<TActualStoredType, IEnumXxx>;

IEnumXxx* m_enumerator{};
iterator_range(IEnumXxx* enumPtr) : m_enumerator(enumPtr)
{
}

WI_NODISCARD auto begin()
{
return m_begin;
return enumerator_type(m_enumerator);
}

WI_NODISCARD constexpr auto end() const noexcept
{
return com_iterator<IEnumXxx, TStoredType>(nullptr);
return enumerator_type(nullptr);
}
};

return iterator_range(enumPtr);
}

#endif // WIL_HAS_CXX_17
#endif // WIL_ENABLE_EXCEPTIONS

Expand Down
67 changes: 58 additions & 9 deletions tests/ComTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2884,6 +2884,8 @@ TEST_CASE("COMEnumerator", "[com][enumerator]")
using IEnumMuffins = EnumT<int32_t>;
using IEnumMuffinsCOM = EnumT<IUnknown*>;

using unique_idlist = wil::unique_any<LPITEMIDLIST, decltype(&ILFree), ILFree>;

SECTION("static_assert COM enumerator details")
{
using real_next_t = decltype(&IEnumIDList::Next);
Expand All @@ -2893,29 +2895,39 @@ TEST_CASE("COMEnumerator", "[com][enumerator]")

using traits_t = wil::details::com_enumerator_traits<IEnumIDList>;
static_assert(std::is_same_v<traits_t::Result, LPITEMIDLIST>);
static_assert(std::is_same_v<traits_t::smart_result, LPITEMIDLIST>);
static_assert(std::is_same_v<traits_t::smart_result, wil::details::You_must_specify_Smart_Output_type_explicitly<IEnumIDList>>); // no smart pointer for LPITEMIDLIST specified

static_assert(std::is_same_v<wil::details::com_enumerator_next_traits<decltype(&IEnumMuffins::Next)>::Result, int32_t>);
static_assert(std::is_same_v<wil::details::com_enumerator_next_traits<decltype(&IEnumMuffins::Next)>::Interface, IEnumMuffins>);
static_assert(std::is_same_v<wil::details::com_enumerator_traits<IEnumMuffins>::Result, int32_t>);
static_assert(std::is_same_v<wil::details::com_enumerator_traits<IEnumMuffins>::smart_result, int32_t>);
static_assert(
std::is_same_v<wil::details::com_enumerator_traits<IEnumMuffins>::smart_result, wil::details::You_must_specify_Smart_Output_type_explicitly<IEnumMuffins>>); // no smart type for int32_t specified

static_assert(std::is_same_v<wil::details::com_enumerator_next_traits<decltype(&IEnumMuffinsCOM::Next)>::Result, IUnknown*>);
static_assert(
std::is_same_v<wil::details::com_enumerator_next_traits<decltype(&IEnumMuffinsCOM::Next)>::Interface, IEnumMuffinsCOM>);
static_assert(std::is_same_v<wil::details::com_enumerator_traits<IEnumMuffinsCOM>::Result, IUnknown*>);
static_assert(std::is_same_v<wil::details::com_enumerator_traits<IEnumMuffinsCOM>::smart_result, wil::com_ptr<IUnknown>>);

{
using custom_stored_type_enumerator = decltype(wil::make_range<unique_idlist, IEnumIDList>(nullptr).begin());
static_assert(std::is_same_v<custom_stored_type_enumerator::smart_result, unique_idlist>);
}
{
using custom_stored_type_enumerator = decltype(wil::make_range<unique_idlist>(wistd::declval<IEnumIDList*>()).begin());
static_assert(std::is_same_v<custom_stored_type_enumerator::smart_result, unique_idlist>);
}
}
SECTION("static_assert com_iterator types")
{
using iterator_t = wil::com_iterator<IEnumIDList>;
static_assert(std::is_same_v<LPITEMIDLIST&, decltype(*iterator_t{nullptr})>);
using iterator_t = wil::com_iterator<unique_idlist, IEnumIDList>;
static_assert(std::is_same_v<unique_idlist&, decltype(*iterator_t{nullptr})>);
}
SECTION("Enumerate empty, non-COM type")
{
auto found = false;
auto muffins = IEnumMuffins(0, 42);
for (auto muffin : wil::make_range(&muffins))
for (auto muffin : wil::make_range<int>(&muffins))
{
REQUIRE(muffin == 0);
found = true;
Expand All @@ -2927,7 +2939,7 @@ TEST_CASE("COMEnumerator", "[com][enumerator]")
{
auto found = false;
auto muffins = IEnumMuffins(3, 42);
for (auto muffin : wil::make_range(&muffins))
for (auto muffin : wil::make_range<int>(&muffins))
{
REQUIRE(muffin == 42);
found = true;
Expand All @@ -2946,11 +2958,32 @@ TEST_CASE("COMEnumerator", "[com][enumerator]")
break;
}
REQUIRE(found);

auto muffinsCOM_nothrow = IEnumMuffinsCOM(1, nullptr);
found = false;
for (auto muffin : wil::make_range<wil::com_ptr_nothrow<IUnknown>>(&muffinsCOM_nothrow))
{
REQUIRE(muffin == nullptr);
found = true;
break;
}
REQUIRE(found);
}
SECTION("CTAD")
{
auto muffinsCOM = IEnumMuffinsCOM(1, nullptr);
using muffins_ctad_type = decltype(wil::com_iterator(&muffinsCOM));
static_assert(std::is_same_v<muffins_ctad_type::smart_result, wil::com_ptr<IUnknown>>);
static_assert(std::is_same_v<decltype(*std::declval<muffins_ctad_type>()), wil::com_ptr<IUnknown>&>);

wil::com_ptr<IEnumString> enumString;
auto it = wil::make_range<wil::unique_cotaskmem_string>(enumString.get());
static_assert(std::is_same_v<decltype(*(it.begin())), wil::unique_cotaskmem_string&>);
}
#if (NTDDI_VERSION >= NTDDI_VISTA)
SECTION("static_assert enumeration types for IEnumAssocHandlers")
{
using range_idlist = decltype(wil::make_range(std::declval<IEnumIDList*>()));
using range_idlist = decltype(wil::make_range<unique_idlist>(std::declval<IEnumIDList*>()));
using range_assochandler = decltype(wil::make_range(std::declval<IEnumAssocHandlers*>()));
// this iterator_range is not the same as this other iterator_range
static_assert(!std::is_same_v<range_idlist, range_assochandler>);
Expand Down Expand Up @@ -2994,11 +3027,27 @@ TEST_CASE("COMEnumerator", "[com][enumerator]")
REQUIRE(enumIDList);

auto count = 0;
for (auto pidl : wil::make_range(enumIDList.get()))
for (const auto& pidl : wil::make_range<unique_idlist>(enumIDList.get()))
{
REQUIRE(pidl);
count++;
break;
}
REQUIRE(count > 0);
}
SECTION("Enumerate IShellFolder, with custom stored type")
{
wil::com_ptr<IShellFolder> desktop;
REQUIRE_SUCCEEDED(::SHGetDesktopFolder(&desktop));
wil::com_ptr<IEnumIDList> enumIDList;
REQUIRE_SUCCEEDED(desktop->EnumObjects(nullptr, SHCONTF_NONFOLDERS, &enumIDList));
REQUIRE(enumIDList);

auto count = 0;
for (auto& pidl : wil::make_range<unique_idlist>(enumIDList.get()))
{
REQUIRE(pidl);
count++;
ILFree(pidl);
break;
}
REQUIRE(count > 0);
Expand Down

0 comments on commit ce27eed

Please sign in to comment.