Skip to content

Commit

Permalink
[segmentation_platform] Added logic to filter out signals
Browse files Browse the repository at this point in the history
This CL adds SignalFilterProcessor which listens to model metadata
updates and sets up signal filters for signal handlers.

Bug: 1204692
Change-Id: I48a8b3d9b317026a1ca9bd6cd71f22639e7755b2
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2863315
Commit-Queue: Shakti Sahu <shaktisahu@chromium.org>
Reviewed-by: Tommy Nyquist <nyquist@chromium.org>
Cr-Commit-Position: refs/heads/master@{#879630}
  • Loading branch information
Shakti Sahu authored and Chromium LUCI CQ committed May 6, 2021
1 parent d01cac9 commit 63a7d96
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 3 deletions.
@@ -1,2 +1,3 @@
SegmentationPlatformServiceImplTest.*
SignalFilterProcessorTest.*
UserActionSignalHandlerTest.*
7 changes: 7 additions & 0 deletions components/segmentation_platform/internal/BUILD.gn
Expand Up @@ -14,9 +14,12 @@ static_library("internal") {
]

sources = [
"database/segment_info_database.h",
"database/user_action_database.h",
"segmentation_platform_service_impl.cc",
"segmentation_platform_service_impl.h",
"signals/signal_filter_processor.cc",
"signals/signal_filter_processor.h",
"signals/user_action_signal_handler.cc",
"signals/user_action_signal_handler.h",
]
Expand All @@ -25,6 +28,7 @@ static_library("internal") {
"//base",
"//components/keyed_service/core",
"//components/leveldb_proto",
"//components/optimization_guide/proto:optimization_guide_proto",
"//components/segmentation_platform/internal/proto",
"//components/segmentation_platform/public",
]
Expand All @@ -38,7 +42,10 @@ source_set("unit_tests") {
# IMPORTANT NOTE: When adding new tests, also remember to update the list of
# tests in //components/segmentation_platform/components_unittests.filter
sources = [
"database/test_segment_info_database.cc",
"database/test_segment_info_database.h",
"segmentation_platform_service_impl_unittest.cc",
"signals/signal_filter_processor_unittest.cc",
"signals/user_action_signal_handler_unittest.cc",
]

Expand Down
@@ -0,0 +1,38 @@
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_

#include <vector>

#include "base/callback.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"

using optimization_guide::proto::OptimizationTarget;

namespace segmentation_platform {

// Represents a DB layer that stores model metadata and prediction results to
// the disk.
class SegmentInfoDatabase {
public:
using SuccessCallback = base::OnceCallback<void(bool)>;
using AllSegmentInfoCallback = base::OnceCallback<void(
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>)>;

virtual ~SegmentInfoDatabase() = default;

// TODO(shaktisahu): Initialize DB before instantiating dependent classes.

// Convenient method to return combined info for all the segments in the
// database.
virtual void GetAllSegmentInfo(AllSegmentInfoCallback callback) = 0;
};

} // namespace segmentation_platform

#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_
@@ -0,0 +1,52 @@
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/database/test_segment_info_database.h"

#include "base/metrics/metrics_hashes.h"

namespace segmentation_platform {

namespace test {

TestSegmentInfoDatabase::TestSegmentInfoDatabase() = default;

TestSegmentInfoDatabase::~TestSegmentInfoDatabase() = default;

void TestSegmentInfoDatabase::GetAllSegmentInfo(
AllSegmentInfoCallback callback) {
std::move(callback).Run(segment_infos_);
}

void TestSegmentInfoDatabase::AddUserAction(
OptimizationTarget segment_id,
const std::string& user_action_name) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
proto::SegmentationModelMetadata* metadata = info->mutable_model_metadata();
proto::Feature* feature = metadata->add_features();
proto::UserActionFeature* user_action = feature->mutable_user_action();
user_action->set_user_action_hash(base::HashMetricName(user_action_name));
}

proto::SegmentInfo* TestSegmentInfoDatabase::FindOrCreateSegment(
OptimizationTarget segment_id) {
proto::SegmentInfo* info = nullptr;
for (auto& pair : segment_infos_) {
if (pair.first == segment_id) {
info = &pair.second;
break;
}
}

if (info == nullptr) {
segment_infos_.emplace_back(segment_id, proto::SegmentInfo());
info = &segment_infos_.back().second;
}

return info;
}

} // namespace test

} // namespace segmentation_platform
@@ -0,0 +1,40 @@
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_

#include "components/segmentation_platform/internal/database/segment_info_database.h"

#include "base/logging.h"

namespace segmentation_platform {

namespace test {

// A fake database with sample entries that can be used for tests.
class TestSegmentInfoDatabase : public SegmentInfoDatabase {
public:
TestSegmentInfoDatabase();
~TestSegmentInfoDatabase() override;

// SegmentInfoDatabase overrides.
void GetAllSegmentInfo(AllSegmentInfoCallback callback) override;

// Test helper methods.
void AddUserAction(OptimizationTarget segment_id,
const std::string& user_action);

private:
// Finds a segment with given |segment_id|. Creates one if it doesn't exist.
proto::SegmentInfo* FindOrCreateSegment(OptimizationTarget segment_id);

std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>> segment_infos_;
};

} // namespace test

} // namespace segmentation_platform

#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_
Expand Up @@ -22,15 +22,17 @@ enum TimeUnit {
}

message HistogramValueFeature {
optional fixed64 name_hash = 1;
optional string name = 1;
optional fixed64 name_hash = 2;
}

message HistogramEnumFeature {
optional fixed64 name_hash = 1;
optional string name = 1;
optional fixed64 name_hash = 2;

// Matches are only valid when the enum ID matches any of these.
// Works like an OR condition, e.g.: [url, search, …] or just [url].
repeated int32 enum_ids = 2;
repeated int32 enum_ids = 3;
}

message UserActionFeature {
Expand Down
@@ -0,0 +1,51 @@
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/signals/signal_filter_processor.h"

#include <set>

#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"

namespace segmentation_platform {

SignalFilterProcessor::SignalFilterProcessor(
SegmentInfoDatabase* segment_database,
UserActionSignalHandler* user_action_signal_handler)
: segment_database_(segment_database),
user_action_signal_handler_(user_action_signal_handler) {}

SignalFilterProcessor::~SignalFilterProcessor() = default;

void SignalFilterProcessor::OnSignalListUpdated() {
segment_database_->GetAllSegmentInfo(base::BindOnce(
&SignalFilterProcessor::FilterSignals, weak_ptr_factory_.GetWeakPtr()));
}

void SignalFilterProcessor::FilterSignals(
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>
segment_infos) {
std::set<uint64_t> user_actions;
for (const auto& pair : segment_infos) {
const proto::SegmentInfo& segment_info = pair.second;
const auto& metadata = segment_info.model_metadata();
for (int i = 0; i < metadata.features_size(); i++) {
const auto& feature = metadata.features(i);
if (feature.has_user_action() &&
feature.user_action().has_user_action_hash()) {
user_actions.insert(feature.user_action().user_action_hash());
}
// TODO(shaktisahu): Do the same for enum and value histograms.
}
}

user_action_signal_handler_->SetRelevantUserActions(user_actions);
}

void SignalFilterProcessor::EnableMetrics(bool enable_metrics) {
user_action_signal_handler_->EnableMetrics(enable_metrics);
}

} // namespace segmentation_platform
@@ -0,0 +1,57 @@
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SIGNALS_SIGNAL_FILTER_PROCESSOR_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SIGNALS_SIGNAL_FILTER_PROCESSOR_H_

#include "base/memory/weak_ptr.h"
#include "components/optimization_guide/proto/models.pb.h"

using optimization_guide::proto::OptimizationTarget;

namespace segmentation_platform {

namespace proto {
class SegmentInfo;
} // namespace proto

class SegmentInfoDatabase;
class UserActionSignalHandler;

// Responsible for listening to the metadata updates for the models and
// registers various signal handlers for the relevant UMA signals specified in
// the metadata.
class SignalFilterProcessor {
public:
SignalFilterProcessor(SegmentInfoDatabase* segment_database,
UserActionSignalHandler* user_action_signal_handler);
~SignalFilterProcessor();

// Disallow copy/assign.
SignalFilterProcessor(const SignalFilterProcessor&) = delete;
SignalFilterProcessor& operator=(const SignalFilterProcessor&) = delete;

// Called whenever the metadata about the models are updated. Registers
// handlers for the relevant signals specified in the metadata. If handlers
// are already registered, it will reset and register again with the new set
// of signals.
void OnSignalListUpdated();

// Called to enable or disable metrics collection for segmentation platform.
void EnableMetrics(bool enable_metrics);

private:
void FilterSignals(
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>
segment_infos);

SegmentInfoDatabase* segment_database_;
UserActionSignalHandler* user_action_signal_handler_;

base::WeakPtrFactory<SignalFilterProcessor> weak_ptr_factory_{this};
};

} // namespace segmentation_platform

#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SIGNALS_SIGNAL_FILTER_PROCESSOR_H_
@@ -0,0 +1,76 @@
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/signals/signal_filter_processor.h"

#include "base/metrics/metrics_hashes.h"
#include "base/run_loop.h"
#include "base/test/task_environment.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/test_segment_info_database.h"
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using testing::_;
using testing::Contains;
using testing::SaveArg;

namespace segmentation_platform {

class MockUserActionSignalHandler : public UserActionSignalHandler {
public:
MockUserActionSignalHandler() : UserActionSignalHandler(nullptr) {}
MOCK_METHOD(void, SetRelevantUserActions, (std::set<uint64_t>));
MOCK_METHOD(void, EnableMetrics, (bool));
};

class SignalFilterProcessorTest : public testing::Test {
public:
SignalFilterProcessorTest() = default;
~SignalFilterProcessorTest() override = default;

void SetUp() override {
base::SetRecordActionTaskRunner(
task_environment_.GetMainThreadTaskRunner());
segment_database_ = std::make_unique<test::TestSegmentInfoDatabase>();
user_action_signal_handler_ =
std::make_unique<MockUserActionSignalHandler>();
signal_filter_processor_ = std::make_unique<SignalFilterProcessor>(
segment_database_.get(), user_action_signal_handler_.get());
}

base::test::TaskEnvironment task_environment_;
std::unique_ptr<test::TestSegmentInfoDatabase> segment_database_;
std::unique_ptr<MockUserActionSignalHandler> user_action_signal_handler_;
std::unique_ptr<SignalFilterProcessor> signal_filter_processor_;
};

TEST_F(SignalFilterProcessorTest, UserActionRegistrationFlow) {
std::string kUserActionName1 = "some_action_1";
segment_database_->AddUserAction(
OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS, kUserActionName1);
std::string kUserActionName2 = "some_action_2";
segment_database_->AddUserAction(
OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
kUserActionName2);

std::set<uint64_t> actions;
EXPECT_CALL(*user_action_signal_handler_, SetRelevantUserActions(_))
.Times(1)
.WillOnce(SaveArg<0>(&actions));

signal_filter_processor_->OnSignalListUpdated();
ASSERT_THAT(actions, Contains(base::HashMetricName(kUserActionName1)));
ASSERT_THAT(actions, Contains(base::HashMetricName(kUserActionName2)));
}

TEST_F(SignalFilterProcessorTest, EnableMetrics) {
EXPECT_CALL(*user_action_signal_handler_, EnableMetrics(true));
signal_filter_processor_->EnableMetrics(true);
EXPECT_CALL(*user_action_signal_handler_, EnableMetrics(false));
signal_filter_processor_->EnableMetrics(false);
}

} // namespace segmentation_platform

0 comments on commit 63a7d96

Please sign in to comment.