Skip to content

Commit

Permalink
[CDM] Cache and database implementation with ModelSource.
Browse files Browse the repository at this point in the history
This CL implements database with default model support.
It handles logic for API with ModelSource as param.

In next set of CLs, will introduce few new methods for default models as
well as rewrite logic for Update API to handle deleting entries from
SegmentInfoDatabase for both server and default models.

The plan is to :
1. Add model source to database API and migrate callers to call the API
   with server model. (doesn't break anything as this is kind of same
   flow as before).
2. Implement the logic for database API.
3. Handle the callers logic to support default models.

This ensures that the compilation works. Also different clients will
have different handling for default models at different layers. So
client default models handling could be a part a other set of CLs.

Planned CL:
- Handling changes required for default models in all the callers.

Change-Id: I127b04f9fc2bb5d8b467eef393c2e2b185e3617a
Bug: b/284427798
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4568822
Reviewed-by: Shakti Sahu <shaktisahu@chromium.org>
Commit-Queue: Ritika Gupta <ritikagup@google.com>
Cr-Commit-Position: refs/heads/main@{#1153599}
  • Loading branch information
Ritika Gupta authored and Chromium LUCI CQ committed Jun 6, 2023
1 parent 3f950e3 commit fbbee37
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ SegmentInfoCache::SegmentInfoCache() = default;
SegmentInfoCache::~SegmentInfoCache() = default;

absl::optional<SegmentInfo> SegmentInfoCache::GetSegmentInfo(
SegmentId segment_id) const {
auto it = segment_info_cache_.find(segment_id);
SegmentId segment_id,
ModelSource model_source) const {
auto it = segment_info_cache_.find(std::make_pair(segment_id, model_source));
return (it == segment_info_cache_.end()) ? absl::nullopt
: absl::make_optional(it->second);
}
Expand All @@ -30,7 +31,8 @@ SegmentInfoCache::GetSegmentInfoForSegments(
std::unique_ptr<SegmentInfoCache::SegmentInfoList> segments_found =
std::make_unique<SegmentInfoCache::SegmentInfoList>();
for (SegmentId target : segment_ids) {
absl::optional<SegmentInfo> info = GetSegmentInfo(target);
absl::optional<SegmentInfo> info =
GetSegmentInfo(target, ModelSource::SERVER_MODEL_SOURCE);
if (info.has_value()) {
segments_found->emplace_back(
std::make_pair(target, std::move(info.value())));
Expand All @@ -41,11 +43,14 @@ SegmentInfoCache::GetSegmentInfoForSegments(

void SegmentInfoCache::UpdateSegmentInfo(
SegmentId segment_id,
ModelSource model_source,
absl::optional<SegmentInfo> segment_info) {
if (segment_info.has_value()) {
segment_info_cache_[segment_id] = std::move(segment_info.value());
segment_info_cache_[std::make_pair(segment_id, model_source)] =
std::move(segment_info.value());
} else {
segment_info_cache_.erase(segment_info_cache_.find(segment_id));
segment_info_cache_.erase(
segment_info_cache_.find(std::make_pair(segment_id, model_source)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

namespace segmentation_platform {

using proto::ModelSource;
using proto::SegmentId;
using proto::SegmentInfo;

Expand All @@ -34,22 +35,26 @@ class SegmentInfoCache {
SegmentInfoCache(const SegmentInfoCache&) = delete;
SegmentInfoCache& operator=(const SegmentInfoCache&) = delete;

// Returns an optional SegmentInfo for a `segment_id`.
absl::optional<SegmentInfo> GetSegmentInfo(SegmentId segment_id) const;
// Returns an optional SegmentInfo for a `segment_id` based on `model_source`.
absl::optional<SegmentInfo> GetSegmentInfo(SegmentId segment_id,
ModelSource model_source) const;

// Returns list of segment info for list of `segment_ids` found in the cache.
// If segment info is not found for a segment id, nothing is returned for it.
// This only returns segment info list for server side segments.
std::unique_ptr<SegmentInfoList> GetSegmentInfoForSegments(
const base::flat_set<SegmentId>& segment_ids) const;

// Updates cache with `segment_info` for a `segment_id`.
// It deletes the entry in cache if `segment_info` is nullopt.
// Updates cache with `segment_info` for a `segment_id` based on
// `model_source`. It deletes the entry in cache if `segment_info` is nullopt.
void UpdateSegmentInfo(SegmentId segment_id,
ModelSource model_source,
absl::optional<SegmentInfo> segment_info);

private:
// Map storing SegmentInfo for a SegmentId.
base::flat_map<SegmentId, SegmentInfo> segment_info_cache_;
// Map storing SegmentInfo for a SegmentId and ModelSource.
base::flat_map<std::pair<SegmentId, ModelSource>, SegmentInfo>
segment_info_cache_;
};

} // namespace segmentation_platform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ const proto::SegmentId kSegmentId =
const proto::SegmentId kSegmentId2 =
proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHOPPING_USER;

proto::SegmentInfo CreateSegment(SegmentId segment_id) {
proto::SegmentInfo CreateSegment(SegmentId segment_id,
ModelSource model_source) {
proto::SegmentInfo info;
info.set_segment_id(segment_id);
info.set_model_source(model_source);
return info;
}
} // namespace
Expand All @@ -45,34 +47,58 @@ class SegmentInfoCacheTest : public testing::Test {
};

TEST_F(SegmentInfoCacheTest, GetSegmentInfoFromEmptyCache) {
auto segment_info_ = segment_info_cache_->GetSegmentInfo(kSegmentId);
auto segment_info_ = segment_info_cache_->GetSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE);
EXPECT_EQ(absl::nullopt, segment_info_);
}

TEST_F(SegmentInfoCacheTest, GetSegmentInfoFromCache) {
segment_info_cache_->UpdateSegmentInfo(kSegmentId, CreateSegment(kSegmentId));
auto segment_info_ = segment_info_cache_->GetSegmentInfo(kSegmentId);
segment_info_cache_->UpdateSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE,
CreateSegment(kSegmentId, ModelSource::SERVER_MODEL_SOURCE));
auto segment_info_ = segment_info_cache_->GetSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE);
EXPECT_TRUE(segment_info_.has_value());
EXPECT_EQ(kSegmentId, segment_info_.value().segment_id());

// Calling GetSegmentInfo method again.
segment_info_ = segment_info_cache_->GetSegmentInfo(kSegmentId);
segment_info_ = segment_info_cache_->GetSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE);
EXPECT_TRUE(segment_info_.has_value());
EXPECT_EQ(kSegmentId, segment_info_.value().segment_id());

// Calling GetSegmentInfo method again for default model.
segment_info_ = segment_info_cache_->GetSegmentInfo(
kSegmentId, ModelSource::DEFAULT_MODEL_SOURCE);
EXPECT_FALSE(segment_info_.has_value());
}

TEST_F(SegmentInfoCacheTest, GetSegmentInfoForSegmentsFromCache) {
// Updating SegmentInfo for 'kSegmentId' and calling
// GetSegmentInfoForSegments with superset of segment ids.
segment_info_cache_->UpdateSegmentInfo(kSegmentId, CreateSegment(kSegmentId));
segment_info_cache_->UpdateSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE,
CreateSegment(kSegmentId, ModelSource::SERVER_MODEL_SOURCE));
auto segments_found =
segment_info_cache_->GetSegmentInfoForSegments({kSegmentId, kSegmentId2});
EXPECT_EQ(1u, segments_found.get()->size());
EXPECT_EQ(kSegmentId, segments_found.get()->at(0).first);

// Creating default model segment for 'kSegmentId2' and calling
// GetSegmentInfoForSegments with all segment ids.
segment_info_cache_->UpdateSegmentInfo(
kSegmentId2, ModelSource::DEFAULT_MODEL_SOURCE,
CreateSegment(kSegmentId2, ModelSource::DEFAULT_MODEL_SOURCE));
segments_found =
segment_info_cache_->GetSegmentInfoForSegments({kSegmentId, kSegmentId2});
EXPECT_EQ(1u, segments_found.get()->size());
EXPECT_EQ(kSegmentId, segments_found.get()->at(0).first);

// Updating SegmentInfo for 'kSegmentId2' and calling
// GetSegmentInfoForSegments with all segment ids.
segment_info_cache_->UpdateSegmentInfo(kSegmentId2,
CreateSegment(kSegmentId2));
segment_info_cache_->UpdateSegmentInfo(
kSegmentId2, ModelSource::SERVER_MODEL_SOURCE,
CreateSegment(kSegmentId2, ModelSource::SERVER_MODEL_SOURCE));
segments_found =
segment_info_cache_->GetSegmentInfoForSegments({kSegmentId, kSegmentId2});
EXPECT_EQ(2u, segments_found.get()->size());
Expand All @@ -81,26 +107,32 @@ TEST_F(SegmentInfoCacheTest, GetSegmentInfoForSegmentsFromCache) {

// Updating absl::nullopt for 'kSegmentId2' and calling
// GetSegmentInfoForSegments with all segment ids.
segment_info_cache_->UpdateSegmentInfo(kSegmentId2, absl::nullopt);
segment_info_cache_->UpdateSegmentInfo(
kSegmentId2, ModelSource::SERVER_MODEL_SOURCE, absl::nullopt);
segments_found =
segment_info_cache_->GetSegmentInfoForSegments({kSegmentId, kSegmentId2});
EXPECT_EQ(1u, segments_found.get()->size());
EXPECT_EQ(kSegmentId, segments_found.get()->at(0).first);
}

TEST_F(SegmentInfoCacheTest, UpdateSegmentInfo) {
proto::SegmentInfo created_segment_info = CreateSegment(kSegmentId);
segment_info_cache_->UpdateSegmentInfo(kSegmentId, created_segment_info);
proto::SegmentInfo created_segment_info =
CreateSegment(kSegmentId, ModelSource::SERVER_MODEL_SOURCE);
segment_info_cache_->UpdateSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE, created_segment_info);

auto segment_info_ = segment_info_cache_->GetSegmentInfo(kSegmentId);
auto segment_info_ = segment_info_cache_->GetSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE);
EXPECT_TRUE(segment_info_.has_value());
EXPECT_EQ(kSegmentId, segment_info_.value().segment_id());

// Update model_version of segment_info
created_segment_info.set_model_version(4);
segment_info_cache_->UpdateSegmentInfo(kSegmentId, created_segment_info);
segment_info_cache_->UpdateSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE, created_segment_info);

segment_info_ = segment_info_cache_->GetSegmentInfo(kSegmentId);
segment_info_ = segment_info_cache_->GetSegmentInfo(
kSegmentId, ModelSource::SERVER_MODEL_SOURCE);
EXPECT_TRUE(segment_info_.has_value());
EXPECT_EQ(kSegmentId, segment_info_.value().segment_id());
EXPECT_EQ(4, segment_info_.value().model_version());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,26 @@ namespace segmentation_platform {

namespace {

std::string ToString(SegmentId segment_id) {
return base::NumberToString(static_cast<int>(segment_id));
std::string ToString(SegmentId segment_id, ModelSource model_source) {
std::string prefix =
(model_source == ModelSource::DEFAULT_MODEL_SOURCE ? "DEFAULT_" : "");
return prefix + base::NumberToString(static_cast<int>(segment_id));
}

ModelSource GetModelSource(absl::optional<proto::SegmentInfo> segment_info) {
ModelSource model_source = ModelSource::UNKNOWN_MODEL_SOURCE;
if (segment_info.has_value()) {
model_source = segment_info->model_source();
}
switch (model_source) {
// If model source is not set in some segment info present in database, we
// consider it to be from server models.
case ModelSource::UNKNOWN_MODEL_SOURCE:
case ModelSource::SERVER_MODEL_SOURCE:
return ModelSource::SERVER_MODEL_SOURCE;
case ModelSource::DEFAULT_MODEL_SOURCE:
return ModelSource::DEFAULT_MODEL_SOURCE;
}
}

} // namespace
Expand Down Expand Up @@ -48,23 +66,24 @@ void SegmentInfoDatabase::GetSegmentInfoForSegments(
}

void SegmentInfoDatabase::GetSegmentInfo(SegmentId segment_id,
ModelSource model_source,
proto::ModelSource model_source,
SegmentInfoCallback callback) {
std::move(callback).Run(cache_->GetSegmentInfo(segment_id));
std::move(callback).Run(cache_->GetSegmentInfo(segment_id, model_source));
}

absl::optional<SegmentInfo> SegmentInfoDatabase::GetCachedSegmentInfo(
SegmentId segment_id,
ModelSource model_source) {
return cache_->GetSegmentInfo(segment_id);
proto::ModelSource model_source) {
return cache_->GetSegmentInfo(segment_id, model_source);
}

void SegmentInfoDatabase::GetTrainingData(SegmentId segment_id,
ModelSource model_source,
TrainingRequestId request_id,
bool delete_from_db,
TrainingDataCallback callback) {
absl::optional<SegmentInfo> segment_info = cache_->GetSegmentInfo(segment_id);
absl::optional<SegmentInfo> segment_info =
cache_->GetSegmentInfo(segment_id, model_source);
absl::optional<proto::TrainingData> result;

// Ignore results if the metadata no longer exists.
Expand Down Expand Up @@ -101,7 +120,8 @@ void SegmentInfoDatabase::UpdateSegment(
SegmentId segment_id,
absl::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback) {
cache_->UpdateSegmentInfo(segment_id, segment_info);
ModelSource model_source = GetModelSource(segment_info);
cache_->UpdateSegmentInfo(segment_id, model_source, segment_info);

// The cache has been updated now. We can notify the client synchronously.
std::move(callback).Run(/*success=*/true);
Expand All @@ -111,10 +131,10 @@ void SegmentInfoDatabase::UpdateSegment(
std::vector<std::pair<std::string, proto::SegmentInfo>>>();
auto keys_to_delete = std::make_unique<std::vector<std::string>>();
if (segment_info.has_value()) {
entries_to_save->emplace_back(
std::make_pair(ToString(segment_id), segment_info.value()));
entries_to_save->emplace_back(std::make_pair(
ToString(segment_id, model_source), segment_info.value()));
} else {
keys_to_delete->emplace_back(ToString(segment_id));
keys_to_delete->emplace_back(ToString(segment_id, model_source));
}
database_->UpdateEntries(std::move(entries_to_save),
std::move(keys_to_delete), base::DoNothing());
Expand All @@ -130,21 +150,24 @@ void SegmentInfoDatabase::UpdateMultipleSegments(
for (auto& segment : segments_to_update) {
const proto::SegmentId segment_id = segment.first;
auto& segment_info = segment.second;

ModelSource model_source = GetModelSource(segment_info);
// Updating the cache.
cache_->UpdateSegmentInfo(segment_id, absl::make_optional(segment_info));
cache_->UpdateSegmentInfo(segment_id, model_source,
absl::make_optional(segment_info));

// Determining entries to save for database.
entries_to_save->emplace_back(
std::make_pair(ToString(segment_id), std::move(segment_info)));
entries_to_save->emplace_back(std::make_pair(
ToString(segment_id, model_source), std::move(segment_info)));
}

// The cache has been updated now. We can notify the client synchronously.
std::move(callback).Run(/*success=*/true);

// TODO (ritikagup@) : Add handling for default models, if required.
// Now write to the database asyncrhonously.
for (auto& segment_id : segments_to_delete) {
entries_to_delete->emplace_back(ToString(segment_id));
entries_to_delete->emplace_back(
ToString(segment_id, proto::ModelSource::SERVER_MODEL_SOURCE));
}

database_->UpdateEntries(std::move(entries_to_save),
Expand All @@ -156,7 +179,7 @@ void SegmentInfoDatabase::SaveSegmentResult(
ModelSource model_source,
absl::optional<proto::PredictionResult> result,
SuccessCallback callback) {
auto segment_info = cache_->GetSegmentInfo(segment_id);
auto segment_info = cache_->GetSegmentInfo(segment_id, model_source);

// Ignore results if the metadata no longer exists.
if (!segment_info.has_value()) {
Expand Down Expand Up @@ -184,7 +207,7 @@ void SegmentInfoDatabase::SaveTrainingData(SegmentId segment_id,
ModelSource model_source,
const proto::TrainingData& data,
SuccessCallback callback) {
auto segment_info = cache_->GetSegmentInfo(segment_id);
auto segment_info = cache_->GetSegmentInfo(segment_id, model_source);

// Ignore data if the metadata no longer exists.
if (!segment_info.has_value()) {
Expand Down Expand Up @@ -222,7 +245,8 @@ void SegmentInfoDatabase::OnLoadAllEntries(
if (success) {
// Add all the entries to the cache on startup.
for (auto info : *all_infos.get()) {
cache_->UpdateSegmentInfo(info.segment_id(), info);
ModelSource model_source = GetModelSource(info);
cache_->UpdateSegmentInfo(info.segment_id(), model_source, info);
}
}
std::move(callback).Run(success);
Expand Down

0 comments on commit fbbee37

Please sign in to comment.