diff --git a/common/internal/byte_string.cc b/common/internal/byte_string.cc index 4f5263f94..b9f479225 100644 --- a/common/internal/byte_string.cc +++ b/common/internal/byte_string.cc @@ -286,6 +286,106 @@ bool ByteString::EndsWith(const absl::Cord& rhs) const { [&rhs](const absl::Cord& lhs) -> bool { return lhs.EndsWith(rhs); })); } +absl::optional ByteString::Find(absl::string_view needle, + size_t pos) const { + ABSL_DCHECK_LE(pos, size()); + + return Visit(absl::Overload( + [&needle, pos](absl::string_view lhs) -> absl::optional { + absl::string_view::size_type i = lhs.find(needle, pos); + if (i == absl::string_view::npos) { + return absl::nullopt; + } + return i; + }, + [&needle, pos](const absl::Cord& lhs) -> absl::optional { + absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos); + absl::Cord::CharIterator it = cord.Find(needle); + if (it == cord.char_end()) { + return absl::nullopt; + } + return pos + + static_cast(absl::Cord::Distance(cord.char_begin(), it)); + })); +} + +absl::optional ByteString::Find(const absl::Cord& needle, + size_t pos) const { + ABSL_DCHECK_LE(pos, size()); + + return Visit(absl::Overload( + [&needle, pos](absl::string_view lhs) -> absl::optional { + if (auto flat_needle = needle.TryFlat(); flat_needle) { + absl::string_view::size_type i = lhs.find(*flat_needle, pos); + if (i == absl::string_view::npos) { + return absl::nullopt; + } + return i; + } + // Needle is fragmented, we have to do a linear scan. + const size_t needle_size = needle.size(); + if (pos + needle_size > lhs.size()) { + return absl::nullopt; + } + if (ABSL_PREDICT_FALSE(needle_size == 0)) { + return pos; + } + // Optimization: find the first chunk of the needle, then compare the + // rest. If the first chunk is empty, `lhs.find` will return + // `current_pos`, which correctly degrades to a linear scan. + absl::string_view first_chunk = *needle.Chunks().begin(); + absl::Cord rest_of_needle = needle.Subcord( + first_chunk.size(), needle_size - first_chunk.size()); + size_t current_pos = pos; + while (true) { + size_t found_pos = lhs.find(first_chunk, current_pos); + if (found_pos == absl::string_view::npos || + found_pos > lhs.size() - needle_size) { + return absl::nullopt; + } + if (lhs.substr(found_pos + first_chunk.size(), + rest_of_needle.size()) == rest_of_needle) { + return found_pos; + } + current_pos = found_pos + 1; + } + }, + [&needle, pos](const absl::Cord& lhs) -> absl::optional { + absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos); + absl::Cord::CharIterator it = cord.Find(needle); + if (it == cord.char_end()) { + return absl::nullopt; + } + return pos + + static_cast(absl::Cord::Distance(cord.char_begin(), it)); + })); +} + +ByteString ByteString::Substring(size_t pos, size_t npos) const { + ABSL_DCHECK_LE(npos, size()); + ABSL_DCHECK_LE(pos, npos); + + switch (GetKind()) { + case ByteStringKind::kSmall: { + ByteString result; + result.rep_.header.kind = ByteStringKind::kSmall; + result.rep_.small.size = npos - pos; + std::memcpy(result.rep_.small.data, rep_.small.data + pos, + result.rep_.small.size); + result.rep_.small.arena = GetSmallArena(); + return result; + } + case ByteStringKind::kMedium: { + ByteString result(*this); + result.rep_.medium.data += pos; + result.rep_.medium.size = npos - pos; + return result; + } + case ByteStringKind::kLarge: + return ByteString(GetLarge().Subcord(pos, npos - pos)); + } +} + void ByteString::RemovePrefix(size_t n) { ABSL_DCHECK_LE(n, size()); if (n == 0) { diff --git a/common/internal/byte_string.h b/common/internal/byte_string.h index 00fedaf9e..e539b558e 100644 --- a/common/internal/byte_string.h +++ b/common/internal/byte_string.h @@ -159,7 +159,7 @@ union CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ByteStringRep final { absl::string_view LegacyByteString(const ByteString& string, bool stable, google::protobuf::Arena* absl_nonnull arena); -// `ByteString` is an vocabulary type capable of representing copy-on-write +// `ByteString` is a vocabulary type capable of representing copy-on-write // strings efficiently for arenas and reference counting. The contents of the // byte string are owned by an arena or managed by a reference count. All byte // strings have an associated allocator specified at construction, once the byte @@ -275,6 +275,24 @@ ByteString final { bool EndsWith(const absl::Cord& rhs) const; bool EndsWith(const ByteString& rhs) const; + // Finds the first occurrence of `needle` in this object, starting at byte + // position `pos`. Returns `absl::nullopt` if `needle` is not found. + // Note: Positions are byte-based, not code point based as in + // `cel::StringValue`. + absl::optional Find(absl::string_view needle, size_t pos = 0) const; + absl::optional Find(const absl::Cord& needle, size_t pos = 0) const; + absl::optional Find(const ByteString& needle, size_t pos = 0) const; + + // Returns a new `ByteString` that is a substring of this object, starting at + // byte position `pos` and with a length of `npos` bytes. + // Note: Positions are byte-based, not code point based as in + // `cel::StringValue`. + ByteString Substring(size_t pos, size_t npos) const; + ByteString Substring(size_t pos) const { + ABSL_DCHECK_LE(pos, size()); + return Substring(pos, size()); + } + void RemovePrefix(size_t n); void RemoveSuffix(size_t n); @@ -501,6 +519,17 @@ inline bool ByteString::EndsWith(const ByteString& rhs) const { [this](const absl::Cord& rhs) -> bool { return EndsWith(rhs); })); } +inline absl::optional ByteString::Find(const ByteString& needle, + size_t pos) const { + return needle.Visit(absl::Overload( + [this, pos](absl::string_view rhs) -> absl::optional { + return Find(rhs, pos); + }, + [this, pos](const absl::Cord& rhs) -> absl::optional { + return Find(rhs, pos); + })); +} + inline bool operator==(const ByteString& lhs, const ByteString& rhs) { return lhs.Equals(rhs); } diff --git a/common/internal/byte_string_test.cc b/common/internal/byte_string_test.cc index 36c43eb32..bd9633845 100644 --- a/common/internal/byte_string_test.cc +++ b/common/internal/byte_string_test.cc @@ -747,6 +747,151 @@ TEST_P(ByteStringTest, EndsWith) { GetMediumOrLargeCord().size() - kSmallByteStringCapacity))); } +TEST_P(ByteStringTest, Find) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + + // Find string_view + EXPECT_THAT(byte_string.Find("A string"), Optional(0)); + EXPECT_THAT( + byte_string.Find("small string optimization!"), + Optional(GetMediumStringView().find("small string optimization!"))); + EXPECT_THAT(byte_string.Find("not found"), Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find(""), Optional(0)); + EXPECT_THAT(byte_string.Find("", 3), Optional(3)); + EXPECT_THAT(byte_string.Find("A string", 1), Eq(absl::nullopt)); + + // Find cord + EXPECT_THAT(byte_string.Find(absl::Cord("A string")), Optional(0)); + EXPECT_THAT( + byte_string.Find(absl::Cord("small string optimization!")), + Optional(GetMediumStringView().find("small string optimization!"))); + EXPECT_THAT( + byte_string.Find(absl::MakeFragmentedCord( + {"A string", " that is too large for the small string optimization!", + " extra"})), + Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find(GetMediumOrLargeFragmentedCord()), Optional(0)); + EXPECT_THAT(byte_string.Find(absl::Cord("not found")), Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find(absl::Cord("")), Optional(0)); + EXPECT_THAT(byte_string.Find(absl::Cord(""), 3), Optional(3)); +} + +TEST_P(ByteStringTest, FindEdgeCases) { + ByteString empty_byte_string(GetAllocator(), ""); + EXPECT_THAT(empty_byte_string.Find("a"), Eq(absl::nullopt)); + EXPECT_THAT(empty_byte_string.Find(""), Optional(0)); + ByteString cord_byte_string = + ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_THAT(cord_byte_string.Find("not found"), Eq(absl::nullopt)); + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + + // Needle longer than haystack. + EXPECT_THAT(byte_string.Find(std::string(byte_string.size() + 1, 'a')), + Eq(absl::nullopt)); + + // Needle at the end. + absl::string_view suffix = "optimization!"; + EXPECT_THAT(byte_string.Find(suffix), + Optional(byte_string.size() - suffix.size())); + + // pos at the end. + EXPECT_THAT(byte_string.Find("a", byte_string.size()), Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find("", byte_string.size()), + Optional(byte_string.size())); + + // Search in a cord-backed ByteString with pos > 0. + EXPECT_THAT(cord_byte_string.Find("string", 1), + Optional(GetMediumStringView().find("string", 1))); + + // Needle at the end of a cord-backed ByteString. + absl::string_view suffix_sv = "optimization!"; + EXPECT_THAT(cord_byte_string.Find(suffix_sv), + Optional(cord_byte_string.size() - suffix_sv.size())); + EXPECT_THAT(cord_byte_string.Find(absl::Cord(suffix_sv)), + Optional(cord_byte_string.size() - suffix_sv.size())); + + // Fragmented needle with empty first chunk. + absl::Cord fragmented_with_empty_chunk; + fragmented_with_empty_chunk.Append(""); + fragmented_with_empty_chunk.Append("A string"); + EXPECT_THAT(byte_string.Find(fragmented_with_empty_chunk), Optional(0)); + + // Search with fragmented cord needle on string_view backed ByteString with + // partial match. + ByteString partial_match_haystack(GetAllocator(), "abababac"); + absl::Cord partial_match_needle = absl::MakeFragmentedCord({"aba", "c"}); + EXPECT_THAT(partial_match_haystack.Find(partial_match_needle), Optional(4)); + + // Search with fragmented cord needle where first chunk is found but not + // enough space for the rest. + ByteString short_haystack(GetAllocator(), "abcdefg"); + absl::Cord needle_too_long = absl::MakeFragmentedCord({"ef", "gh"}); + EXPECT_THAT(short_haystack.Find(needle_too_long), Eq(absl::nullopt)); + + // Search with a fragmented empty cord. + absl::Cord fragmented_empty_cord = absl::MakeFragmentedCord({"", ""}); + EXPECT_THAT(byte_string.Find(fragmented_empty_cord), Optional(0)); + EXPECT_THAT(byte_string.Find(fragmented_empty_cord, 3), Optional(3)); + + // Search for suffix in a fragmented cord. + ByteString fragmented_cord_byte_string(GetAllocator(), + GetMediumOrLargeFragmentedCord()); + EXPECT_THAT(fragmented_cord_byte_string.Find(suffix_sv), + Optional(fragmented_cord_byte_string.size() - suffix_sv.size())); + EXPECT_THAT(fragmented_cord_byte_string.Find(absl::Cord(suffix_sv)), + Optional(fragmented_cord_byte_string.size() - suffix_sv.size())); +} + +#ifndef NDEBUG +TEST_P(ByteStringTest, FindOutOfBounds) { + ByteString byte_string = ByteString(GetAllocator(), "test"); + EXPECT_DEATH(byte_string.Find("t", 5), _); +} +#endif + +TEST_P(ByteStringTest, Substring) { + // small byte_string substring + ByteString small_byte_string = + ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(small_byte_string.Substring(1, 5), + GetSmallStringView().substr(1, 4)); + EXPECT_EQ(small_byte_string.Substring(0, small_byte_string.size()), + GetSmallStringView()); + EXPECT_EQ(small_byte_string.Substring(1, 1), ""); + // medium byte_string substring + ByteString medium_byte_string = + ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(medium_byte_string.Substring(2, 12), + GetMediumStringView().substr(2, 10)); + EXPECT_EQ(medium_byte_string.Substring(0, medium_byte_string.size()), + GetMediumStringView()); + // large byte_string substring + ByteString large_byte_string = + ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(large_byte_string.Substring(3, 15), + GetMediumOrLargeCord().Subcord(3, 12)); + EXPECT_EQ(large_byte_string.Substring(0, large_byte_string.size()), + GetMediumOrLargeCord()); + // substring with one parameter + ByteString tacocat_byte_string = ByteString(GetAllocator(), "tacocat"); + EXPECT_EQ(tacocat_byte_string.Substring(4), "cat"); +} + +TEST_P(ByteStringTest, SubstringEdgeCases) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.Substring(byte_string.size(), byte_string.size()), ""); + EXPECT_EQ(byte_string.Substring(0, 0), ""); +} + +#ifndef NDEBUG +TEST_P(ByteStringTest, SubstringOutOfBounds) { + ByteString byte_string = ByteString(GetAllocator(), "test"); + EXPECT_DEATH(static_cast(byte_string.Substring(5, 5)), _); + EXPECT_DEATH(static_cast(byte_string.Substring(0, 5)), _); + EXPECT_DEATH(static_cast(byte_string.Substring(3, 2)), _); +} +#endif + TEST_P(ByteStringTest, RemovePrefixSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); byte_string.RemovePrefix(1);