Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[segmentation_platform] Added logic to filter out signals
This CL adds SignalFilterProcessor which listens to model metadata updates and sets up signal filters for signal handlers. Bug: 1204692 Change-Id: I48a8b3d9b317026a1ca9bd6cd71f22639e7755b2 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2863315 Commit-Queue: Shakti Sahu <shaktisahu@chromium.org> Reviewed-by: Tommy Nyquist <nyquist@chromium.org> Cr-Commit-Position: refs/heads/master@{#879630}
- Loading branch information
Shakti Sahu
authored and
Chromium LUCI CQ
committed
May 6, 2021
1 parent
d01cac9
commit 63a7d96
Showing
9 changed files
with
327 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
SegmentationPlatformServiceImplTest.* | ||
SignalFilterProcessorTest.* | ||
UserActionSignalHandlerTest.* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
components/segmentation_platform/internal/database/segment_info_database.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright 2021 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_ | ||
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_ | ||
|
||
#include <vector> | ||
|
||
#include "base/callback.h" | ||
#include "components/optimization_guide/proto/models.pb.h" | ||
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h" | ||
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h" | ||
|
||
using optimization_guide::proto::OptimizationTarget; | ||
|
||
namespace segmentation_platform { | ||
|
||
// Represents a DB layer that stores model metadata and prediction results to | ||
// the disk. | ||
class SegmentInfoDatabase { | ||
public: | ||
using SuccessCallback = base::OnceCallback<void(bool)>; | ||
using AllSegmentInfoCallback = base::OnceCallback<void( | ||
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>)>; | ||
|
||
virtual ~SegmentInfoDatabase() = default; | ||
|
||
// TODO(shaktisahu): Initialize DB before instantiating dependent classes. | ||
|
||
// Convenient method to return combined info for all the segments in the | ||
// database. | ||
virtual void GetAllSegmentInfo(AllSegmentInfoCallback callback) = 0; | ||
}; | ||
|
||
} // namespace segmentation_platform | ||
|
||
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_ |
52 changes: 52 additions & 0 deletions
52
components/segmentation_platform/internal/database/test_segment_info_database.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright 2021 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "components/segmentation_platform/internal/database/test_segment_info_database.h" | ||
|
||
#include "base/metrics/metrics_hashes.h" | ||
|
||
namespace segmentation_platform { | ||
|
||
namespace test { | ||
|
||
TestSegmentInfoDatabase::TestSegmentInfoDatabase() = default; | ||
|
||
TestSegmentInfoDatabase::~TestSegmentInfoDatabase() = default; | ||
|
||
void TestSegmentInfoDatabase::GetAllSegmentInfo( | ||
AllSegmentInfoCallback callback) { | ||
std::move(callback).Run(segment_infos_); | ||
} | ||
|
||
void TestSegmentInfoDatabase::AddUserAction( | ||
OptimizationTarget segment_id, | ||
const std::string& user_action_name) { | ||
proto::SegmentInfo* info = FindOrCreateSegment(segment_id); | ||
proto::SegmentationModelMetadata* metadata = info->mutable_model_metadata(); | ||
proto::Feature* feature = metadata->add_features(); | ||
proto::UserActionFeature* user_action = feature->mutable_user_action(); | ||
user_action->set_user_action_hash(base::HashMetricName(user_action_name)); | ||
} | ||
|
||
proto::SegmentInfo* TestSegmentInfoDatabase::FindOrCreateSegment( | ||
OptimizationTarget segment_id) { | ||
proto::SegmentInfo* info = nullptr; | ||
for (auto& pair : segment_infos_) { | ||
if (pair.first == segment_id) { | ||
info = &pair.second; | ||
break; | ||
} | ||
} | ||
|
||
if (info == nullptr) { | ||
segment_infos_.emplace_back(segment_id, proto::SegmentInfo()); | ||
info = &segment_infos_.back().second; | ||
} | ||
|
||
return info; | ||
} | ||
|
||
} // namespace test | ||
|
||
} // namespace segmentation_platform |
40 changes: 40 additions & 0 deletions
40
components/segmentation_platform/internal/database/test_segment_info_database.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright 2021 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_ | ||
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_ | ||
|
||
#include "components/segmentation_platform/internal/database/segment_info_database.h" | ||
|
||
#include "base/logging.h" | ||
|
||
namespace segmentation_platform { | ||
|
||
namespace test { | ||
|
||
// A fake database with sample entries that can be used for tests. | ||
class TestSegmentInfoDatabase : public SegmentInfoDatabase { | ||
public: | ||
TestSegmentInfoDatabase(); | ||
~TestSegmentInfoDatabase() override; | ||
|
||
// SegmentInfoDatabase overrides. | ||
void GetAllSegmentInfo(AllSegmentInfoCallback callback) override; | ||
|
||
// Test helper methods. | ||
void AddUserAction(OptimizationTarget segment_id, | ||
const std::string& user_action); | ||
|
||
private: | ||
// Finds a segment with given |segment_id|. Creates one if it doesn't exist. | ||
proto::SegmentInfo* FindOrCreateSegment(OptimizationTarget segment_id); | ||
|
||
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>> segment_infos_; | ||
}; | ||
|
||
} // namespace test | ||
|
||
} // namespace segmentation_platform | ||
|
||
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
components/segmentation_platform/internal/signals/signal_filter_processor.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright 2021 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "components/segmentation_platform/internal/signals/signal_filter_processor.h" | ||
|
||
#include <set> | ||
|
||
#include "components/segmentation_platform/internal/database/segment_info_database.h" | ||
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h" | ||
|
||
namespace segmentation_platform { | ||
|
||
SignalFilterProcessor::SignalFilterProcessor( | ||
SegmentInfoDatabase* segment_database, | ||
UserActionSignalHandler* user_action_signal_handler) | ||
: segment_database_(segment_database), | ||
user_action_signal_handler_(user_action_signal_handler) {} | ||
|
||
SignalFilterProcessor::~SignalFilterProcessor() = default; | ||
|
||
void SignalFilterProcessor::OnSignalListUpdated() { | ||
segment_database_->GetAllSegmentInfo(base::BindOnce( | ||
&SignalFilterProcessor::FilterSignals, weak_ptr_factory_.GetWeakPtr())); | ||
} | ||
|
||
void SignalFilterProcessor::FilterSignals( | ||
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>> | ||
segment_infos) { | ||
std::set<uint64_t> user_actions; | ||
for (const auto& pair : segment_infos) { | ||
const proto::SegmentInfo& segment_info = pair.second; | ||
const auto& metadata = segment_info.model_metadata(); | ||
for (int i = 0; i < metadata.features_size(); i++) { | ||
const auto& feature = metadata.features(i); | ||
if (feature.has_user_action() && | ||
feature.user_action().has_user_action_hash()) { | ||
user_actions.insert(feature.user_action().user_action_hash()); | ||
} | ||
// TODO(shaktisahu): Do the same for enum and value histograms. | ||
} | ||
} | ||
|
||
user_action_signal_handler_->SetRelevantUserActions(user_actions); | ||
} | ||
|
||
void SignalFilterProcessor::EnableMetrics(bool enable_metrics) { | ||
user_action_signal_handler_->EnableMetrics(enable_metrics); | ||
} | ||
|
||
} // namespace segmentation_platform |
57 changes: 57 additions & 0 deletions
57
components/segmentation_platform/internal/signals/signal_filter_processor.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright 2021 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SIGNALS_SIGNAL_FILTER_PROCESSOR_H_ | ||
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SIGNALS_SIGNAL_FILTER_PROCESSOR_H_ | ||
|
||
#include "base/memory/weak_ptr.h" | ||
#include "components/optimization_guide/proto/models.pb.h" | ||
|
||
using optimization_guide::proto::OptimizationTarget; | ||
|
||
namespace segmentation_platform { | ||
|
||
namespace proto { | ||
class SegmentInfo; | ||
} // namespace proto | ||
|
||
class SegmentInfoDatabase; | ||
class UserActionSignalHandler; | ||
|
||
// Responsible for listening to the metadata updates for the models and | ||
// registers various signal handlers for the relevant UMA signals specified in | ||
// the metadata. | ||
class SignalFilterProcessor { | ||
public: | ||
SignalFilterProcessor(SegmentInfoDatabase* segment_database, | ||
UserActionSignalHandler* user_action_signal_handler); | ||
~SignalFilterProcessor(); | ||
|
||
// Disallow copy/assign. | ||
SignalFilterProcessor(const SignalFilterProcessor&) = delete; | ||
SignalFilterProcessor& operator=(const SignalFilterProcessor&) = delete; | ||
|
||
// Called whenever the metadata about the models are updated. Registers | ||
// handlers for the relevant signals specified in the metadata. If handlers | ||
// are already registered, it will reset and register again with the new set | ||
// of signals. | ||
void OnSignalListUpdated(); | ||
|
||
// Called to enable or disable metrics collection for segmentation platform. | ||
void EnableMetrics(bool enable_metrics); | ||
|
||
private: | ||
void FilterSignals( | ||
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>> | ||
segment_infos); | ||
|
||
SegmentInfoDatabase* segment_database_; | ||
UserActionSignalHandler* user_action_signal_handler_; | ||
|
||
base::WeakPtrFactory<SignalFilterProcessor> weak_ptr_factory_{this}; | ||
}; | ||
|
||
} // namespace segmentation_platform | ||
|
||
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SIGNALS_SIGNAL_FILTER_PROCESSOR_H_ |
76 changes: 76 additions & 0 deletions
76
components/segmentation_platform/internal/signals/signal_filter_processor_unittest.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright 2021 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "components/segmentation_platform/internal/signals/signal_filter_processor.h" | ||
|
||
#include "base/metrics/metrics_hashes.h" | ||
#include "base/run_loop.h" | ||
#include "base/test/task_environment.h" | ||
#include "components/segmentation_platform/internal/database/segment_info_database.h" | ||
#include "components/segmentation_platform/internal/database/test_segment_info_database.h" | ||
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h" | ||
#include "testing/gmock/include/gmock/gmock.h" | ||
#include "testing/gtest/include/gtest/gtest.h" | ||
|
||
using testing::_; | ||
using testing::Contains; | ||
using testing::SaveArg; | ||
|
||
namespace segmentation_platform { | ||
|
||
class MockUserActionSignalHandler : public UserActionSignalHandler { | ||
public: | ||
MockUserActionSignalHandler() : UserActionSignalHandler(nullptr) {} | ||
MOCK_METHOD(void, SetRelevantUserActions, (std::set<uint64_t>)); | ||
MOCK_METHOD(void, EnableMetrics, (bool)); | ||
}; | ||
|
||
class SignalFilterProcessorTest : public testing::Test { | ||
public: | ||
SignalFilterProcessorTest() = default; | ||
~SignalFilterProcessorTest() override = default; | ||
|
||
void SetUp() override { | ||
base::SetRecordActionTaskRunner( | ||
task_environment_.GetMainThreadTaskRunner()); | ||
segment_database_ = std::make_unique<test::TestSegmentInfoDatabase>(); | ||
user_action_signal_handler_ = | ||
std::make_unique<MockUserActionSignalHandler>(); | ||
signal_filter_processor_ = std::make_unique<SignalFilterProcessor>( | ||
segment_database_.get(), user_action_signal_handler_.get()); | ||
} | ||
|
||
base::test::TaskEnvironment task_environment_; | ||
std::unique_ptr<test::TestSegmentInfoDatabase> segment_database_; | ||
std::unique_ptr<MockUserActionSignalHandler> user_action_signal_handler_; | ||
std::unique_ptr<SignalFilterProcessor> signal_filter_processor_; | ||
}; | ||
|
||
TEST_F(SignalFilterProcessorTest, UserActionRegistrationFlow) { | ||
std::string kUserActionName1 = "some_action_1"; | ||
segment_database_->AddUserAction( | ||
OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS, kUserActionName1); | ||
std::string kUserActionName2 = "some_action_2"; | ||
segment_database_->AddUserAction( | ||
OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, | ||
kUserActionName2); | ||
|
||
std::set<uint64_t> actions; | ||
EXPECT_CALL(*user_action_signal_handler_, SetRelevantUserActions(_)) | ||
.Times(1) | ||
.WillOnce(SaveArg<0>(&actions)); | ||
|
||
signal_filter_processor_->OnSignalListUpdated(); | ||
ASSERT_THAT(actions, Contains(base::HashMetricName(kUserActionName1))); | ||
ASSERT_THAT(actions, Contains(base::HashMetricName(kUserActionName2))); | ||
} | ||
|
||
TEST_F(SignalFilterProcessorTest, EnableMetrics) { | ||
EXPECT_CALL(*user_action_signal_handler_, EnableMetrics(true)); | ||
signal_filter_processor_->EnableMetrics(true); | ||
EXPECT_CALL(*user_action_signal_handler_, EnableMetrics(false)); | ||
signal_filter_processor_->EnableMetrics(false); | ||
} | ||
|
||
} // namespace segmentation_platform |