Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions common/internal/byte_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,106 @@ bool ByteString::EndsWith(const absl::Cord& rhs) const {
[&rhs](const absl::Cord& lhs) -> bool { return lhs.EndsWith(rhs); }));
}

absl::optional<size_t> ByteString::Find(absl::string_view needle,
size_t pos) const {
ABSL_DCHECK_LE(pos, size());

return Visit(absl::Overload(
[&needle, pos](absl::string_view lhs) -> absl::optional<size_t> {
absl::string_view::size_type i = lhs.find(needle, pos);
if (i == absl::string_view::npos) {
return absl::nullopt;
}
return i;
},
[&needle, pos](const absl::Cord& lhs) -> absl::optional<size_t> {
absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos);
absl::Cord::CharIterator it = cord.Find(needle);
if (it == cord.char_end()) {
return absl::nullopt;
}
return pos +
static_cast<size_t>(absl::Cord::Distance(cord.char_begin(), it));
}));
}

absl::optional<size_t> ByteString::Find(const absl::Cord& needle,
size_t pos) const {
ABSL_DCHECK_LE(pos, size());

return Visit(absl::Overload(
[&needle, pos](absl::string_view lhs) -> absl::optional<size_t> {
if (auto flat_needle = needle.TryFlat(); flat_needle) {
absl::string_view::size_type i = lhs.find(*flat_needle, pos);
if (i == absl::string_view::npos) {
return absl::nullopt;
}
return i;
}
// Needle is fragmented, we have to do a linear scan.
const size_t needle_size = needle.size();
if (pos + needle_size > lhs.size()) {
return absl::nullopt;
}
if (ABSL_PREDICT_FALSE(needle_size == 0)) {
return pos;
}
// Optimization: find the first chunk of the needle, then compare the
// rest. If the first chunk is empty, `lhs.find` will return
// `current_pos`, which correctly degrades to a linear scan.
absl::string_view first_chunk = *needle.Chunks().begin();
absl::Cord rest_of_needle = needle.Subcord(
first_chunk.size(), needle_size - first_chunk.size());
size_t current_pos = pos;
while (true) {
size_t found_pos = lhs.find(first_chunk, current_pos);
if (found_pos == absl::string_view::npos ||
found_pos > lhs.size() - needle_size) {
return absl::nullopt;
}
if (lhs.substr(found_pos + first_chunk.size(),
rest_of_needle.size()) == rest_of_needle) {
return found_pos;
}
current_pos = found_pos + 1;
}
},
[&needle, pos](const absl::Cord& lhs) -> absl::optional<size_t> {
absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos);
absl::Cord::CharIterator it = cord.Find(needle);
if (it == cord.char_end()) {
return absl::nullopt;
}
return pos +
static_cast<size_t>(absl::Cord::Distance(cord.char_begin(), it));
}));
}

ByteString ByteString::Substring(size_t pos, size_t npos) const {
ABSL_DCHECK_LE(npos, size());
ABSL_DCHECK_LE(pos, npos);

switch (GetKind()) {
case ByteStringKind::kSmall: {
ByteString result;
result.rep_.header.kind = ByteStringKind::kSmall;
result.rep_.small.size = npos - pos;
std::memcpy(result.rep_.small.data, rep_.small.data + pos,
result.rep_.small.size);
result.rep_.small.arena = GetSmallArena();
return result;
}
case ByteStringKind::kMedium: {
ByteString result(*this);
result.rep_.medium.data += pos;
result.rep_.medium.size = npos - pos;
return result;
}
case ByteStringKind::kLarge:
return ByteString(GetLarge().Subcord(pos, npos - pos));
}
}

void ByteString::RemovePrefix(size_t n) {
ABSL_DCHECK_LE(n, size());
if (n == 0) {
Expand Down
31 changes: 30 additions & 1 deletion common/internal/byte_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ union CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ByteStringRep final {
absl::string_view LegacyByteString(const ByteString& string, bool stable,
google::protobuf::Arena* absl_nonnull arena);

// `ByteString` is an vocabulary type capable of representing copy-on-write
// `ByteString` is a vocabulary type capable of representing copy-on-write
// strings efficiently for arenas and reference counting. The contents of the
// byte string are owned by an arena or managed by a reference count. All byte
// strings have an associated allocator specified at construction, once the byte
Expand Down Expand Up @@ -275,6 +275,24 @@ ByteString final {
bool EndsWith(const absl::Cord& rhs) const;
bool EndsWith(const ByteString& rhs) const;

// Finds the first occurrence of `needle` in this object, starting at byte
// position `pos`. Returns `absl::nullopt` if `needle` is not found.
// Note: Positions are byte-based, not code point based as in
// `cel::StringValue`.
absl::optional<size_t> Find(absl::string_view needle, size_t pos = 0) const;
absl::optional<size_t> Find(const absl::Cord& needle, size_t pos = 0) const;
absl::optional<size_t> Find(const ByteString& needle, size_t pos = 0) const;

// Returns a new `ByteString` that is a substring of this object, starting at
// byte position `pos` and with a length of `npos` bytes.
// Note: Positions are byte-based, not code point based as in
// `cel::StringValue`.
ByteString Substring(size_t pos, size_t npos) const;
ByteString Substring(size_t pos) const {
ABSL_DCHECK_LE(pos, size());
return Substring(pos, size());
}

void RemovePrefix(size_t n);

void RemoveSuffix(size_t n);
Expand Down Expand Up @@ -501,6 +519,17 @@ inline bool ByteString::EndsWith(const ByteString& rhs) const {
[this](const absl::Cord& rhs) -> bool { return EndsWith(rhs); }));
}

inline absl::optional<size_t> ByteString::Find(const ByteString& needle,
size_t pos) const {
return needle.Visit(absl::Overload(
[this, pos](absl::string_view rhs) -> absl::optional<size_t> {
return Find(rhs, pos);
},
[this, pos](const absl::Cord& rhs) -> absl::optional<size_t> {
return Find(rhs, pos);
}));
}

inline bool operator==(const ByteString& lhs, const ByteString& rhs) {
return lhs.Equals(rhs);
}
Expand Down
145 changes: 145 additions & 0 deletions common/internal/byte_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,151 @@ TEST_P(ByteStringTest, EndsWith) {
GetMediumOrLargeCord().size() - kSmallByteStringCapacity)));
}

TEST_P(ByteStringTest, Find) {
ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView());

// Find string_view
EXPECT_THAT(byte_string.Find("A string"), Optional(0));
EXPECT_THAT(
byte_string.Find("small string optimization!"),
Optional(GetMediumStringView().find("small string optimization!")));
EXPECT_THAT(byte_string.Find("not found"), Eq(absl::nullopt));
EXPECT_THAT(byte_string.Find(""), Optional(0));
EXPECT_THAT(byte_string.Find("", 3), Optional(3));
EXPECT_THAT(byte_string.Find("A string", 1), Eq(absl::nullopt));

// Find cord
EXPECT_THAT(byte_string.Find(absl::Cord("A string")), Optional(0));
EXPECT_THAT(
byte_string.Find(absl::Cord("small string optimization!")),
Optional(GetMediumStringView().find("small string optimization!")));
EXPECT_THAT(
byte_string.Find(absl::MakeFragmentedCord(
{"A string", " that is too large for the small string optimization!",
" extra"})),
Eq(absl::nullopt));
EXPECT_THAT(byte_string.Find(GetMediumOrLargeFragmentedCord()), Optional(0));
EXPECT_THAT(byte_string.Find(absl::Cord("not found")), Eq(absl::nullopt));
EXPECT_THAT(byte_string.Find(absl::Cord("")), Optional(0));
EXPECT_THAT(byte_string.Find(absl::Cord(""), 3), Optional(3));
}

TEST_P(ByteStringTest, FindEdgeCases) {
ByteString empty_byte_string(GetAllocator(), "");
EXPECT_THAT(empty_byte_string.Find("a"), Eq(absl::nullopt));
EXPECT_THAT(empty_byte_string.Find(""), Optional(0));
ByteString cord_byte_string =
ByteString(GetAllocator(), GetMediumOrLargeCord());
EXPECT_THAT(cord_byte_string.Find("not found"), Eq(absl::nullopt));
ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView());

// Needle longer than haystack.
EXPECT_THAT(byte_string.Find(std::string(byte_string.size() + 1, 'a')),
Eq(absl::nullopt));

// Needle at the end.
absl::string_view suffix = "optimization!";
EXPECT_THAT(byte_string.Find(suffix),
Optional(byte_string.size() - suffix.size()));

// pos at the end.
EXPECT_THAT(byte_string.Find("a", byte_string.size()), Eq(absl::nullopt));
EXPECT_THAT(byte_string.Find("", byte_string.size()),
Optional(byte_string.size()));

// Search in a cord-backed ByteString with pos > 0.
EXPECT_THAT(cord_byte_string.Find("string", 1),
Optional(GetMediumStringView().find("string", 1)));

// Needle at the end of a cord-backed ByteString.
absl::string_view suffix_sv = "optimization!";
EXPECT_THAT(cord_byte_string.Find(suffix_sv),
Optional(cord_byte_string.size() - suffix_sv.size()));
EXPECT_THAT(cord_byte_string.Find(absl::Cord(suffix_sv)),
Optional(cord_byte_string.size() - suffix_sv.size()));

// Fragmented needle with empty first chunk.
absl::Cord fragmented_with_empty_chunk;
fragmented_with_empty_chunk.Append("");
fragmented_with_empty_chunk.Append("A string");
EXPECT_THAT(byte_string.Find(fragmented_with_empty_chunk), Optional(0));

// Search with fragmented cord needle on string_view backed ByteString with
// partial match.
ByteString partial_match_haystack(GetAllocator(), "abababac");
absl::Cord partial_match_needle = absl::MakeFragmentedCord({"aba", "c"});
EXPECT_THAT(partial_match_haystack.Find(partial_match_needle), Optional(4));

// Search with fragmented cord needle where first chunk is found but not
// enough space for the rest.
ByteString short_haystack(GetAllocator(), "abcdefg");
absl::Cord needle_too_long = absl::MakeFragmentedCord({"ef", "gh"});
EXPECT_THAT(short_haystack.Find(needle_too_long), Eq(absl::nullopt));

// Search with a fragmented empty cord.
absl::Cord fragmented_empty_cord = absl::MakeFragmentedCord({"", ""});
EXPECT_THAT(byte_string.Find(fragmented_empty_cord), Optional(0));
EXPECT_THAT(byte_string.Find(fragmented_empty_cord, 3), Optional(3));

// Search for suffix in a fragmented cord.
ByteString fragmented_cord_byte_string(GetAllocator(),
GetMediumOrLargeFragmentedCord());
EXPECT_THAT(fragmented_cord_byte_string.Find(suffix_sv),
Optional(fragmented_cord_byte_string.size() - suffix_sv.size()));
EXPECT_THAT(fragmented_cord_byte_string.Find(absl::Cord(suffix_sv)),
Optional(fragmented_cord_byte_string.size() - suffix_sv.size()));
}

#ifndef NDEBUG
TEST_P(ByteStringTest, FindOutOfBounds) {
ByteString byte_string = ByteString(GetAllocator(), "test");
EXPECT_DEATH(byte_string.Find("t", 5), _);
}
#endif

TEST_P(ByteStringTest, Substring) {
// small byte_string substring
ByteString small_byte_string =
ByteString(GetAllocator(), GetSmallStringView());
EXPECT_EQ(small_byte_string.Substring(1, 5),
GetSmallStringView().substr(1, 4));
EXPECT_EQ(small_byte_string.Substring(0, small_byte_string.size()),
GetSmallStringView());
EXPECT_EQ(small_byte_string.Substring(1, 1), "");
// medium byte_string substring
ByteString medium_byte_string =
ByteString(GetAllocator(), GetMediumStringView());
EXPECT_EQ(medium_byte_string.Substring(2, 12),
GetMediumStringView().substr(2, 10));
EXPECT_EQ(medium_byte_string.Substring(0, medium_byte_string.size()),
GetMediumStringView());
// large byte_string substring
ByteString large_byte_string =
ByteString(GetAllocator(), GetMediumOrLargeCord());
EXPECT_EQ(large_byte_string.Substring(3, 15),
GetMediumOrLargeCord().Subcord(3, 12));
EXPECT_EQ(large_byte_string.Substring(0, large_byte_string.size()),
GetMediumOrLargeCord());
// substring with one parameter
ByteString tacocat_byte_string = ByteString(GetAllocator(), "tacocat");
EXPECT_EQ(tacocat_byte_string.Substring(4), "cat");
}

TEST_P(ByteStringTest, SubstringEdgeCases) {
ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView());
EXPECT_EQ(byte_string.Substring(byte_string.size(), byte_string.size()), "");
EXPECT_EQ(byte_string.Substring(0, 0), "");
}

#ifndef NDEBUG
TEST_P(ByteStringTest, SubstringOutOfBounds) {
ByteString byte_string = ByteString(GetAllocator(), "test");
EXPECT_DEATH(static_cast<void>(byte_string.Substring(5, 5)), _);
EXPECT_DEATH(static_cast<void>(byte_string.Substring(0, 5)), _);
EXPECT_DEATH(static_cast<void>(byte_string.Substring(3, 2)), _);
}
#endif

TEST_P(ByteStringTest, RemovePrefixSmall) {
ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView());
byte_string.RemovePrefix(1);
Expand Down