From 5dee1ab123550064aaeed7c7d1c7ae375b63385d Mon Sep 17 00:00:00 2001 From: FuzzTest Team Date: Tue, 26 May 2026 02:53:18 -0700 Subject: [PATCH] Optimize FlatBuffers untyped table domain with fields_by_id_ map This refactor introduces an absl::btree_map for field storage in FlatbuffersTableUntypedDomainImpl, allowing deterministic field iteration and faster field ID lookup. PiperOrigin-RevId: 921342355 --- .../arbitrary_domains_flatbuffers_test.cc | 26 +++ fuzztest/internal/BUILD | 7 +- fuzztest/internal/CMakeLists.txt | 1 + fuzztest/internal/domains/BUILD | 2 +- .../domains/flatbuffers_domain_impl.cc | 52 +++--- .../domains/flatbuffers_domain_impl.h | 165 ++++++++++++------ fuzztest/internal/test_flatbuffers_64bits.fbs | 21 +++ 7 files changed, 194 insertions(+), 80 deletions(-) create mode 100644 fuzztest/internal/test_flatbuffers_64bits.fbs diff --git a/domain_tests/arbitrary_domains_flatbuffers_test.cc b/domain_tests/arbitrary_domains_flatbuffers_test.cc index 508ad269..d96c24dd 100644 --- a/domain_tests/arbitrary_domains_flatbuffers_test.cc +++ b/domain_tests/arbitrary_domains_flatbuffers_test.cc @@ -36,6 +36,7 @@ #include "./domain_tests/domain_testing.h" #include "./fuzztest/flatbuffers.h" #include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/test_flatbuffers_64bits_generated.h" #include "./fuzztest/internal/test_flatbuffers_generated.h" namespace fuzztest { @@ -43,6 +44,7 @@ namespace { using ::fuzztest::internal::BoolTable; using ::fuzztest::internal::DefaultTable; +using ::fuzztest::internal::DefaultTable64; using ::fuzztest::internal::OptionalTable; using ::fuzztest::internal::RecursiveTable; using ::fuzztest::internal::RequiredTable; @@ -592,5 +594,29 @@ TEST(FlatbuffersTableDomainImplTest, RecursiveTable) { ASSERT_THAT(new_table, IsNull()); } +TEST(FlatbuffersTableDomainImplTest, DefaultTable64ValueRoundTrip) { + flatbuffers::FlatBufferBuilder64 fbb; + auto str_offset = fbb.CreateString("foo bar baz"); + auto table_offset = internal::CreateDefaultTable64(fbb, str_offset); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto domain = Arbitrary(); + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*corpus)); + + auto ir = domain.SerializeCorpus(corpus.value()); + + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + + auto new_table = domain.GetValue(*new_corpus); + ASSERT_THAT(new_table, NotNull()); + ASSERT_THAT(new_table->str(), NotNull()); + EXPECT_EQ(new_table->str()->str(), "foo bar baz"); +} + } // namespace } // namespace fuzztest diff --git a/fuzztest/internal/BUILD b/fuzztest/internal/BUILD index f5155910..2bf470ff 100644 --- a/fuzztest/internal/BUILD +++ b/fuzztest/internal/BUILD @@ -616,8 +616,13 @@ cc_test( flatbuffer_library_public( name = "test_flatbuffers_fbs", - srcs = ["test_flatbuffers.fbs"], + srcs = [ + "test_flatbuffers.fbs", + "test_flatbuffers_64bits.fbs", + ], outs = [ + "test_flatbuffers_64bits_bfbs_generated.h", + "test_flatbuffers_64bits_generated.h", "test_flatbuffers_bfbs_generated.h", "test_flatbuffers_generated.h", ], diff --git a/fuzztest/internal/CMakeLists.txt b/fuzztest/internal/CMakeLists.txt index 0eb75aba..5c1173e6 100644 --- a/fuzztest/internal/CMakeLists.txt +++ b/fuzztest/internal/CMakeLists.txt @@ -574,6 +574,7 @@ if (FUZZTEST_BUILD_FLATBUFFERS) test_flatbuffers_headers SCHEMAS "test_flatbuffers.fbs" + "test_flatbuffers_64bits.fbs" FLAGS --bfbs-gen-embed --gen-name-strings TESTONLY diff --git a/fuzztest/internal/domains/BUILD b/fuzztest/internal/domains/BUILD index 545fef36..5682ffb8 100644 --- a/fuzztest/internal/domains/BUILD +++ b/fuzztest/internal/domains/BUILD @@ -187,9 +187,9 @@ cc_library( hdrs = ["flatbuffers_domain_impl.h"], deps = [ ":core_domains_impl", - "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/random:bit_gen_ref", diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.cc b/fuzztest/internal/domains/flatbuffers_domain_impl.cc index 7e8e144e..c675ae00 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.cc +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.cc @@ -38,13 +38,18 @@ namespace fuzztest::internal { FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( const reflection::Schema* absl_nonnull schema, const reflection::Object* absl_nonnull table_object) - : schema_(schema), table_object_(table_object) {} + : schema_(schema), table_object_(table_object) { + for (const auto& field : *table_object_->fields()) { + fields_by_id_[field->id()] = field; + } +} FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( const FlatbuffersTableUntypedDomainImpl& other) : DomainBase(other), schema_(other.schema_), - table_object_(other.table_object_) { + table_object_(other.table_object_), + fields_by_id_(other.fields_by_id_) { absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = other.domains_; @@ -55,6 +60,7 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( DomainBase::operator=(other); schema_ = other.schema_; table_object_ = other.table_object_; + fields_by_id_ = other.fields_by_id_; absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = other.domains_; @@ -63,7 +69,9 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( FlatbuffersTableUntypedDomainImpl&& other) - : schema_(other.schema_), table_object_(other.table_object_) { + : schema_(other.schema_), + table_object_(other.table_object_), + fields_by_id_(std::move(other.fields_by_id_)) { absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = std::move(other.domains_); @@ -74,6 +82,7 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( FlatbuffersTableUntypedDomainImpl&& other) { schema_ = other.schema_; table_object_ = other.table_object_; + fields_by_id_ = std::move(other.fields_by_id_); absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = std::move(other.domains_); @@ -87,7 +96,7 @@ FlatbuffersTableUntypedDomainImpl::Init(absl::BitGenRef prng) { return *seed; } corpus_type val; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val}); } return val; @@ -98,7 +107,7 @@ void FlatbuffersTableUntypedDomainImpl::Mutate( corpus_type& val, absl::BitGenRef prng, const domain_implementor::MutationMetadata& metadata, bool only_shrink) { uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, CountNumberOfMutableFieldsVisitor{*this, field_count, val, only_shrink}); @@ -112,7 +121,7 @@ void FlatbuffersTableUntypedDomainImpl::Mutate( uint64_t FlatbuffersTableUntypedDomainImpl::CountNumberOfFields( corpus_type& val) { uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField( schema_, field, CountNumberOfMutableFieldsVisitor{*this, field_count, val}); @@ -130,29 +139,16 @@ uint64_t FlatbuffersTableUntypedDomainImpl::MutateSelectedField( return fields_count; } - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { if (!IsSupportedField(field)) { if (only_shrink && !val.contains(field->id())) continue; } ++field_counter; - if (field_counter == selected_field_index) { - VisitFlatbufferField( - schema_, field, - MutateVisitor{*this, prng, metadata, only_shrink, val}); - return field_counter; - } - - if (field->type()->base_type() == reflection::BaseType::Obj) { - auto sub_object = schema_->objects()->Get(field->type()->index()); - if (!sub_object->is_struct()) { - field_counter += - GetCachedDomain(field).MutateSelectedField( - val[field->id()], prng, metadata, only_shrink, - selected_field_index - field_counter); - } - // TODO: Add support for structs. - } + VisitFlatbufferField( + schema_, field, + MutateSelectedFieldVisitor{*this, field_counter, val, prng, metadata, + only_shrink, selected_field_index}); if (field_counter >= selected_field_index) { return field_counter; @@ -163,7 +159,7 @@ uint64_t FlatbuffersTableUntypedDomainImpl::MutateSelectedField( absl::Status FlatbuffersTableUntypedDomainImpl::ValidateCorpusValue( const corpus_type& corpus_value) const { - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { absl::Status result; GenericDomainCorpusType field_corpus; if (auto it = corpus_value.find(field->id()); it != corpus_value.end()) { @@ -183,7 +179,7 @@ FlatbuffersTableUntypedDomainImpl::FromValue(const value_type& value) const { return std::nullopt; } corpus_type ret; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); } return ret; @@ -276,11 +272,11 @@ bool FlatbuffersTableUntypedDomainImpl::IsSupportedField( } uint32_t FlatbuffersTableUntypedDomainImpl::BuildTable( - const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const { + const corpus_type& value, flatbuffers::FlatBufferBuilder64& builder) const { // Add all the fields to the builder. // Offsets is the map of field id to its offset in the table. - absl::flat_hash_map + absl::flat_hash_map offsets; // Some fields are stored inline in the flatbuffer table itself (a.k.a diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.h b/fuzztest/internal/domains/flatbuffers_domain_impl.h index 1482a598..33500ba2 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.h +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.h @@ -26,9 +26,9 @@ #include #include -#include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/random/bit_gen_ref.h" @@ -37,7 +37,9 @@ #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "flatbuffers/base.h" +#include "flatbuffers/buffer.h" #include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/reflection.h" #include "flatbuffers/reflection_generated.h" #include "flatbuffers/string.h" #include "flatbuffers/table.h" @@ -82,67 +84,80 @@ struct FlatbuffersStructTag; struct FlatbuffersUnionTag; struct FlatbuffersVectorTag; +// Helper to wrap the visitor with the correct tag type. +template