Skip to content

Commit

Permalink
Introduce DefaultModelManager
Browse files Browse the repository at this point in the history
This CL introduces DefaultModelManager to handle default model related
queries. It also provides helper methods to provide metadata from both
the server model and default model. SignalFilterProcessor and
DatabaseMaintenanceImpl were updated to track signals from both types of
models.

Bug: 1298756
Change-Id: I5470cea068ba7d9311912489b278bde90bbbbbc6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3529267
Reviewed-by: Tommy Nyquist <nyquist@chromium.org>
Reviewed-by: Siddhartha S <ssid@chromium.org>
Commit-Queue: Shakti Sahu <shaktisahu@chromium.org>
Cr-Commit-Position: refs/heads/main@{#984358}
  • Loading branch information
Shakti Sahu authored and Chromium LUCI CQ committed Mar 23, 2022
1 parent 461a08f commit b885cd3
Show file tree
Hide file tree
Showing 16 changed files with 507 additions and 21 deletions.
@@ -1,4 +1,5 @@
DatabaseMaintenanceImplTest.*
DefaultModelManagerTest.*
DummyModelExecutionManagerTest.*
DummySegmentationPlatformServiceTest.*
FeatureAggregatorImplTest.*
Expand Down
3 changes: 3 additions & 0 deletions components/segmentation_platform/internal/BUILD.gn
Expand Up @@ -52,6 +52,8 @@ static_library("internal") {
"dummy_ukm_data_manager.h",
"execution/custom_input_processor.cc",
"execution/custom_input_processor.h",
"execution/default_model_manager.cc",
"execution/default_model_manager.h",
"execution/dummy_model_execution_manager.cc",
"execution/dummy_model_execution_manager.h",
"execution/feature_aggregator.h",
Expand Down Expand Up @@ -201,6 +203,7 @@ source_set("unit_tests") {
"database/ukm_metrics_table_unittest.cc",
"database/ukm_url_table_unittest.cc",
"dummy_segmentation_platform_service_unittest.cc",
"execution/default_model_manager_unittest.cc",
"execution/dummy_model_execution_manager_unittest.cc",
"execution/feature_aggregator_impl_unittest.cc",
"execution/feature_list_query_processor_unittest.cc",
Expand Down
Expand Up @@ -24,6 +24,7 @@
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/signal_database.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/execution/default_model_manager.h"
#include "components/segmentation_platform/internal/proto/types.pb.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
Expand Down Expand Up @@ -93,20 +94,22 @@ DatabaseMaintenanceImpl::DatabaseMaintenanceImpl(
base::Clock* clock,
SegmentInfoDatabase* segment_info_database,
SignalDatabase* signal_database,
SignalStorageConfig* signal_storage_config)
SignalStorageConfig* signal_storage_config,
DefaultModelManager* default_model_manager)
: segment_ids_(segment_ids),
clock_(clock),
segment_info_database_(segment_info_database),
signal_database_(signal_database),
signal_storage_config_(signal_storage_config) {}
signal_storage_config_(signal_storage_config),
default_model_manager_(default_model_manager) {}

DatabaseMaintenanceImpl::~DatabaseMaintenanceImpl() = default;

void DatabaseMaintenanceImpl::ExecuteMaintenanceTasks() {
std::vector<OptimizationTarget> segment_ids(segment_ids_.begin(),
segment_ids_.end());
segment_info_database_->GetSegmentInfoForSegments(
segment_ids,
default_model_manager_->GetAllSegmentInfoFromBothModels(
segment_ids, segment_info_database_,
base::BindOnce(&DatabaseMaintenanceImpl::OnSegmentInfoCallback,
weak_ptr_factory_.GetWeakPtr()));
}
Expand Down
Expand Up @@ -27,6 +27,7 @@ class Time;
using optimization_guide::proto::OptimizationTarget;

namespace segmentation_platform {
class DefaultModelManager;
class SignalDatabase;
class SignalStorageConfig;

Expand All @@ -42,7 +43,8 @@ class DatabaseMaintenanceImpl : public DatabaseMaintenance {
base::Clock* clock,
SegmentInfoDatabase* segment_info_database,
SignalDatabase* signal_database,
SignalStorageConfig* signal_storage_config);
SignalStorageConfig* signal_storage_config,
DefaultModelManager* default_model_manager);
~DatabaseMaintenanceImpl() override;

// DatabaseMaintenance overrides.
Expand Down Expand Up @@ -92,6 +94,9 @@ class DatabaseMaintenanceImpl : public DatabaseMaintenance {
raw_ptr<SignalDatabase> signal_database_;
raw_ptr<SignalStorageConfig> signal_storage_config_;

// Default model provider.
raw_ptr<DefaultModelManager> default_model_manager_;

base::WeakPtrFactory<DatabaseMaintenanceImpl> weak_ptr_factory_{this};
};

Expand Down
Expand Up @@ -13,12 +13,14 @@
#include "base/test/gmock_callback_support.h"
#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
#include "base/threading/thread_task_runner_handle.h"
#include "base/time/time.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/database/mock_signal_database.h"
#include "components/segmentation_platform/internal/database/mock_signal_storage_config.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/database/test_segment_info_database.h"
#include "components/segmentation_platform/internal/execution/default_model_manager.h"
#include "components/segmentation_platform/internal/proto/aggregation.pb.h"
#include "components/segmentation_platform/internal/proto/types.pb.h"
#include "components/segmentation_platform/public/config.h"
Expand Down Expand Up @@ -57,6 +59,33 @@ struct SignalData {

} // namespace

// Noop version. For database calls, just passes the calls to the DB.
// TODO(shaktisahu): Move this class to its own file.
class TestDefaultModelManager : public DefaultModelManager {
public:
TestDefaultModelManager()
: DefaultModelManager(nullptr, std::vector<OptimizationTarget>()) {}
~TestDefaultModelManager() override = default;

void GetAllSegmentInfoFromDefaultModel(
const std::vector<OptimizationTarget>& segment_ids,
MultipleSegmentInfoCallback callback) override {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(
std::move(callback),
std::make_unique<DefaultModelManager::SegmentInfoList>()));
}

void GetAllSegmentInfoFromBothModels(
const std::vector<OptimizationTarget>& segment_ids,
SegmentInfoDatabase* segment_database,
MultipleSegmentInfoCallback callback) override {
segment_database->GetSegmentInfoForSegments(segment_ids,
std::move(callback));
}
};

class DatabaseMaintenanceImplTest : public testing::Test {
public:
DatabaseMaintenanceImplTest() = default;
Expand All @@ -69,9 +98,11 @@ class DatabaseMaintenanceImplTest : public testing::Test {
base::flat_set<OptimizationTarget> segment_ids = {
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE};
default_model_manager_ = std::make_unique<TestDefaultModelManager>();
database_maintenance_ = std::make_unique<DatabaseMaintenanceImpl>(
segment_ids, &clock_, segment_info_database_.get(),
signal_database_.get(), signal_storage_config_.get());
signal_database_.get(), signal_storage_config_.get(),
default_model_manager_.get());

clock_.SetNow(base::Time::Now());
}
Expand Down Expand Up @@ -113,6 +144,7 @@ class DatabaseMaintenanceImplTest : public testing::Test {
std::unique_ptr<test::TestSegmentInfoDatabase> segment_info_database_;
std::unique_ptr<MockSignalDatabase> signal_database_;
std::unique_ptr<MockSignalStorageConfig> signal_storage_config_;
std::unique_ptr<TestDefaultModelManager> default_model_manager_;

std::unique_ptr<DatabaseMaintenanceImpl> database_maintenance_;
};
Expand Down
@@ -0,0 +1,121 @@
// Copyright 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/execution/default_model_manager.h"

#include "base/threading/thread_task_runner_handle.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"

namespace segmentation_platform {

DefaultModelManager::DefaultModelManager(
ModelProviderFactory* model_provider_factory,
const std::vector<OptimizationTarget>& segment_ids) {
for (OptimizationTarget segment_id : segment_ids) {
if (!model_provider_factory)
continue;
std::unique_ptr<ModelProvider> provider =
model_provider_factory->CreateDefaultProvider(segment_id);
if (!provider)
continue;
default_model_providers_.emplace(
std::make_pair(segment_id, std::move(provider)));
}
}

DefaultModelManager::~DefaultModelManager() = default;

void DefaultModelManager::GetAllSegmentInfoFromDefaultModel(
const std::vector<OptimizationTarget>& segment_ids,
MultipleSegmentInfoCallback callback) {
auto result = std::make_unique<SegmentInfoList>();
std::deque<OptimizationTarget> remaining_segment_ids(segment_ids.begin(),
segment_ids.end());
GetNextSegmentInfoFromDefaultModel(
std::move(result), std::move(remaining_segment_ids), std::move(callback));
}

void DefaultModelManager::GetNextSegmentInfoFromDefaultModel(
std::unique_ptr<SegmentInfoList> result,
std::deque<OptimizationTarget> remaining_segment_ids,
MultipleSegmentInfoCallback callback) {
OptimizationTarget segment_id =
OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
ModelProvider* default_provider = nullptr;

// Find the next available default provider.
while (!default_provider && !remaining_segment_ids.empty()) {
segment_id = remaining_segment_ids.front();
remaining_segment_ids.pop_front();
if (default_model_providers_.count(segment_id) == 1) {
default_provider = default_model_providers_[segment_id].get();
break;
}
}

if (!default_provider) {
// If there are no more default providers, return the result so far.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(result)));
return;
}

default_provider->InitAndFetchModel(base::BindRepeating(
&DefaultModelManager::OnFetchDefaultModel, weak_ptr_factory_.GetWeakPtr(),
base::Passed(&result), remaining_segment_ids, base::Passed(&callback)));
}

void DefaultModelManager::OnFetchDefaultModel(
std::unique_ptr<SegmentInfoList> result,
std::deque<OptimizationTarget> remaining_segment_ids,
MultipleSegmentInfoCallback callback,
OptimizationTarget segment_id,
proto::SegmentationModelMetadata metadata,
int64_t model_version) {
proto::SegmentInfo segment_info;
segment_info.set_segment_id(segment_id);
segment_info.mutable_model_metadata()->CopyFrom(metadata);
segment_info.set_model_version(model_version);
result->push_back(std::make_pair(segment_id, segment_info));

GetNextSegmentInfoFromDefaultModel(
std::move(result), std::move(remaining_segment_ids), std::move(callback));
}

void DefaultModelManager::GetAllSegmentInfoFromBothModels(
const std::vector<OptimizationTarget>& segment_ids,
SegmentInfoDatabase* segment_database,
MultipleSegmentInfoCallback callback) {
segment_database->GetSegmentInfoForSegments(
segment_ids,
base::BindOnce(&DefaultModelManager::OnGetAllSegmentInfoFromDatabase,
weak_ptr_factory_.GetWeakPtr(), segment_ids,
std::move(callback)));
}

void DefaultModelManager::OnGetAllSegmentInfoFromDatabase(
const std::vector<OptimizationTarget>& segment_ids,
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos) {
GetAllSegmentInfoFromDefaultModel(
segment_ids,
base::BindOnce(&DefaultModelManager::OnGetAllSegmentInfoFromDefaultModel,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
std::move(segment_infos)));
}

void DefaultModelManager::OnGetAllSegmentInfoFromDefaultModel(
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos_from_db,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList>
segment_infos_from_default_model) {
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> merged_results =
std::move(segment_infos_from_db);
for (const auto& segment_info : *segment_infos_from_default_model)
merged_results->push_back(std::move(segment_info));

std::move(callback).Run(std::move(merged_results));
}

} // namespace segmentation_platform
@@ -0,0 +1,93 @@
// Copyright 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_

#include <deque>
#include <map>
#include <memory>
#include <set>
#include <vector>

#include "base/callback.h"
#include "base/containers/flat_map.h"
#include "base/logging.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "third_party/abseil-cpp/absl/types/optional.h"

using optimization_guide::proto::OptimizationTarget;

namespace segmentation_platform {
class SegmentInfoDatabase;

// DefaultModelManager provides support to query all default models available.
// It also provides useful methods to combine results from both the database and
// the default model.
class DefaultModelManager {
public:
DefaultModelManager(ModelProviderFactory* model_provider_factory,
const std::vector<OptimizationTarget>& segment_ids);
virtual ~DefaultModelManager();

// Disallow copy/assign.
DefaultModelManager(const DefaultModelManager&) = delete;
DefaultModelManager& operator=(const DefaultModelManager&) = delete;

// Callback for returning a list of segment infos associated with IDs.
// The same segment ID can be repeated multiple times.
using SegmentInfoList =
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>;
using MultipleSegmentInfoCallback =
base::OnceCallback<void(std::unique_ptr<SegmentInfoList>)>;

// Utility function to get the segment info from both the database and the
// default model for a given set of segment IDs. The result can contain
// the same segment ID multiple times.
virtual void GetAllSegmentInfoFromBothModels(
const std::vector<OptimizationTarget>& segment_ids,
SegmentInfoDatabase* segment_database,
MultipleSegmentInfoCallback callback);

private:
// Called to get the segment info from the default model for a given set of
// segment IDs.
virtual void GetAllSegmentInfoFromDefaultModel(
const std::vector<OptimizationTarget>& segment_ids,
MultipleSegmentInfoCallback callback);

void GetNextSegmentInfoFromDefaultModel(
std::unique_ptr<SegmentInfoList> result,
std::deque<OptimizationTarget> remaining_segment_ids,
MultipleSegmentInfoCallback callback);

void OnFetchDefaultModel(std::unique_ptr<SegmentInfoList> result,
std::deque<OptimizationTarget> remaining_segment_ids,
MultipleSegmentInfoCallback callback,
OptimizationTarget segment_id,
proto::SegmentationModelMetadata metadata,
int64_t model_version);

void OnGetAllSegmentInfoFromDatabase(
const std::vector<OptimizationTarget>& segment_ids,
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoList> segment_infos);

void OnGetAllSegmentInfoFromDefaultModel(
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoList> segment_infos_from_db,
std::unique_ptr<SegmentInfoList> segment_infos_from_default_model);

// Default model providers.
std::map<OptimizationTarget, std::unique_ptr<ModelProvider>>
default_model_providers_;

base::WeakPtrFactory<DefaultModelManager> weak_ptr_factory_{this};
};

} // namespace segmentation_platform

#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_

0 comments on commit b885cd3

Please sign in to comment.