Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This CL introduces DefaultModelManager to handle default model related queries. It also provides helper methods to provide metadata from both the server model and default model. SignalFilterProcessor and DatabaseMaintenanceImpl were updated to track signals from both types of models. Bug: 1298756 Change-Id: I5470cea068ba7d9311912489b278bde90bbbbbc6 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3529267 Reviewed-by: Tommy Nyquist <nyquist@chromium.org> Reviewed-by: Siddhartha S <ssid@chromium.org> Commit-Queue: Shakti Sahu <shaktisahu@chromium.org> Cr-Commit-Position: refs/heads/main@{#984358}
- Loading branch information
Shakti Sahu
authored and
Chromium LUCI CQ
committed
Mar 23, 2022
1 parent
461a08f
commit b885cd3
Showing
16 changed files
with
507 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
121 changes: 121 additions & 0 deletions
121
components/segmentation_platform/internal/execution/default_model_manager.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
// Copyright 2022 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "components/segmentation_platform/internal/execution/default_model_manager.h" | ||
|
||
#include "base/threading/thread_task_runner_handle.h" | ||
#include "components/segmentation_platform/internal/database/segment_info_database.h" | ||
|
||
namespace segmentation_platform { | ||
|
||
DefaultModelManager::DefaultModelManager( | ||
ModelProviderFactory* model_provider_factory, | ||
const std::vector<OptimizationTarget>& segment_ids) { | ||
for (OptimizationTarget segment_id : segment_ids) { | ||
if (!model_provider_factory) | ||
continue; | ||
std::unique_ptr<ModelProvider> provider = | ||
model_provider_factory->CreateDefaultProvider(segment_id); | ||
if (!provider) | ||
continue; | ||
default_model_providers_.emplace( | ||
std::make_pair(segment_id, std::move(provider))); | ||
} | ||
} | ||
|
||
DefaultModelManager::~DefaultModelManager() = default; | ||
|
||
void DefaultModelManager::GetAllSegmentInfoFromDefaultModel( | ||
const std::vector<OptimizationTarget>& segment_ids, | ||
MultipleSegmentInfoCallback callback) { | ||
auto result = std::make_unique<SegmentInfoList>(); | ||
std::deque<OptimizationTarget> remaining_segment_ids(segment_ids.begin(), | ||
segment_ids.end()); | ||
GetNextSegmentInfoFromDefaultModel( | ||
std::move(result), std::move(remaining_segment_ids), std::move(callback)); | ||
} | ||
|
||
void DefaultModelManager::GetNextSegmentInfoFromDefaultModel( | ||
std::unique_ptr<SegmentInfoList> result, | ||
std::deque<OptimizationTarget> remaining_segment_ids, | ||
MultipleSegmentInfoCallback callback) { | ||
OptimizationTarget segment_id = | ||
OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN; | ||
ModelProvider* default_provider = nullptr; | ||
|
||
// Find the next available default provider. | ||
while (!default_provider && !remaining_segment_ids.empty()) { | ||
segment_id = remaining_segment_ids.front(); | ||
remaining_segment_ids.pop_front(); | ||
if (default_model_providers_.count(segment_id) == 1) { | ||
default_provider = default_model_providers_[segment_id].get(); | ||
break; | ||
} | ||
} | ||
|
||
if (!default_provider) { | ||
// If there are no more default providers, return the result so far. | ||
base::ThreadTaskRunnerHandle::Get()->PostTask( | ||
FROM_HERE, base::BindOnce(std::move(callback), std::move(result))); | ||
return; | ||
} | ||
|
||
default_provider->InitAndFetchModel(base::BindRepeating( | ||
&DefaultModelManager::OnFetchDefaultModel, weak_ptr_factory_.GetWeakPtr(), | ||
base::Passed(&result), remaining_segment_ids, base::Passed(&callback))); | ||
} | ||
|
||
void DefaultModelManager::OnFetchDefaultModel( | ||
std::unique_ptr<SegmentInfoList> result, | ||
std::deque<OptimizationTarget> remaining_segment_ids, | ||
MultipleSegmentInfoCallback callback, | ||
OptimizationTarget segment_id, | ||
proto::SegmentationModelMetadata metadata, | ||
int64_t model_version) { | ||
proto::SegmentInfo segment_info; | ||
segment_info.set_segment_id(segment_id); | ||
segment_info.mutable_model_metadata()->CopyFrom(metadata); | ||
segment_info.set_model_version(model_version); | ||
result->push_back(std::make_pair(segment_id, segment_info)); | ||
|
||
GetNextSegmentInfoFromDefaultModel( | ||
std::move(result), std::move(remaining_segment_ids), std::move(callback)); | ||
} | ||
|
||
void DefaultModelManager::GetAllSegmentInfoFromBothModels( | ||
const std::vector<OptimizationTarget>& segment_ids, | ||
SegmentInfoDatabase* segment_database, | ||
MultipleSegmentInfoCallback callback) { | ||
segment_database->GetSegmentInfoForSegments( | ||
segment_ids, | ||
base::BindOnce(&DefaultModelManager::OnGetAllSegmentInfoFromDatabase, | ||
weak_ptr_factory_.GetWeakPtr(), segment_ids, | ||
std::move(callback))); | ||
} | ||
|
||
void DefaultModelManager::OnGetAllSegmentInfoFromDatabase( | ||
const std::vector<OptimizationTarget>& segment_ids, | ||
MultipleSegmentInfoCallback callback, | ||
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos) { | ||
GetAllSegmentInfoFromDefaultModel( | ||
segment_ids, | ||
base::BindOnce(&DefaultModelManager::OnGetAllSegmentInfoFromDefaultModel, | ||
weak_ptr_factory_.GetWeakPtr(), std::move(callback), | ||
std::move(segment_infos))); | ||
} | ||
|
||
void DefaultModelManager::OnGetAllSegmentInfoFromDefaultModel( | ||
MultipleSegmentInfoCallback callback, | ||
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos_from_db, | ||
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> | ||
segment_infos_from_default_model) { | ||
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> merged_results = | ||
std::move(segment_infos_from_db); | ||
for (const auto& segment_info : *segment_infos_from_default_model) | ||
merged_results->push_back(std::move(segment_info)); | ||
|
||
std::move(callback).Run(std::move(merged_results)); | ||
} | ||
|
||
} // namespace segmentation_platform |
93 changes: 93 additions & 0 deletions
93
components/segmentation_platform/internal/execution/default_model_manager.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright 2022 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_ | ||
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_ | ||
|
||
#include <deque> | ||
#include <map> | ||
#include <memory> | ||
#include <set> | ||
#include <vector> | ||
|
||
#include "base/callback.h" | ||
#include "base/containers/flat_map.h" | ||
#include "base/logging.h" | ||
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h" | ||
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h" | ||
#include "components/segmentation_platform/public/model_provider.h" | ||
#include "third_party/abseil-cpp/absl/types/optional.h" | ||
|
||
using optimization_guide::proto::OptimizationTarget; | ||
|
||
namespace segmentation_platform { | ||
class SegmentInfoDatabase; | ||
|
||
// DefaultModelManager provides support to query all default models available. | ||
// It also provides useful methods to combine results from both the database and | ||
// the default model. | ||
class DefaultModelManager { | ||
public: | ||
DefaultModelManager(ModelProviderFactory* model_provider_factory, | ||
const std::vector<OptimizationTarget>& segment_ids); | ||
virtual ~DefaultModelManager(); | ||
|
||
// Disallow copy/assign. | ||
DefaultModelManager(const DefaultModelManager&) = delete; | ||
DefaultModelManager& operator=(const DefaultModelManager&) = delete; | ||
|
||
// Callback for returning a list of segment infos associated with IDs. | ||
// The same segment ID can be repeated multiple times. | ||
using SegmentInfoList = | ||
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>; | ||
using MultipleSegmentInfoCallback = | ||
base::OnceCallback<void(std::unique_ptr<SegmentInfoList>)>; | ||
|
||
// Utility function to get the segment info from both the database and the | ||
// default model for a given set of segment IDs. The result can contain | ||
// the same segment ID multiple times. | ||
virtual void GetAllSegmentInfoFromBothModels( | ||
const std::vector<OptimizationTarget>& segment_ids, | ||
SegmentInfoDatabase* segment_database, | ||
MultipleSegmentInfoCallback callback); | ||
|
||
private: | ||
// Called to get the segment info from the default model for a given set of | ||
// segment IDs. | ||
virtual void GetAllSegmentInfoFromDefaultModel( | ||
const std::vector<OptimizationTarget>& segment_ids, | ||
MultipleSegmentInfoCallback callback); | ||
|
||
void GetNextSegmentInfoFromDefaultModel( | ||
std::unique_ptr<SegmentInfoList> result, | ||
std::deque<OptimizationTarget> remaining_segment_ids, | ||
MultipleSegmentInfoCallback callback); | ||
|
||
void OnFetchDefaultModel(std::unique_ptr<SegmentInfoList> result, | ||
std::deque<OptimizationTarget> remaining_segment_ids, | ||
MultipleSegmentInfoCallback callback, | ||
OptimizationTarget segment_id, | ||
proto::SegmentationModelMetadata metadata, | ||
int64_t model_version); | ||
|
||
void OnGetAllSegmentInfoFromDatabase( | ||
const std::vector<OptimizationTarget>& segment_ids, | ||
MultipleSegmentInfoCallback callback, | ||
std::unique_ptr<SegmentInfoList> segment_infos); | ||
|
||
void OnGetAllSegmentInfoFromDefaultModel( | ||
MultipleSegmentInfoCallback callback, | ||
std::unique_ptr<SegmentInfoList> segment_infos_from_db, | ||
std::unique_ptr<SegmentInfoList> segment_infos_from_default_model); | ||
|
||
// Default model providers. | ||
std::map<OptimizationTarget, std::unique_ptr<ModelProvider>> | ||
default_model_providers_; | ||
|
||
base::WeakPtrFactory<DefaultModelManager> weak_ptr_factory_{this}; | ||
}; | ||
|
||
} // namespace segmentation_platform | ||
|
||
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_ |
Oops, something went wrong.