Skip to content

Commit

Permalink
[segmentation] Separate ownership of execution and signal handlers
Browse files Browse the repository at this point in the history
Move ownership of signal handlers, and execution / processing classes to
separate classes. Make clear the dependencies between these 2 services.

BUG=1307083

Change-Id: Ic572e5b7d60b6a16108a91a2d1dff05b1be1b27b
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3531200
Reviewed-by: Tommy Nyquist <nyquist@chromium.org>
Commit-Queue: Siddhartha S <ssid@chromium.org>
Cr-Commit-Position: refs/heads/main@{#990173}
  • Loading branch information
ssiddhartha authored and Chromium LUCI CQ committed Apr 7, 2022
1 parent 19c2913 commit 7f85823
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 69 deletions.
4 changes: 4 additions & 0 deletions components/segmentation_platform/internal/BUILD.gn
Expand Up @@ -77,6 +77,8 @@ static_library("internal") {
"execution/uma_feature_processor.h",
"platform_options.cc",
"platform_options.h",
"scheduler/execution_service.cc",
"scheduler/execution_service.h",
"scheduler/model_execution_scheduler.h",
"scheduler/model_execution_scheduler_impl.cc",
"scheduler/model_execution_scheduler_impl.h",
Expand All @@ -103,6 +105,8 @@ static_library("internal") {
"signals/history_service_observer.h",
"signals/signal_filter_processor.cc",
"signals/signal_filter_processor.h",
"signals/signal_handler.cc",
"signals/signal_handler.h",
"signals/ukm_config.cc",
"signals/ukm_config.h",
"signals/ukm_observer.cc",
Expand Down
@@ -0,0 +1,60 @@
// Copyright 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/scheduler/execution_service.h"

#include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/signal_database.h"
#include "components/segmentation_platform/internal/execution/feature_aggregator_impl.h"
#include "components/segmentation_platform/internal/execution/feature_list_query_processor.h"
#include "components/segmentation_platform/internal/execution/model_execution_manager_factory.h"
#include "components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h"
#include "components/segmentation_platform/internal/signals/signal_handler.h"

namespace segmentation_platform {

ExecutionService::ExecutionService() = default;
ExecutionService::~ExecutionService() = default;

void ExecutionService::Initialize(
SignalDatabase* signal_database,
SegmentInfoDatabase* segment_info_database,
SignalStorageConfig* signal_storage_config,
SignalHandler* signal_handler,
base::Clock* clock,
ModelExecutionManager::SegmentationModelUpdatedCallback callback,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const base::flat_set<OptimizationTarget>& all_segment_ids,
ModelProviderFactory* model_provider_factory,
std::vector<ModelExecutionScheduler::Observer*>&& observers,
const PlatformOptions& platform_options) {
feature_list_query_processor_ = std::make_unique<FeatureListQueryProcessor>(
signal_database, std::make_unique<FeatureAggregatorImpl>());

training_data_collector_ = TrainingDataCollector::Create(
segment_info_database, feature_list_query_processor_.get(),
signal_handler->deprecated_histogram_signal_handler(),
signal_storage_config, clock);
training_data_collector_->OnServiceInitialized();

model_execution_manager_ = CreateModelExecutionManager(
model_provider_factory, task_runner, all_segment_ids, clock,
segment_info_database, signal_database,
feature_list_query_processor_.get(), callback);

model_execution_scheduler_ = std::make_unique<ModelExecutionSchedulerImpl>(
std::move(observers), segment_info_database, signal_storage_config,
model_execution_manager_.get(), all_segment_ids, clock, platform_options);

model_execution_scheduler_->RequestModelExecutionForEligibleSegments(
/*expired_only=*/true);
}

void ExecutionService::OnNewModelInfoReady(
const proto::SegmentInfo& segment_info) {
model_execution_scheduler_->OnNewModelInfoReady(segment_info);
}

} // namespace segmentation_platform
@@ -0,0 +1,82 @@
// Copyright 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SCHEDULER_EXECUTION_SERVICE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SCHEDULER_EXECUTION_SERVICE_H_

#include <memory>
#include <vector>

#include "base/containers/flat_set.h"
#include "base/task/sequenced_task_runner.h"
#include "base/time/clock.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/execution/model_execution_manager.h"
#include "components/segmentation_platform/internal/scheduler/model_execution_scheduler.h"

namespace segmentation_platform {

struct PlatformOptions;
class FeatureListQueryProcessor;
class ModelProviderFactory;
class ModelExecutionSchedulerImpl;
class SegmentInfoDatabase;
class SignalDatabase;
class SignalHandler;
class SignalStorageConfig;
class TrainingDataCollector;

// Handles feature processing and model execution.
class ExecutionService {
public:
ExecutionService();
~ExecutionService();

ExecutionService(ExecutionService&) = delete;
ExecutionService& operator=(ExecutionService&) = delete;

void Initialize(
SignalDatabase* signal_database,
SegmentInfoDatabase* segment_info_database,
SignalStorageConfig* signal_storage_config,
SignalHandler* signal_handler,
base::Clock* clock,
ModelExecutionManager::SegmentationModelUpdatedCallback callback,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const base::flat_set<OptimizationTarget>& all_segment_ids,
ModelProviderFactory* model_provider_factory,
std::vector<ModelExecutionScheduler::Observer*>&& observers,
const PlatformOptions& platform_options);

// Called whenever a new or updated model is available. Must be a valid
// SegmentInfo with valid metadata and features.
void OnNewModelInfoReady(const proto::SegmentInfo& segment_info);

// TODO(ssid): Remove this method and pass in ExecutionService to proxy
// service.
ModelExecutionSchedulerImpl* deprecated_model_execution_scheduler() {
return model_execution_scheduler_.get();
}
// TODO(ssid): Remove this method and pass in ExecutionService to selector.
ModelExecutionManager* deprecated_model_execution_manager() {
return model_execution_manager_.get();
}

private:
// Training/inference input data generation.
std::unique_ptr<FeatureListQueryProcessor> feature_list_query_processor_;

// Traing data collection logic.
std::unique_ptr<TrainingDataCollector> training_data_collector_;

// Model execution scheduling logic.
std::unique_ptr<ModelExecutionSchedulerImpl> model_execution_scheduler_;

// Model execution.
std::unique_ptr<ModelExecutionManager> model_execution_manager_;
};

} // namespace segmentation_platform

#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SCHEDULER_EXECUTION_SERVICE_H_
Expand Up @@ -130,21 +130,10 @@ SegmentationPlatformServiceImpl::SegmentationPlatformServiceImpl(
ukm_data_manager_->AddRef();

// Construct signal processors.
user_action_signal_handler_ =
std::make_unique<UserActionSignalHandler>(signal_database_.get());
histogram_signal_handler_ =
std::make_unique<HistogramSignalHandler>(signal_database_.get());
signal_filter_processor_ = std::make_unique<SignalFilterProcessor>(
segment_info_database_.get(), user_action_signal_handler_.get(),
histogram_signal_handler_.get(), ukm_data_manager_,
default_model_manager_.get(), segment_id_vec);

if (ukm_data_manager_->IsUkmEngineEnabled() && history_service) {
// If UKM engine is enabled and history service is not available, then we
// would write metrics without URLs to the database, which is OK.
history_service_observer_ = std::make_unique<HistoryServiceObserver>(
history_service, ukm_data_manager_->GetOrCreateUrlHandler());
}
signal_handler_.Initialize(signal_database_.get(),
segment_info_database_.get(),
ukm_data_manager_.get(), history_service,
default_model_manager_.get(), segment_id_vec);

for (const auto& config : configs_) {
segment_selectors_[config->segmentation_key] =
Expand Down Expand Up @@ -179,7 +168,7 @@ SegmentationPlatformServiceImpl::SegmentationPlatformServiceImpl(
}

SegmentationPlatformServiceImpl::~SegmentationPlatformServiceImpl() {
history_service_observer_.reset();
signal_handler_.TearDown();
ukm_data_manager_->RemoveRef();
}

Expand All @@ -200,7 +189,7 @@ SegmentSelectionResult SegmentationPlatformServiceImpl::GetCachedSegmentResult(

void SegmentationPlatformServiceImpl::EnableMetrics(
bool signal_collection_allowed) {
signal_filter_processor_->EnableMetrics(signal_collection_allowed);
signal_handler_.EnableMetrics(signal_collection_allowed);
}

ServiceProxy* SegmentationPlatformServiceImpl::GetServiceProxy() {
Expand Down Expand Up @@ -250,33 +239,19 @@ void SegmentationPlatformServiceImpl::MaybeRunPostInitializationRoutines() {
return;
}

feature_list_query_processor_ = std::make_unique<FeatureListQueryProcessor>(
signal_database_.get(), std::make_unique<FeatureAggregatorImpl>());

training_data_collector_ = TrainingDataCollector::Create(
segment_info_database_.get(), feature_list_query_processor_.get(),
histogram_signal_handler_.get(), signal_storage_config_.get(), clock_);
training_data_collector_->OnServiceInitialized();

model_execution_manager_ = CreateModelExecutionManager(
model_provider_factory_.get(), task_runner_, all_segment_ids_, clock_,
segment_info_database_.get(), signal_database_.get(),
feature_list_query_processor_.get(),
base::BindRepeating(
&SegmentationPlatformServiceImpl::OnSegmentationModelUpdated,
weak_ptr_factory_.GetWeakPtr()));
signal_handler_.OnSignalListUpdated();

std::vector<ModelExecutionSchedulerImpl::Observer*> observers;
for (auto& key_and_selector : segment_selectors_)
observers.push_back(key_and_selector.second.get());
model_execution_scheduler_ = std::make_unique<ModelExecutionSchedulerImpl>(
std::move(observers), segment_info_database_.get(),
signal_storage_config_.get(), model_execution_manager_.get(),
all_segment_ids_, clock_, platform_options_);

signal_filter_processor_->OnSignalListUpdated();
model_execution_scheduler_->RequestModelExecutionForEligibleSegments(
/*expired_only=*/true);
execution_service_.Initialize(
signal_database_.get(), segment_info_database_.get(),
signal_storage_config_.get(), &signal_handler_, clock_,
base::BindRepeating(
&SegmentationPlatformServiceImpl::OnSegmentationModelUpdated,
weak_ptr_factory_.GetWeakPtr()),
task_runner_, all_segment_ids_, model_provider_factory_.get(),
std::move(observers), platform_options_);

// Initiate database maintenance tasks with a small delay.
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
Expand All @@ -286,10 +261,12 @@ void SegmentationPlatformServiceImpl::MaybeRunPostInitializationRoutines() {
weak_ptr_factory_.GetWeakPtr()),
kDatabaseMaintenanceDelay);

proxy_->SetModelExecutionScheduler(model_execution_scheduler_.get());
proxy_->SetModelExecutionScheduler(
execution_service_.deprecated_model_execution_scheduler());

for (auto& selector : segment_selectors_) {
selector.second->OnPlatformInitialized(model_execution_manager_.get());
selector.second->OnPlatformInitialized(
execution_service_.deprecated_model_execution_manager());
}
}

Expand All @@ -300,9 +277,9 @@ void SegmentationPlatformServiceImpl::OnSegmentationModelUpdated(

signal_storage_config_->OnSignalCollectionStarted(
segment_info.model_metadata());
signal_filter_processor_->OnSignalListUpdated();
signal_handler_.OnSignalListUpdated();

model_execution_scheduler_->OnNewModelInfoReady(segment_info);
execution_service_.OnNewModelInfoReady(segment_info);

// Update the service status for proxy.
base::ThreadTaskRunnerHandle::Get()->PostTask(
Expand Down
Expand Up @@ -16,8 +16,11 @@
#include "base/memory/weak_ptr.h"
#include "components/leveldb_proto/public/proto_database.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/execution/model_execution_manager.h"
#include "components/segmentation_platform/internal/platform_options.h"
#include "components/segmentation_platform/internal/scheduler/execution_service.h"
#include "components/segmentation_platform/internal/service_proxy_impl.h"
#include "components/segmentation_platform/internal/signals/signal_handler.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "third_party/abseil-cpp/absl/types/optional.h"

Expand Down Expand Up @@ -48,22 +51,14 @@ class SignalStorageConfigs;
struct Config;
class DatabaseMaintenanceImpl;
class DefaultModelManager;
class FeatureListQueryProcessor;
class HistogramSignalHandler;
class HistoryServiceObserver;
class ModelExecutionManager;
class ModelExecutionSchedulerImpl;
class ModelProviderFactory;
class SegmentationResultPrefs;
class SegmentInfoDatabase;
class SegmentSelectorImpl;
class SignalDatabaseImpl;
class SignalFilterProcessor;
class SignalStorageConfig;
class SegmentScoreProvider;
class TrainingDataCollector;
class UkmDataManager;
class UserActionSignalHandler;

// Qualifiers used to indicate service status. One or more qualifiers can
// be used at a time.
Expand Down Expand Up @@ -173,13 +168,7 @@ class SegmentationPlatformServiceImpl : public SegmentationPlatformService {
raw_ptr<UkmDataManager> ukm_data_manager_;

// Signal processing.
std::unique_ptr<UserActionSignalHandler> user_action_signal_handler_;
std::unique_ptr<HistogramSignalHandler> histogram_signal_handler_;
std::unique_ptr<SignalFilterProcessor> signal_filter_processor_;
std::unique_ptr<HistoryServiceObserver> history_service_observer_;

// Training/inference input data generation.
std::unique_ptr<FeatureListQueryProcessor> feature_list_query_processor_;
SignalHandler signal_handler_;

// Segment selection.
// TODO(shaktisahu): Determine safe destruction ordering between
Expand All @@ -190,14 +179,7 @@ class SegmentationPlatformServiceImpl : public SegmentationPlatformService {
// Segment results.
std::unique_ptr<SegmentScoreProvider> segment_score_provider_;

// Traing data collection logic.
std::unique_ptr<TrainingDataCollector> training_data_collector_;

// Model execution scheduling logic.
std::unique_ptr<ModelExecutionSchedulerImpl> model_execution_scheduler_;

// Model execution.
std::unique_ptr<ModelExecutionManager> model_execution_manager_;
ExecutionService execution_service_;

// Database maintenance.
std::unique_ptr<DatabaseMaintenanceImpl> database_maintenance_;
Expand Down
@@ -0,0 +1,55 @@
// Copyright 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/signals/signal_handler.h"

#include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
#include "components/segmentation_platform/internal/signals/history_service_observer.h"
#include "components/segmentation_platform/internal/signals/signal_filter_processor.h"
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
#include "components/segmentation_platform/internal/ukm_data_manager.h"

namespace segmentation_platform {

SignalHandler::SignalHandler() = default;
SignalHandler::~SignalHandler() = default;

void SignalHandler::Initialize(
SignalDatabase* signal_database,
SegmentInfoDatabase* segment_info_database,
UkmDataManager* ukm_data_manager,
history::HistoryService* history_service,
DefaultModelManager* default_model_manager,
const std::vector<optimization_guide::proto::OptimizationTarget>&
segment_ids) {
user_action_signal_handler_ =
std::make_unique<UserActionSignalHandler>(signal_database);
histogram_signal_handler_ =
std::make_unique<HistogramSignalHandler>(signal_database);
signal_filter_processor_ = std::make_unique<SignalFilterProcessor>(
segment_info_database, user_action_signal_handler_.get(),
histogram_signal_handler_.get(), ukm_data_manager, default_model_manager,
segment_ids);

if (ukm_data_manager->IsUkmEngineEnabled() && history_service) {
// If UKM engine is enabled and history service is not available, then we
// would write metrics without URLs to the database, which is OK.
history_service_observer_ = std::make_unique<HistoryServiceObserver>(
history_service, ukm_data_manager->GetOrCreateUrlHandler());
}
}

void SignalHandler::TearDown() {
history_service_observer_.reset();
}

void SignalHandler::EnableMetrics(bool signal_collection_allowed) {
signal_filter_processor_->EnableMetrics(signal_collection_allowed);
}

void SignalHandler::OnSignalListUpdated() {
signal_filter_processor_->OnSignalListUpdated();
}

} // namespace segmentation_platform

0 comments on commit 7f85823

Please sign in to comment.