Skip to content

Commit

Permalink
[segmentation] Change result provider to use default manager
Browse files Browse the repository at this point in the history
The result provider had code to fetch segment info from the database.
The default manager already has this utility, so changes the default
manager API to work for this case and share code.

Bug: 1307083
Change-Id: I794c067d7869d9b22470411cbe56cbb7d40e2272
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3553306
Reviewed-by: Shakti Sahu <shaktisahu@chromium.org>
Commit-Queue: Siddhartha S <ssid@chromium.org>
Cr-Commit-Position: refs/heads/main@{#985575}
  • Loading branch information
ssiddhartha authored and Chromium LUCI CQ committed Mar 26, 2022
1 parent 0b33323 commit 8ccb5fe
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 105 deletions.
Expand Up @@ -40,10 +40,10 @@ using CleanupItem = DatabaseMaintenanceImpl::CleanupItem;

namespace {
std::set<SignalIdentifier> CollectAllSignalIdentifiers(
const SegmentInfoDatabase::SegmentInfoList& segment_infos) {
const DefaultModelManager::SegmentInfoList& segment_infos) {
std::set<SignalIdentifier> signal_ids;
for (const auto& pair : segment_infos) {
const proto::SegmentInfo& segment_info = pair.second;
for (const auto& info : segment_infos) {
const proto::SegmentInfo& segment_info = info->segment_info;
const auto& metadata = segment_info.model_metadata();
auto features =
metadata_utils::GetAllUmaFeatures(metadata, /*include_outputs=*/true);
Expand Down Expand Up @@ -115,9 +115,9 @@ void DatabaseMaintenanceImpl::ExecuteMaintenanceTasks() {
}

void DatabaseMaintenanceImpl::OnSegmentInfoCallback(
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos) {
DefaultModelManager::SegmentInfoList segment_infos) {
std::set<SignalIdentifier> signal_ids =
CollectAllSignalIdentifiers(*segment_infos);
CollectAllSignalIdentifiers(segment_infos);
stats::RecordMaintenanceSignalIdentifierCount(signal_ids.size());

auto all_tasks = GetAllTasks(signal_ids);
Expand Down
Expand Up @@ -16,7 +16,7 @@
#include "base/memory/weak_ptr.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/database/database_maintenance.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/execution/default_model_manager.h"
#include "components/segmentation_platform/internal/proto/types.pb.h"

namespace base {
Expand All @@ -28,6 +28,7 @@ using optimization_guide::proto::OptimizationTarget;

namespace segmentation_platform {
class DefaultModelManager;
class SegmentInfoDatabase;
class SignalDatabase;
class SignalStorageConfig;

Expand Down Expand Up @@ -58,7 +59,7 @@ class DatabaseMaintenanceImpl : public DatabaseMaintenance {
// All tasks currently need information about various segments, so this is
// the callback after the initial database lookup for this data.
void OnSegmentInfoCallback(
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos);
DefaultModelManager::SegmentInfoList segment_infos);

// Returns an ordered vector of all the tasks we are supposed to perform.
// These are unfinished and also need to be linked to the next task to be
Expand Down
Expand Up @@ -71,18 +71,30 @@ class TestDefaultModelManager : public DefaultModelManager {
const std::vector<OptimizationTarget>& segment_ids,
MultipleSegmentInfoCallback callback) override {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(
std::move(callback),
std::make_unique<DefaultModelManager::SegmentInfoList>()));
FROM_HERE, base::BindOnce(std::move(callback),
DefaultModelManager::SegmentInfoList()));
}

void GetAllSegmentInfoFromBothModels(
const std::vector<OptimizationTarget>& segment_ids,
SegmentInfoDatabase* segment_database,
MultipleSegmentInfoCallback callback) override {
segment_database->GetSegmentInfoForSegments(segment_ids,
std::move(callback));
segment_database->GetSegmentInfoForSegments(
segment_ids,
base::BindOnce(
[](DefaultModelManager::MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> db_list) {
DefaultModelManager::SegmentInfoList list;
for (auto& pair : *db_list) {
list.push_back(std::make_unique<
DefaultModelManager::SegmentInfoWrapper>());
list.back()->segment_source =
DefaultModelManager::SegmentSource::DATABASE;
list.back()->segment_info.Swap(&pair.second);
}
std::move(callback).Run(std::move(list));
},
std::move(callback)));
}
};

Expand Down
Expand Up @@ -9,6 +9,9 @@

namespace segmentation_platform {

DefaultModelManager::SegmentInfoWrapper::SegmentInfoWrapper() = default;
DefaultModelManager::SegmentInfoWrapper::~SegmentInfoWrapper() = default;

DefaultModelManager::DefaultModelManager(
ModelProviderFactory* model_provider_factory,
const std::vector<OptimizationTarget>& segment_ids)
Expand Down Expand Up @@ -66,7 +69,7 @@ void DefaultModelManager::GetNextSegmentInfoFromDefaultModel(
if (!default_provider) {
// If there are no more default providers, return the result so far.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(result)));
FROM_HERE, base::BindOnce(std::move(callback), std::move(*result)));
return;
}

Expand All @@ -82,11 +85,12 @@ void DefaultModelManager::OnFetchDefaultModel(
OptimizationTarget segment_id,
proto::SegmentationModelMetadata metadata,
int64_t model_version) {
proto::SegmentInfo segment_info;
segment_info.set_segment_id(segment_id);
segment_info.mutable_model_metadata()->CopyFrom(metadata);
segment_info.set_model_version(model_version);
result->push_back(std::make_pair(segment_id, segment_info));
auto info = std::make_unique<SegmentInfoWrapper>();
info->segment_source = DefaultModelManager::SegmentSource::DEFAULT_MODEL;
info->segment_info.set_segment_id(segment_id);
info->segment_info.mutable_model_metadata()->CopyFrom(metadata);
info->segment_info.set_model_version(model_version);
result->push_back(std::move(info));

GetNextSegmentInfoFromDefaultModel(
std::move(result), std::move(remaining_segment_ids), std::move(callback));
Expand Down Expand Up @@ -117,12 +121,19 @@ void DefaultModelManager::OnGetAllSegmentInfoFromDatabase(
void DefaultModelManager::OnGetAllSegmentInfoFromDefaultModel(
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos_from_db,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList>
segment_infos_from_default_model) {
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> merged_results =
std::move(segment_infos_from_db);
for (const auto& segment_info : *segment_infos_from_default_model)
merged_results->push_back(std::move(segment_info));
SegmentInfoList segment_infos_from_default_model) {
SegmentInfoList merged_results;
if (segment_infos_from_db) {
for (auto it : *segment_infos_from_db) {
merged_results.push_back(std::make_unique<SegmentInfoWrapper>());
merged_results.back()->segment_source = SegmentSource::DATABASE;
merged_results.back()->segment_info.Swap(&it.second);
}
}
merged_results.insert(
merged_results.end(),
std::make_move_iterator(segment_infos_from_default_model.begin()),
std::make_move_iterator(segment_infos_from_default_model.end()));

std::move(callback).Run(std::move(merged_results));
}
Expand Down
Expand Up @@ -14,6 +14,7 @@
#include "base/callback.h"
#include "base/containers/flat_map.h"
#include "base/logging.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/model_provider.h"
Expand All @@ -39,10 +40,21 @@ class DefaultModelManager {

// Callback for returning a list of segment infos associated with IDs.
// The same segment ID can be repeated multiple times.
using SegmentInfoList =
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>;
using MultipleSegmentInfoCallback =
base::OnceCallback<void(std::unique_ptr<SegmentInfoList>)>;
enum class SegmentSource {
DATABASE,
DEFAULT_MODEL,
};
struct SegmentInfoWrapper {
SegmentInfoWrapper();
~SegmentInfoWrapper();
SegmentInfoWrapper(const SegmentInfoWrapper&) = delete;
SegmentInfoWrapper& operator=(const SegmentInfoWrapper&) = delete;

SegmentSource segment_source;
proto::SegmentInfo segment_info;
};
using SegmentInfoList = std::vector<std::unique_ptr<SegmentInfoWrapper>>;
using MultipleSegmentInfoCallback = base::OnceCallback<void(SegmentInfoList)>;

// Utility function to get the segment info from both the database and the
// default model for a given set of segment IDs. The result can contain
Expand Down Expand Up @@ -80,12 +92,13 @@ class DefaultModelManager {
void OnGetAllSegmentInfoFromDatabase(
const std::vector<OptimizationTarget>& segment_ids,
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoList> segment_infos);
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos);

void OnGetAllSegmentInfoFromDefaultModel(
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoList> segment_infos_from_db,
std::unique_ptr<SegmentInfoList> segment_infos_from_default_model);
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList>
segment_infos_from_db,
SegmentInfoList segment_infos_from_default_model);

// Default model providers.
std::map<OptimizationTarget, std::unique_ptr<ModelProvider>>
Expand Down
Expand Up @@ -34,21 +34,20 @@ class DefaultModelManagerTest : public testing::Test {
.second;
}

void OnGetAllSegments(
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> entries) {
void OnGetAllSegments(DefaultModelManager::SegmentInfoList entries) {
get_all_segment_result_.swap(entries);
}

const SegmentInfoDatabase::SegmentInfoList& get_all_segment_result() const {
return *get_all_segment_result_;
const DefaultModelManager::SegmentInfoList& get_all_segment_result() const {
return get_all_segment_result_;
}

base::test::TaskEnvironment task_environment_;
test::TestSegmentInfoDatabase segment_database_;
TestModelProviderFactory::Data model_provider_data_;
TestModelProviderFactory model_provider_factory_;
std::unique_ptr<DefaultModelManager> default_model_manager_;
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> get_all_segment_result_;
DefaultModelManager::SegmentInfoList get_all_segment_result_;
base::WeakPtrFactory<DefaultModelManagerTest> weak_ptr_factory_{this};
};

Expand Down Expand Up @@ -101,15 +100,15 @@ TEST_F(DefaultModelManagerTest, BasicTest) {
// Verify that model exists from both sources in order: segment_1 from db,
// segment_1 from model, segment_2 from model.
EXPECT_EQ(3u, get_all_segment_result().size());
EXPECT_EQ(segment_1, get_all_segment_result()[0].first);
EXPECT_EQ(segment_1, get_all_segment_result()[0]->segment_info.segment_id());
EXPECT_EQ(model_version_db,
get_all_segment_result()[0].second.model_version());
EXPECT_EQ(segment_1, get_all_segment_result()[1].first);
get_all_segment_result()[0]->segment_info.model_version());
EXPECT_EQ(segment_1, get_all_segment_result()[1]->segment_info.segment_id());
EXPECT_EQ(model_version_default,
get_all_segment_result()[1].second.model_version());
EXPECT_EQ(segment_2, get_all_segment_result()[2].first);
get_all_segment_result()[1]->segment_info.model_version());
EXPECT_EQ(segment_2, get_all_segment_result()[2]->segment_info.segment_id());
EXPECT_EQ(model_version_default,
get_all_segment_result()[2].second.model_version());
get_all_segment_result()[2]->segment_info.model_version());

// Query again, this time with a segment ID that doesn't exist in either
// sources.
Expand All @@ -130,7 +129,7 @@ TEST_F(DefaultModelManagerTest, BasicTest) {
weak_ptr_factory_.GetWeakPtr()));
task_environment_.RunUntilIdle();
EXPECT_EQ(1u, get_all_segment_result().size());
EXPECT_EQ(segment_2, get_all_segment_result()[0].first);
EXPECT_EQ(segment_2, get_all_segment_result()[0]->segment_info.segment_id());

// Query for a model only available in the database.
default_model_manager_->GetAllSegmentInfoFromBothModels(
Expand All @@ -139,7 +138,7 @@ TEST_F(DefaultModelManagerTest, BasicTest) {
weak_ptr_factory_.GetWeakPtr()));
task_environment_.RunUntilIdle();
EXPECT_EQ(1u, get_all_segment_result().size());
EXPECT_EQ(segment_3, get_all_segment_result()[0].first);
EXPECT_EQ(segment_3, get_all_segment_result()[0]->segment_info.segment_id());
}

} // namespace segmentation_platform

0 comments on commit 8ccb5fe

Please sign in to comment.