Skip to content

Commit

Permalink
[merge-103][segmentation] Add support for experimental subgroups
Browse files Browse the repository at this point in the history
Uses model score to select a subsegment for heuristic models. This
allows the model to provide specific scores for users and that can be
used to log synthetic field trials specific to those subsegment. The
metadata must define an extra subsegment discrete mapping that maps
model scores to subsegment enum.

BUG=1325414

(cherry picked from commit 297d012)

Change-Id: Ica602e323f8af2c8bca30c80a8942b87450f7122
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3636859
Commit-Queue: Siddhartha S <ssid@chromium.org>
Reviewed-by: Tommy Nyquist <nyquist@chromium.org>
Cr-Original-Commit-Position: refs/heads/main@{#1005428}
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3661372
Reviewed-by: Min Qin <qinmin@chromium.org>
Auto-Submit: Siddhartha S <ssid@chromium.org>
Cr-Commit-Position: refs/branch-heads/5060@{#198}
Cr-Branched-From: b83393d-refs/heads/main@{#1002911}
  • Loading branch information
ssiddhartha authored and Chromium LUCI CQ committed May 23, 2022
1 parent 408d6da commit ef26ec7
Show file tree
Hide file tree
Showing 20 changed files with 525 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
#include "chrome/browser/segmentation_platform/default_model/feed_user_segment.h"

#include "base/metrics/field_trial_params.h"
#include "base/strings/strcat.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "chrome/browser/segmentation_platform/default_model/metadata_writer.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/model_provider.h"

namespace segmentation_platform {
Expand All @@ -31,6 +33,19 @@ constexpr int64_t kFeedUserDiscreteMappingRank = 1;
constexpr std::pair<float, int> kDiscreteMappings[] = {
{kFeedUserDiscreteMappingMinResult, kFeedUserDiscreteMappingRank}};

static constexpr std::array<std::pair<float, /*FeedUserSubsegment*/ int>, 8>
kFeedUserScoreToSubGroup = {{
{1.0, static_cast<int>(FeedUserSubsegment::kActiveOnFeedOnly)},
{0.8,
static_cast<int>(FeedUserSubsegment::kActiveOnFeedAndNtpFeatures)},
{0.7, static_cast<int>(FeedUserSubsegment::kNoFeedAndNtpFeatures)},
{0.5, static_cast<int>(FeedUserSubsegment::kMvtOnly)},
{0.4, static_cast<int>(FeedUserSubsegment::kReturnToCurrentTabOnly)},
{0.2, static_cast<int>(FeedUserSubsegment::kUsedNtpWithoutModules)},
{0.1, static_cast<int>(FeedUserSubsegment::kNoNTPOrHomeOpened)},
{0.0, static_cast<int>(FeedUserSubsegment::kUnknown)},
}};

// InputFeatures.
constexpr MetadataWriter::UMAFeature kFeedUserUMAFeatures[] = {
MetadataWriter::UMAFeature{
Expand Down Expand Up @@ -93,11 +108,53 @@ constexpr MetadataWriter::UMAFeature kFeedUserUMAFeatures[] = {

#define ARRAY_SIZE(ar) (sizeof(ar) / sizeof(ar[0]))

float GetScoreForSubsegment(FeedUserSubsegment subgroup) {
for (const auto& score_and_type : kFeedUserScoreToSubGroup) {
if (score_and_type.second == static_cast<int>(subgroup)) {
return score_and_type.first;
}
}
NOTREACHED();
return 0;
}

std::string FeedUserSubsegmentToString(FeedUserSubsegment feed_group) {
switch (feed_group) {
case FeedUserSubsegment::kUnknown:
return "Unknown";
case FeedUserSubsegment::kOther:
return "Other";
case FeedUserSubsegment::kActiveOnFeedOnly:
return "ActiveOnFeedOnly";
case FeedUserSubsegment::kActiveOnFeedAndNtpFeatures:
return "ActiveOnFeedAndNtpFeatures";
case FeedUserSubsegment::kNoFeedAndNtpFeatures:
return "NoFeedAndNtpFeatures";
case FeedUserSubsegment::kMvtOnly:
return "MvtOnly";
case FeedUserSubsegment::kReturnToCurrentTabOnly:
return "ReturnToCurrentTabOnly";
case FeedUserSubsegment::kUsedNtpWithoutModules:
return "UsedNtpWithoutModules";
case FeedUserSubsegment::kNoNTPOrHomeOpened:
return "NoNTPOrHomeOpened";
}
}

} // namespace

FeedUserSegment::FeedUserSegment()
: ModelProvider(kFeedUserOptimizationTarget) {}

absl::optional<std::string> FeedUserSegment::GetSubsegmentName(
int subsegment_rank) {
DCHECK(static_cast<int>(FeedUserSubsegment::kUnknown) <= subsegment_rank &&
subsegment_rank <= static_cast<int>(FeedUserSubsegment::kMaxValue));
FeedUserSubsegment subgroup =
static_cast<FeedUserSubsegment>(subsegment_rank);
return FeedUserSubsegmentToString(subgroup);
}

void FeedUserSegment::InitAndFetchModel(
const ModelUpdatedCallback& model_updated_callback) {
proto::SegmentationModelMetadata chrome_start_metadata;
Expand All @@ -110,6 +167,12 @@ void FeedUserSegment::InitAndFetchModel(
writer.AddDiscreteMappingEntries(kFeedUserDiscreteMappingKey,
kDiscreteMappings, 1);

// Add subsegment mapping.
writer.AddDiscreteMappingEntries(
base::StrCat(
{kFeedUserDiscreteMappingKey, kSubsegmentDiscreteMappingSuffix}),
kFeedUserScoreToSubGroup.data(), kFeedUserScoreToSubGroup.size());

// Set features.
writer.AddUmaFeatures(kFeedUserUMAFeatures, ARRAY_SIZE(kFeedUserUMAFeatures));

Expand All @@ -129,19 +192,34 @@ void FeedUserSegment::ExecuteModelWithInput(const std::vector<float>& inputs,
return;
}

float result = 0;
FeedUserSubsegment segment = FeedUserSubsegment::kNoNTPOrHomeOpened;

const bool feed_opened = (inputs[0] + inputs[1] + inputs[2]) >= 2;
const bool mv_tiles_used = inputs[3] >= 2;
const bool return_to_tab_used = inputs[8] >= 2;
const bool home_or_ntp_used = (inputs[4] + inputs[5]) >= 4;

if (feed_opened) {
result = 1;
} else if (mv_tiles_used) {
result = 0.75;
if (mv_tiles_used || return_to_tab_used) {
segment = FeedUserSubsegment::kActiveOnFeedAndNtpFeatures;
} else {
segment = FeedUserSubsegment::kActiveOnFeedOnly;
}
} else if (home_or_ntp_used) {
result = 0.5;
if (mv_tiles_used && return_to_tab_used) {
segment = FeedUserSubsegment::kNoFeedAndNtpFeatures;
} else if (mv_tiles_used) {
segment = FeedUserSubsegment::kMvtOnly;
} else if (return_to_tab_used) {
segment = FeedUserSubsegment::kReturnToCurrentTabOnly;
} else {
segment = segment = FeedUserSubsegment::kUsedNtpWithoutModules;
}
} else {
segment = segment = FeedUserSubsegment::kNoNTPOrHomeOpened;
}

float result = GetScoreForSubsegment(segment);
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), result));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@

namespace segmentation_platform {

// List of sub-segments for Feed segment.
enum class FeedUserSubsegment {
kUnknown = 0,
kOther = 1,
kActiveOnFeedOnly = 2,
kActiveOnFeedAndNtpFeatures = 3,
kNoFeedAndNtpFeatures = 4,
kMvtOnly = 5,
kReturnToCurrentTabOnly = 6,
kUsedNtpWithoutModules = 7,
kNoNTPOrHomeOpened = 8,
kMaxValue = kNoNTPOrHomeOpened
};

// Segmentation Chrome Feed user model provider. Provides a default model and
// metadata for the Feed user optimization target.
class FeedUserSegment : public ModelProvider {
Expand All @@ -19,6 +33,11 @@ class FeedUserSegment : public ModelProvider {
FeedUserSegment(FeedUserSegment&) = delete;
FeedUserSegment& operator=(FeedUserSegment&) = delete;

// Returns the name of the subsegment for the given segment and the
// `subsegment_rank`. The `subsegment_rank` should be computed based on the
// subsegment discrete mapping in the model metadata.
static absl::optional<std::string> GetSubsegmentName(int subsegment_rank);

// ModelProvider implementation.
void InitAndFetchModel(
const ModelUpdatedCallback& model_updated_callback) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,40 @@
#include "testing/gtest/include/gtest/gtest.h"

namespace segmentation_platform {
namespace {

// TODO(ssid): Use metadata_utils or share common code for this function.
int ConvertToDiscreteScore(const std::string& mapping_key,
float input_score,
const proto::SegmentationModelMetadata& metadata) {
auto iter = metadata.discrete_mappings().find(mapping_key);
if (iter == metadata.discrete_mappings().end()) {
iter =
metadata.discrete_mappings().find(metadata.default_discrete_mapping());
if (iter == metadata.discrete_mappings().end())
return 0;
}
DCHECK(iter != metadata.discrete_mappings().end());

const auto& mapping = iter->second;

// Iterate over the entries and find the largest entry whose min result is
// equal to or less than the input.
int discrete_result = 0;
float largest_score_below_input_score = std::numeric_limits<float>::min();
for (int i = 0; i < mapping.entries_size(); i++) {
const auto& entry = mapping.entries(i);
if (entry.min_result() <= input_score &&
entry.min_result() > largest_score_below_input_score) {
largest_score_below_input_score = entry.min_result();
discrete_result = entry.rank();
}
}

return discrete_result;
}

} // namespace

class FeedUserModelTest : public testing::Test {
public:
Expand Down Expand Up @@ -42,60 +76,76 @@ class FeedUserModelTest : public testing::Test {
int64_t) {
EXPECT_EQ(metadata_utils::ValidateMetadataAndFeatures(metadata),
metadata_utils::ValidationResult::kValidationSuccess);
fetched_metadata_ = metadata;
std::move(closure).Run();
}

void ExpectExecutionWithInput(const std::vector<float>& inputs,
bool expected_error,
float expected_result) {
absl::optional<float> ExpectExecutionWithInput(
const std::vector<float>& inputs) {
absl::optional<float> result;
base::RunLoop loop;
feed_user_model_->ExecuteModelWithInput(
inputs, base::BindOnce(&FeedUserModelTest::OnExecutionFinishedCallback,
base::Unretained(this), loop.QuitClosure(),
expected_error, expected_result));
inputs,
base::BindOnce(&FeedUserModelTest::OnExecutionFinishedCallback,
base::Unretained(this), loop.QuitClosure(), &result));
loop.Run();
return result;
}

void OnExecutionFinishedCallback(base::RepeatingClosure closure,
bool expected_error,
float expected_result,
absl::optional<float>* output,
const absl::optional<float>& result) {
if (expected_error) {
EXPECT_FALSE(result.has_value());
} else {
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), expected_result);
}
*output = result;
std::move(closure).Run();
}

protected:
base::test::TaskEnvironment task_environment_;
std::unique_ptr<FeedUserSegment> feed_user_model_;
absl::optional<proto::SegmentationModelMetadata> fetched_metadata_;
};

TEST_F(FeedUserModelTest, InitAndFetchModel) {
ExpectInitAndFetchModel();
}

TEST_F(FeedUserModelTest, ExecuteModelWithInput) {
ExpectInitAndFetchModel();
ASSERT_TRUE(fetched_metadata_);

std::vector<float> input(9, 0);

ExpectExecutionWithInput(input, false, 0);
absl::optional<float> result = ExpectExecutionWithInput(input);
ASSERT_TRUE(result);
EXPECT_EQ("NoNTPOrHomeOpened",
FeedUserSegment::GetSubsegmentName(ConvertToDiscreteScore(
"feed_user_segment_subsegment", *result, *fetched_metadata_)));

input[4] = 3;
input[5] = 2;
ExpectExecutionWithInput(input, false, 0.5);
result = ExpectExecutionWithInput(input);
ASSERT_TRUE(result);
EXPECT_EQ("UsedNtpWithoutModules",
FeedUserSegment::GetSubsegmentName(ConvertToDiscreteScore(
"feed_user_segment_subsegment", *result, *fetched_metadata_)));

input[3] = 3;
ExpectExecutionWithInput(input, false, 0.75);
result = ExpectExecutionWithInput(input);
ASSERT_TRUE(result);
EXPECT_EQ("MvtOnly",
FeedUserSegment::GetSubsegmentName(ConvertToDiscreteScore(
"feed_user_segment_subsegment", *result, *fetched_metadata_)));

input[0] = 1;
input[2] = 2;
ExpectExecutionWithInput(input, false, 1);

ExpectExecutionWithInput({}, true, 0);
ExpectExecutionWithInput({1, 2}, true, 0);
result = ExpectExecutionWithInput(input);
ASSERT_TRUE(result);
EXPECT_EQ("ActiveOnFeedAndNtpFeatures",
FeedUserSegment::GetSubsegmentName(ConvertToDiscreteScore(
"feed_user_segment_subsegment", *result, *fetched_metadata_)));

EXPECT_FALSE(ExpectExecutionWithInput({}));
EXPECT_FALSE(ExpectExecutionWithInput({1, 2}));
}

} // namespace segmentation_platform
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,23 @@ void FieldTrialRegisterImpl::RegisterFieldTrial(base::StringPiece trial_name,
variations::SyntheticTrialAnnotationMode::kCurrentLog);
}

void FieldTrialRegisterImpl::RegisterSubsegmentFieldTrialIfNeeded(
base::StringPiece trial_name,
OptimizationTarget segment_id,
int subsegment_rank) {
absl::optional<std::string> group_name;
// TODO(ssid): Make GetSubsegmentName as a ModelProvider API so that clients
// can simply implement it instead of adding conditions here, once the
// subsegment process is more stable.
if (segment_id ==
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER) {
group_name = FeedUserSegment::GetSubsegmentName(subsegment_rank);
}

if (!group_name) {
return;
}
RegisterFieldTrial(trial_name, *group_name);
}

} // namespace segmentation_platform
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class FieldTrialRegisterImpl : public FieldTrialRegister {
// FieldTrialRegister:
void RegisterFieldTrial(base::StringPiece trial_name,
base::StringPiece group_name) override;

void RegisterSubsegmentFieldTrialIfNeeded(
base::StringPiece trial_name,
optimization_guide::proto::OptimizationTarget segment_id,
int subsegment_rank) override;
};

} // namespace segmentation_platform
Expand Down
4 changes: 4 additions & 0 deletions components/segmentation_platform/internal/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ static_library("internal") {
"execution/processing/uma_feature_processor.h",
"local_state_helper_impl.cc",
"local_state_helper_impl.h",
"metric_filter_utils.cc",
"metric_filter_utils.h",
"platform_options.cc",
"platform_options.h",
"scheduler/execution_service.cc",
Expand All @@ -95,6 +97,8 @@ static_library("internal") {
"segmentation_platform_service_impl.h",
"segmentation_ukm_helper.cc",
"segmentation_ukm_helper.h",
"selection/experimental_group_recorder.cc",
"selection/experimental_group_recorder.h",
"selection/segment_result_provider.cc",
"selection/segment_result_provider.h",
"selection/segment_score_provider.cc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void TestSegmentInfoDatabase::AddPredictionResult(OptimizationTarget segment_id,

void TestSegmentInfoDatabase::AddDiscreteMapping(
OptimizationTarget segment_id,
float mappings[][2],
const float mappings[][2],
int num_pairs,
const std::string& discrete_mapping_key) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class TestSegmentInfoDatabase : public SegmentInfoDatabase {
float score,
base::Time timestamp);
void AddDiscreteMapping(OptimizationTarget segment_id,
float mappings[][2],
const float mappings[][2],
int num_pairs,
const std::string& discrete_mapping_key);
void SetBucketDuration(OptimizationTarget segment_id,
Expand Down

0 comments on commit ef26ec7

Please sign in to comment.