Skip to content

Commit

Permalink
[merge-103]Cleanup experimental recorder to use result provider
Browse files Browse the repository at this point in the history
The recorder used database result and computed rank instead of using the
result provider which is wrapper on database to provide rank.
Additionally records uncomputed segments as "Unknown" instead of
skipping, so we can know how many users do not meet requirements in the
metrics.

BUG=1325414

(cherry picked from commit 5d8c5f0)

Change-Id: If033a22a34f7884fc2addbce5782fe7f1e5ed0d8
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3661623
Commit-Queue: Siddhartha S <ssid@chromium.org>
Reviewed-by: Min Qin <qinmin@chromium.org>
Cr-Original-Commit-Position: refs/heads/main@{#1007514}
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3669718
Commit-Queue: Krishna Govind <govind@chromium.org>
Reviewed-by: Krishna Govind <govind@chromium.org>
Owners-Override: Krishna Govind <govind@chromium.org>
Cr-Commit-Position: refs/branch-heads/5060@{#290}
Cr-Branched-From: b83393d-refs/heads/main@{#1002911}
  • Loading branch information
ssiddhartha authored and Chromium LUCI CQ committed May 27, 2022
1 parent 769dd3a commit 56c233c
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 32 deletions.
8 changes: 4 additions & 4 deletions chrome/browser/segmentation_platform/service_browsertest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ IN_PROC_BROWSER_TEST_F(SegmentationPlatformUkmModelTest,
MockModelProvider* provider = utils_.GetDefaultOverride(kSegmentId);

EXPECT_CALL(*provider, ExecuteModelWithInput(_, _))
.WillOnce(Invoke([&](const std::vector<float>& inputs,
ModelProvider::ExecutionCallback callback) {
.WillRepeatedly(Invoke([&](const std::vector<float>& inputs,
ModelProvider::ExecutionCallback callback) {
// There are no UKM metrics written to the database, count = 0.
EXPECT_EQ(std::vector<float>({0}), inputs);
std::move(callback).Run(0.5);
Expand Down Expand Up @@ -219,8 +219,8 @@ IN_PROC_BROWSER_TEST_F(SegmentationPlatformUkmModelTest,
MockModelProvider* provider = utils_.GetDefaultOverride(kSegmentId);

EXPECT_CALL(*provider, ExecuteModelWithInput(_, _))
.WillOnce(Invoke([](const std::vector<float>& inputs,
ModelProvider::ExecutionCallback callback) {
.WillRepeatedly(Invoke([](const std::vector<float>& inputs,
ModelProvider::ExecutionCallback callback) {
// Expected input is 2 since we recorded 2 UKM metrics in the previous
// session.
EXPECT_EQ(std::vector<float>({2}), inputs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include "base/bind.h"
#include "base/strings/strcat.h"
#include "components/segmentation_platform/internal/database/metadata_utils.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/metric_filter_utils.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/field_trial_register.h"

Expand All @@ -17,43 +17,35 @@
namespace segmentation_platform {

ExperimentalGroupRecorder::ExperimentalGroupRecorder(
SegmentInfoDatabase* segment_database,
SegmentResultProvider* result_provider,
FieldTrialRegister* field_trial_register,
const std::string& segmentation_key,
optimization_guide::proto::OptimizationTarget selected_segment)
: field_trial_register_(field_trial_register),
segmentation_key_(segmentation_key) {
segment_database->GetSegmentInfo(
selected_segment, base::BindOnce(&ExperimentalGroupRecorder::OnGetSegment,
weak_ptr_factory_.GetWeakPtr()));
segmentation_key_(segmentation_key),
segment_id_(selected_segment) {
result_provider->GetSegmentResult(
segment_id_,
base::StrCat({segmentation_key, kSubsegmentDiscreteMappingSuffix}),
base::BindOnce(&ExperimentalGroupRecorder::OnGetSegment,
weak_ptr_factory_.GetWeakPtr()));
}

ExperimentalGroupRecorder::~ExperimentalGroupRecorder() = default;

void ExperimentalGroupRecorder::OnGetSegment(
absl::optional<proto::SegmentInfo> result) {
if (!result || !result->has_prediction_result()) {
return;
}
const float score = result->prediction_result().result();
std::string subsegment_key =
base::StrCat({segmentation_key_, kSubsegmentDiscreteMappingSuffix});
auto iter = result->model_metadata().discrete_mappings().find(subsegment_key);
if (iter == result->model_metadata().discrete_mappings().end()) {
// TODO(ssid): Move this check into ConvertToDiscreteScore().
return;
}
const int rank = metadata_utils::ConvertToDiscreteScore(
base::StrCat({segmentation_key_, kSubsegmentDiscreteMappingSuffix}),
score, result->model_metadata());

std::unique_ptr<SegmentResultProvider::SegmentResult> result) {
const std::string trial_name = stats::SegmentationKeyToSubsegmentTrialName(
segmentation_key_, result->segment_id());
segmentation_key_, segment_id_);
int rank = 0;
if (result && result->rank) {
rank = *result->rank;
}

// Can be nullptr in tests.
if (field_trial_register_) {
field_trial_register_->RegisterSubsegmentFieldTrialIfNeeded(
trial_name, result->segment_id(), rank);
trial_name, segment_id_, rank);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
#include "base/memory/weak_ptr.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "third_party/abseil-cpp/absl/types/optional.h"

namespace segmentation_platform {

class SegmentInfoDatabase;
class FieldTrialRegister;

// Records experimental sub groups for the given optimization target.
Expand All @@ -23,7 +23,7 @@ class ExperimentalGroupRecorder {
// subsegment based on the score. This class must be kept alive till the
// recording is complete, can be used only once.
ExperimentalGroupRecorder(
SegmentInfoDatabase* storage_service,
SegmentResultProvider* result_provider,
FieldTrialRegister* field_trial_register,
const std::string& segmentation_key,
optimization_guide::proto::OptimizationTarget selected_segment);
Expand All @@ -33,10 +33,12 @@ class ExperimentalGroupRecorder {
ExperimentalGroupRecorder& operator=(ExperimentalGroupRecorder&) = delete;

private:
void OnGetSegment(absl::optional<proto::SegmentInfo> result);
void OnGetSegment(
std::unique_ptr<SegmentResultProvider::SegmentResult> result);

const raw_ptr<FieldTrialRegister> field_trial_register_;
const std::string segmentation_key_;
const optimization_guide::proto::OptimizationTarget segment_id_;

base::WeakPtrFactory<ExperimentalGroupRecorder> weak_ptr_factory_{this};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void SegmentSelectorImpl::OnPlatformInitialized(
for (const OptimizationTarget segment_id : config_->segment_ids) {
experimental_group_recorder_.emplace_back(
std::make_unique<ExperimentalGroupRecorder>(
segment_database_, field_trial_register_,
segment_result_provider_.get(), field_trial_register_,
config_->segmentation_key, segment_id));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,17 @@ TEST_F(SegmentSelectorTest, SubsegmentRecording) {
&config_, &field_trial_register_, &clock_,
PlatformOptions::CreateDefault(), default_manager_.get());

// When segment result is missing, unknown subsegment is recorded.
EXPECT_CALL(
field_trial_register_,
RegisterSubsegmentFieldTrialIfNeeded(
base::StringPiece("Segmentation_TestKey_Share"), segment_id0, 0));
EXPECT_CALL(
field_trial_register_,
RegisterSubsegmentFieldTrialIfNeeded(
base::StringPiece("Segmentation_TestKey_NewTab"),
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 0));

// The new selector will record subsegment metric groups based on the mapping.
base::RunLoop wait_for_subsegment;
EXPECT_CALL(field_trial_register_,
Expand Down

0 comments on commit 56c233c

Please sign in to comment.