Skip to content

Commit

Permalink
[Merge-103][Segmentation] Check that validation messages only contain…
Browse files Browse the repository at this point in the history
… data after UKM approval

The input tensors may contain data before UKM approval. This CL
will verify that this doesn't happen.

(cherry picked from commit cae984c)

Bug: 1327419
Change-Id: I835f982748090b49e21a2efb34402877516b118d
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3655131
Commit-Queue: Min Qin <qinmin@chromium.org>
Reviewed-by: Siddhartha S <ssid@chromium.org>
Cr-Original-Commit-Position: refs/heads/main@{#1005600}
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3659636
Cr-Commit-Position: refs/branch-heads/5060@{#186}
Cr-Branched-From: b83393d-refs/heads/main@{#1002911}
  • Loading branch information
Min Qin authored and Chromium LUCI CQ committed May 23, 2022
1 parent 715cbf7 commit 208206b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,8 @@ bool TrainingDataCollectorImpl::CanReportTrainingData(
base::TimeDelta signal_storage_length =
model_metadata.signal_storage_length() *
metadata_utils::GetTimeUnit(model_metadata);
if (LocalStateHelper::GetInstance().GetPrefTime(
kSegmentationUkmMostRecentAllowedTimeKey) +
signal_storage_length >=
clock_->Now()) {
if (!SegmentationUkmHelper::AllowedToUploadData(signal_storage_length,
clock_)) {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kPartialDataNotAllowed);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct ModelExecutorImpl::ExecutionState {
std::vector<float> input_tensor;
base::Time total_execution_start_time;
base::Time model_execution_start_time;
base::TimeDelta signal_storage_length;
};

ModelExecutorImpl::ModelExecutionTraceEvent::ModelExecutionTraceEvent(
Expand Down Expand Up @@ -120,6 +121,10 @@ void ModelExecutorImpl::ExecuteModel(const proto::SegmentInfo& segment_info,
}

state->model_version = segment_info.model_version();
const proto::SegmentationModelMetadata& model_metadata =
segment_info.model_metadata();
state->signal_storage_length = model_metadata.signal_storage_length() *
metadata_utils::GetTimeUnit(model_metadata);
feature_list_query_processor_->ProcessFeatureList(
segment_info.model_metadata(), segment_id, clock_->Now(),
FeatureListQueryProcessor::ProcessOption::kInputsOnly,
Expand Down Expand Up @@ -180,7 +185,8 @@ void ModelExecutorImpl::OnModelExecutionComplete(
<< optimization_guide::proto::OptimizationTarget_Name(
state->segment_id);
stats::RecordModelExecutionResult(state->segment_id, result.value());
if (state->model_version) {
if (state->model_version && SegmentationUkmHelper::AllowedToUploadData(
state->signal_storage_length, clock_)) {
SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
state->segment_id, state->model_version, state->input_tensor,
result.value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
#include "base/metrics/field_trial_params.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
#include "base/time/clock.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/local_state_helper.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "services/metrics/public/cpp/ukm_recorder.h"

Expand Down Expand Up @@ -200,4 +203,14 @@ int64_t SegmentationUkmHelper::FloatToInt64(float f) {
return base::bit_cast<int64_t>(static_cast<double>(f));
}

// static
bool SegmentationUkmHelper::AllowedToUploadData(
base::TimeDelta signal_storage_length,
base::Clock* clock) {
return LocalStateHelper::GetInstance().GetPrefTime(
kSegmentationUkmMostRecentAllowedTimeKey) +
signal_storage_length <
clock->Now();
}

} // namespace segmentation_platform
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@

#include "base/containers/flat_set.h"
#include "base/no_destructor.h"
#include "base/time/time.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
#include "third_party/abseil-cpp/absl/types/optional.h"

using optimization_guide::proto::OptimizationTarget;

namespace base {
class Clock;
}

namespace ukm::builders {
class Segmentation_ModelExecution;
} // namespace ukm::builders
Expand Down Expand Up @@ -52,6 +57,11 @@ class SegmentationUkmHelper {
// Helper method to encode a float number into int64.
static int64_t FloatToInt64(float f);

// Helper method to check if data is allowed to upload through ukm
// given a clock and the signal storage length.
static bool AllowedToUploadData(base::TimeDelta signal_storage_length,
base::Clock* clock);

// Gets a set of segment IDs that are allowed to upload metrics.
const base::flat_set<OptimizationTarget>& allowed_segment_ids() {
return allowed_segment_ids_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
#include "base/bit_cast.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
#include "components/prefs/testing_pref_service.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/local_state_helper.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/ukm/test_ukm_recorder.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "testing/gtest/include/gtest/gtest.h"
Expand Down Expand Up @@ -248,4 +253,23 @@ TEST_F(SegmentationUkmHelperTest, OutputsValidation) {
ASSERT_NE(source_id, ukm::kInvalidSourceId);
}

TEST_F(SegmentationUkmHelperTest, AllowedToUploadData) {
TestingPrefServiceSimple prefs;
SegmentationPlatformService::RegisterLocalStatePrefs(prefs.registry());
LocalStateHelper::GetInstance().Initialize(&prefs);

base::SimpleTestClock clock;
clock.SetNow(base::Time::Now());
LocalStateHelper::GetInstance().SetPrefTime(
kSegmentationUkmMostRecentAllowedTimeKey, clock.Now());

ASSERT_FALSE(
SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
clock.Advance(base::Seconds(10));
ASSERT_TRUE(
SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
ASSERT_FALSE(
SegmentationUkmHelper::AllowedToUploadData(base::Seconds(11), &clock));
}

} // namespace segmentation_platform

0 comments on commit 208206b

Please sign in to comment.