Skip to content

Commit 21c8e2c

Browse files
bobowenyjugl
authored andcommitted
Bug 1980886 p2 - Provide a forward_iterator for accessing the ACE_HEADERs in an ACL. r=yjuglaret,win-reviewers,gstoll
Differential Revision: https://phabricator.services.mozilla.com/D266682
1 parent 6def298 commit 21c8e2c

File tree

4 files changed

+372
-0
lines changed

4 files changed

+372
-0
lines changed

testing/gtest/mozilla/MozHelpers.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,26 @@ namespace mozilla::gtest {
5858
a; \
5959
}, \
6060
b)
61+
62+
// Wrap EXPECT_DEATH_* macros to also disable the crash reporter.
63+
# define EXPECT_DEATH_WRAP(a, b) \
64+
EXPECT_DEATH_IF_SUPPORTED( \
65+
{ \
66+
mozilla::gtest::DisableCrashReporter(); \
67+
a; \
68+
}, \
69+
b)
70+
# define EXPECT_DEBUG_DEATH_WRAP(a, b) \
71+
EXPECT_DEBUG_DEATH( \
72+
{ \
73+
mozilla::gtest::DisableCrashReporter(); \
74+
a; \
75+
}, \
76+
b)
6177
#else
6278
# define ASSERT_DEATH_WRAP(a, b)
79+
# define EXPECT_DEATH_WRAP(a, b)
80+
# define EXPECT_DEBUG_DEATH_WRAP(a, b)
6381
#endif
6482

6583
void DisableCrashReporter();

widget/windows/WinHeaderOnlyUtils.h

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mozilla/Attributes.h"
2222
#include "mozilla/DynamicallyLinkedFunctionPtr.h"
2323
#include "mozilla/Maybe.h"
24+
#include "mozilla/NotNull.h"
2425
#include "mozilla/ResultVariant.h"
2526
#include "mozilla/UniquePtr.h"
2627
#include "nsWindowsHelpers.h"
@@ -815,6 +816,109 @@ int MozPathGetDriveNumber(const T* aPath) {
815816
return ToDriveNumber(aPath);
816817
}
817818

819+
/**
820+
* Class to provide a forward_iterator for accessing the ACE_HEADERs in an ACL.
821+
* ACE_HEADERs start after the ACL struct and know the size of their ACE.
822+
*/
823+
class AclAceRange {
824+
public:
825+
explicit AclAceRange(const NotNull<const ACL*> aAcl) : mAcl(aAcl) {}
826+
827+
class Iterator {
828+
public:
829+
using iterator_category = std::forward_iterator_tag;
830+
using difference_type = WORD;
831+
using value_type = const ACE_HEADER;
832+
using pointer = value_type*;
833+
using reference = value_type&;
834+
835+
// Constructs an end iterator.
836+
Iterator() = default;
837+
838+
Iterator(const Iterator&) = default;
839+
Iterator& operator=(const Iterator& aOther) = default;
840+
Iterator(Iterator&&) = default;
841+
Iterator& operator=(Iterator&& aOther) = default;
842+
843+
reference operator*() const {
844+
MOZ_RELEASE_ASSERT(mAceCount,
845+
"Trying to dereference past end of AclAceRange");
846+
return *CurrentAceHeader();
847+
}
848+
pointer operator->() const {
849+
MOZ_RELEASE_ASSERT(mAceCount,
850+
"Trying to dereference past end of AclAceRange");
851+
return CurrentAceHeader();
852+
}
853+
854+
Iterator& operator++() {
855+
MOZ_ASSERT(mAceCount, "Iterating past end of AclAceRange");
856+
if (!mAceCount) {
857+
return *this;
858+
}
859+
860+
--mAceCount;
861+
if (!mAceCount) {
862+
return *this;
863+
}
864+
865+
mCharCurrentAceHeader += CurrentAceHeader()->AceSize;
866+
SetAtEndIfCurrentAcePastEndOfAcl();
867+
return *this;
868+
}
869+
870+
Iterator operator++(int) {
871+
auto tmp = *this;
872+
++*this;
873+
return tmp;
874+
}
875+
876+
bool operator==(const Iterator& aOther) const {
877+
return mAceCount == aOther.mAceCount;
878+
}
879+
bool operator!=(const Iterator& aOther) const { return !(*this == aOther); }
880+
881+
private:
882+
friend class AclAceRange;
883+
884+
explicit Iterator(const NotNull<const ACL*> aAcl)
885+
: mCharCurrentAceHeader(reinterpret_cast<const char*>(aAcl.get() + 1)),
886+
mCharEndAcl(reinterpret_cast<const char*>(aAcl.get()) +
887+
aAcl->AclSize),
888+
mAceCount(aAcl->AceCount) {
889+
if (mAceCount > 0) {
890+
SetAtEndIfCurrentAcePastEndOfAcl();
891+
} else if (mAceCount < 0) {
892+
SetAtEnd();
893+
}
894+
}
895+
896+
void SetAtEnd() { mAceCount = 0; }
897+
898+
void SetAtEndIfCurrentAcePastEndOfAcl() {
899+
if (mCharCurrentAceHeader + sizeof(ACE_HEADER) > mCharEndAcl ||
900+
mCharCurrentAceHeader + CurrentAceHeader()->AceSize > mCharEndAcl) {
901+
SetAtEnd();
902+
}
903+
}
904+
905+
pointer CurrentAceHeader() const {
906+
return reinterpret_cast<const ACE_HEADER*>(mCharCurrentAceHeader);
907+
}
908+
909+
const char* mCharCurrentAceHeader = nullptr;
910+
const char* mCharEndAcl = nullptr;
911+
// An mAceCount of 0 means we are at the end.
912+
int mAceCount = 0;
913+
};
914+
915+
Iterator begin() { return Iterator(mAcl); }
916+
Iterator end() { return Iterator(); }
917+
918+
private:
919+
const NotNull<const ACL*> mAcl;
920+
};
921+
818922
} // namespace mozilla
819923

820924
#endif // mozilla_WinHeaderOnlyUtils_h
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2+
/* vim: set ts=8 sts=2 et sw=2 tw=80: */
3+
/* This Source Code Form is subject to the terms of the Mozilla Public
4+
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
5+
* You can obtain one at http://mozilla.org/MPL/2.0/. */
6+
7+
#include "WinHeaderOnlyUtils.h"
8+
9+
#include <algorithm>
10+
11+
#include "gtest/gtest.h"
12+
#include "mozilla/gtest/MozHelpers.h"
13+
14+
using namespace mozilla;
15+
16+
struct TestAcl {
17+
ACL acl{ACL_REVISION, 0, sizeof(TestAcl), 3, 0};
18+
ACCESS_ALLOWED_ACE ace1{
19+
{ACCESS_ALLOWED_ACE_TYPE, OBJECT_INHERIT_ACE, sizeof(ACCESS_ALLOWED_ACE)},
20+
GENERIC_READ,
21+
0};
22+
ACCESS_ALLOWED_OBJECT_ACE ace2{{ACCESS_ALLOWED_OBJECT_ACE_TYPE, INHERITED_ACE,
23+
sizeof(ACCESS_ALLOWED_OBJECT_ACE)},
24+
GENERIC_READ,
25+
0};
26+
ACCESS_DENIED_ACE ace3{
27+
{ACCESS_DENIED_ACE_TYPE, INHERITED_ACE, sizeof(ACCESS_DENIED_ACE)},
28+
GENERIC_READ,
29+
0};
30+
NotNull<ACL*> AsAclPtr() { return WrapNotNull(reinterpret_cast<ACL*>(this)); }
31+
};
32+
33+
TEST(AclAceRange, SimpleCount)
34+
{
35+
TestAcl testAcl;
36+
int aceCount = 0;
37+
for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
38+
Unused << aceHeader;
39+
++aceCount;
40+
}
41+
42+
ASSERT_EQ(aceCount, 3);
43+
}
44+
45+
TEST(AclAceRange, SameAsGetAce)
46+
{
47+
TestAcl testAcl;
48+
int aceIdx = 0;
49+
for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
50+
VOID* pGetAceHeader = nullptr;
51+
EXPECT_TRUE(::GetAce(testAcl.AsAclPtr(), aceIdx, &pGetAceHeader));
52+
auto* getAceHeader = static_cast<ACE_HEADER*>(pGetAceHeader);
53+
EXPECT_EQ(getAceHeader->AceType, aceHeader.AceType);
54+
EXPECT_EQ(getAceHeader->AceFlags, aceHeader.AceFlags);
55+
EXPECT_EQ(getAceHeader->AceSize, aceHeader.AceSize);
56+
++aceIdx;
57+
}
58+
}
59+
60+
TEST(AclAceRange, WithFlagCount)
61+
{
62+
TestAcl testAcl;
63+
int aceCount = 0;
64+
for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
65+
if (aceHeader.AceFlags & INHERITED_ACE) {
66+
++aceCount;
67+
}
68+
}
69+
70+
ASSERT_EQ(aceCount, 2);
71+
}
72+
73+
TEST(AclAceRange, AclSizeCheckedAsWellAsCount)
74+
{
75+
TestAcl testAcl;
76+
testAcl.acl.AclSize -= sizeof(ACCESS_DENIED_ACE);
77+
int aceCount = 0;
78+
for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
79+
if (aceHeader.AceFlags & INHERITED_ACE) {
80+
++aceCount;
81+
}
82+
}
83+
84+
ASSERT_EQ(aceCount, 1);
85+
}
86+
87+
TEST(AclAceRange, ChecksAceHeaderSizeInAclSize)
88+
{
89+
TestAcl testAcl;
90+
testAcl.acl.AclSize -= 1;
91+
int aceCount = 0;
92+
for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
93+
if (aceHeader.AceFlags & INHERITED_ACE) {
94+
++aceCount;
95+
}
96+
}
97+
98+
ASSERT_EQ(aceCount, 1);
99+
}
100+
101+
TEST(AclAceRange, AceCountOfZeroResultsInNoIterations)
102+
{
103+
TestAcl testAcl;
104+
testAcl.acl.AceCount = 0;
105+
int aceCount = 0;
106+
for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
107+
Unused << aceHeader;
108+
++aceCount;
109+
}
110+
111+
ASSERT_EQ(aceCount, 0);
112+
}
113+
114+
TEST(AclAceRange, AclSizeTooSmallForAnyAcesResultsInNoIterations)
115+
{
116+
TestAcl testAcl;
117+
testAcl.acl.AclSize = sizeof(ACCESS_ALLOWED_ACE) - 1;
118+
int aceCount = 0;
119+
for (const auto& aceHeader : AclAceRange(testAcl.AsAclPtr())) {
120+
Unused << aceHeader;
121+
++aceCount;
122+
}
123+
124+
ASSERT_EQ(aceCount, 0);
125+
}
126+
127+
TEST(AclAceRange, weakly_incrementable)
128+
{
129+
TestAcl testAcl;
130+
AclAceRange aclAceRange(testAcl.AsAclPtr());
131+
auto iter = aclAceRange.begin();
132+
133+
EXPECT_TRUE(std::addressof(++iter) == std::addressof(iter))
134+
<< "addressof pre-increment result should match iterator";
135+
136+
// pre and post increment advance iterator.
137+
EXPECT_EQ(iter->AceType, testAcl.ace2.Header.AceType);
138+
EXPECT_EQ(iter->AceFlags, testAcl.ace2.Header.AceFlags);
139+
EXPECT_EQ(iter->AceSize, testAcl.ace2.Header.AceSize);
140+
iter++;
141+
EXPECT_EQ(iter->AceType, testAcl.ace3.Header.AceType);
142+
EXPECT_EQ(iter->AceFlags, testAcl.ace3.Header.AceFlags);
143+
EXPECT_EQ(iter->AceSize, testAcl.ace3.Header.AceSize);
144+
145+
// Moveable.
146+
auto moveConstructedIter(std::move(iter));
147+
EXPECT_EQ(moveConstructedIter->AceType, testAcl.ace3.Header.AceType);
148+
EXPECT_EQ(moveConstructedIter->AceFlags, testAcl.ace3.Header.AceFlags);
149+
EXPECT_EQ(moveConstructedIter->AceSize, testAcl.ace3.Header.AceSize);
150+
auto moveAssignedIter = std::move(iter);
151+
EXPECT_EQ(moveAssignedIter->AceType, testAcl.ace3.Header.AceType);
152+
EXPECT_EQ(moveAssignedIter->AceFlags, testAcl.ace3.Header.AceFlags);
153+
EXPECT_EQ(moveAssignedIter->AceSize, testAcl.ace3.Header.AceSize);
154+
}
155+
156+
TEST(AclAceRange, incrementable)
157+
{
158+
TestAcl testAcl;
159+
AclAceRange aclAceRange1(testAcl.AsAclPtr());
160+
AclAceRange aclAceRange2(testAcl.AsAclPtr());
161+
auto it1 = aclAceRange1.begin();
162+
auto it2 = aclAceRange2.begin();
163+
164+
// bool(a == b) implies bool(a++ == b)
165+
EXPECT_TRUE(it1 == it2) << "begin iterators for same ACL should be equal";
166+
EXPECT_TRUE(it1++ == it2);
167+
EXPECT_FALSE(it1 == it2);
168+
EXPECT_FALSE(it1++ == it2);
169+
170+
// bool(a == b) implies bool(((void)a++, a) == ++b)
171+
it1 = aclAceRange1.begin();
172+
EXPECT_TRUE(it1 == it2);
173+
EXPECT_TRUE(((void)it1++, it1) == ++it2);
174+
it1 = aclAceRange1.begin();
175+
EXPECT_FALSE(it1 == it2);
176+
EXPECT_FALSE(((void)it1++, it1) == ++it2);
177+
178+
// Copyable.
179+
auto copyConstructedIter(it2);
180+
EXPECT_EQ(copyConstructedIter->AceType, testAcl.ace3.Header.AceType);
181+
EXPECT_EQ(copyConstructedIter->AceFlags, testAcl.ace3.Header.AceFlags);
182+
EXPECT_EQ(copyConstructedIter->AceSize, testAcl.ace3.Header.AceSize);
183+
auto copyAssignedIter = it2;
184+
EXPECT_EQ(copyAssignedIter->AceType, testAcl.ace3.Header.AceType);
185+
EXPECT_EQ(copyAssignedIter->AceFlags, testAcl.ace3.Header.AceFlags);
186+
EXPECT_EQ(copyAssignedIter->AceSize, testAcl.ace3.Header.AceSize);
187+
188+
// Default constructable.
189+
AclAceRange::Iterator defaultConstructed;
190+
EXPECT_TRUE(defaultConstructed == aclAceRange1.end());
191+
}
192+
193+
TEST(AclAceRange, AlgorithmCountIf)
194+
{
195+
TestAcl testAcl;
196+
AclAceRange aclAceRange(testAcl.AsAclPtr());
197+
auto aceCount = std::count_if(
198+
aclAceRange.begin(), aclAceRange.end(),
199+
[](const auto& hdr) { return hdr.AceFlags & INHERITED_ACE; });
200+
201+
ASSERT_EQ(aceCount, 2);
202+
}
203+
204+
TEST(AclAceRange, AlgorithmAnyOf)
205+
{
206+
TestAcl testAcl;
207+
AclAceRange aclAceRange(testAcl.AsAclPtr());
208+
auto anyInherited =
209+
std::any_of(aclAceRange.begin(), aclAceRange.end(),
210+
[](const auto& hdr) { return hdr.AceFlags & INHERITED_ACE; });
211+
212+
ASSERT_TRUE(anyInherited);
213+
}
214+
215+
TEST(AclAceRange, DereferenceAtEndIsFatal)
216+
{
217+
#if DEBUG
218+
const auto* msg =
219+
"Assertion failure: mAceCount \\(Trying to dereference past end of "
220+
"AclAceRange\\)";
221+
#else
222+
const auto* msg = "";
223+
#endif
224+
225+
EXPECT_DEATH_WRAP(
226+
{
227+
TestAcl testAcl;
228+
AclAceRange aclAceRange(testAcl.AsAclPtr());
229+
auto aceItCurrent = aclAceRange.begin();
230+
for (; aceItCurrent != aclAceRange.end(); ++aceItCurrent) {
231+
}
232+
*aceItCurrent;
233+
},
234+
msg);
235+
}
236+
237+
TEST(AclAceRange, DebugAssertForIteratingPastEnd)
238+
{
239+
EXPECT_DEBUG_DEATH_WRAP(
240+
{
241+
TestAcl testAcl;
242+
AclAceRange aclAceRange(testAcl.AsAclPtr());
243+
auto aceItCurrent = aclAceRange.begin();
244+
for (; aceItCurrent != aclAceRange.end(); ++aceItCurrent) {
245+
}
246+
++aceItCurrent;
247+
},
248+
"Assertion failure: mAceCount \\(Iterating past end of AclAceRange\\)");
249+
}

0 commit comments

Comments
 (0)