Skip to content

Commit

Permalink
Add methods to support getting reports for and sending reports from a…
Browse files Browse the repository at this point in the history
…ggregation service internals WebUI

Bug: 1348029
Change-Id: I28fd52314c38a47654049373287f592679669e2d
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3794786
Reviewed-by: John Delaney <johnidel@chromium.org>
Commit-Queue: Nan Lin <linnan@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1035160}
  • Loading branch information
linnan-github authored and Chromium LUCI CQ committed Aug 15, 2022
1 parent 81ad161 commit 6a52078
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 22 deletions.
18 changes: 18 additions & 0 deletions content/browser/aggregation_service/aggregation_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
#ifndef CONTENT_BROWSER_AGGREGATION_SERVICE_AGGREGATION_SERVICE_H_
#define CONTENT_BROWSER_AGGREGATION_SERVICE_AGGREGATION_SERVICE_H_

#include <vector>

#include "base/callback_forward.h"
#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
#include "content/browser/aggregation_service/aggregatable_report_sender.h"
#include "content/browser/aggregation_service/aggregation_service_storage.h"
#include "content/public/browser/storage_partition.h"

class GURL;
Expand Down Expand Up @@ -70,6 +74,20 @@ class AggregationService {
// time. It is stored on disk (unless in incognito) until then. See the
// `AggregatableReportScheduler` for details.
virtual void ScheduleReport(AggregatableReportRequest report_request) = 0;

// Gets all pending report requests that are currently stored. Used for
// populating WebUI.
// TODO(linnan): Consider enforcing a limit on the number of requests
// returned.
virtual void GetPendingReportRequestsForWebUI(
base::OnceCallback<void(
std::vector<AggregationServiceStorage::RequestAndId>)> callback) = 0;

// Sends the given reports immediately, and runs `reports_sent_callback` once
// they have all been sent.
virtual void SendReportsForWebUI(
const std::vector<AggregationServiceStorage::RequestId>& ids,
base::OnceClosure reports_sent_callback) = 0;
};

} // namespace content
Expand Down
65 changes: 55 additions & 10 deletions content/browser/aggregation_service/aggregation_service_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
#include <utility>
#include <vector>

#include "base/barrier_closure.h"
#include "base/bind.h"
#include "base/callback.h"
#include "base/callback_helpers.h"
#include "base/check_op.h"
#include "base/files/file_path.h"
#include "base/memory/ptr_util.h"
Expand Down Expand Up @@ -144,25 +146,20 @@ void AggregationServiceImpl::ScheduleReport(

void AggregationServiceImpl::OnScheduledReportTimeReached(
std::vector<AggregationServiceStorage::RequestAndId> requests_and_ids) {
for (AggregationServiceStorage::RequestAndId& elem : requests_and_ids) {
GURL reporting_url = elem.request.GetReportingUrl();
AssembleReport(
std::move(elem.request),
base::BindOnce(
&AggregationServiceImpl::OnReportAssemblyComplete,
// `base::Unretained` is safe as the assembler is owned by `this`.
base::Unretained(this), elem.id, std::move(reporting_url)));
}
AssembleAndSendReports(std::move(requests_and_ids),
/*done=*/base::DoNothing());
}

void AggregationServiceImpl::OnReportAssemblyComplete(
base::OnceClosure done,
AggregationServiceStorage::RequestId request_id,
GURL reporting_url,
absl::optional<AggregatableReport> report,
AggregatableReportAssembler::AssemblyStatus status) {
DCHECK_EQ(report.has_value(),
status == AggregatableReportAssembler::AssemblyStatus::kOk);
if (!report.has_value()) {
std::move(done).Run();
scheduler_->NotifyInProgressRequestFailed(request_id);
return;
}
Expand All @@ -172,12 +169,15 @@ void AggregationServiceImpl::OnReportAssemblyComplete(
base::BindOnce(
&AggregationServiceImpl::OnReportSendingComplete,
// `base::Unretained` is safe as the sender is owned by `this`.
base::Unretained(this), request_id));
base::Unretained(this), std::move(done), request_id));
}

void AggregationServiceImpl::OnReportSendingComplete(
base::OnceClosure done,
AggregationServiceStorage::RequestId request_id,
AggregatableReportSender::RequestStatus status) {
std::move(done).Run();

if (status == AggregatableReportSender::RequestStatus::kOk) {
scheduler_->NotifyInProgressRequestSucceeded(request_id);
} else {
Expand All @@ -192,4 +192,49 @@ void AggregationServiceImpl::SetPublicKeysForTesting(
.WithArgs(url, keyset);
}

void AggregationServiceImpl::AssembleAndSendReports(
std::vector<AggregationServiceStorage::RequestAndId> requests_and_ids,
base::RepeatingClosure done) {
for (AggregationServiceStorage::RequestAndId& elem : requests_and_ids) {
GURL reporting_url = elem.request.GetReportingUrl();
AssembleReport(
std::move(elem.request),
base::BindOnce(
&AggregationServiceImpl::OnReportAssemblyComplete,
// `base::Unretained` is safe as the assembler is owned by `this`.
base::Unretained(this), done, elem.id, std::move(reporting_url)));
}
}

void AggregationServiceImpl::GetPendingReportRequestsForWebUI(
base::OnceCallback<
void(std::vector<AggregationServiceStorage::RequestAndId>)> callback) {
storage_.AsyncCall(&AggregationServiceStorage::GetRequestsReportingOnOrBefore)
.WithArgs(/*not_after_time=*/base::Time::Max())
.Then(std::move(callback));
}

void AggregationServiceImpl::SendReportsForWebUI(
const std::vector<AggregationServiceStorage::RequestId>& ids,
base::OnceClosure reports_sent_callback) {
storage_.AsyncCall(&AggregationServiceStorage::GetRequests)
.WithArgs(ids)
.Then(base::BindOnce(
&AggregationServiceImpl::OnGetRequestsToSendFromWebUI,
weak_factory_.GetWeakPtr(), std::move(reports_sent_callback)));
}

void AggregationServiceImpl::OnGetRequestsToSendFromWebUI(
base::OnceClosure reports_sent_callback,
std::vector<AggregationServiceStorage::RequestAndId> requests_and_ids) {
if (requests_and_ids.empty()) {
std::move(reports_sent_callback).Run();
return;
}

auto barrier = base::BarrierClosure(requests_and_ids.size(),
std::move(reports_sent_callback));
AssembleAndSendReports(std::move(requests_and_ids), std::move(barrier));
}

} // namespace content
24 changes: 23 additions & 1 deletion content/browser/aggregation_service/aggregation_service_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
#include <memory>
#include <vector>

#include "base/callback_forward.h"
#include "base/containers/flat_map.h"
#include "base/memory/weak_ptr.h"
#include "base/threading/sequence_bound.h"
#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
#include "content/browser/aggregation_service/aggregatable_report_scheduler.h"
#include "content/browser/aggregation_service/aggregatable_report_sender.h"
#include "content/browser/aggregation_service/aggregation_service.h"
#include "content/browser/aggregation_service/aggregation_service_storage.h"
#include "content/browser/aggregation_service/aggregation_service_storage_context.h"
#include "content/common/content_export.h"
#include "content/public/browser/storage_partition.h"
Expand Down Expand Up @@ -75,6 +78,13 @@ class CONTENT_EXPORT AggregationServiceImpl
StoragePartition::StorageKeyMatcherFunction filter,
base::OnceClosure done) override;
void ScheduleReport(AggregatableReportRequest report_request) override;
void GetPendingReportRequestsForWebUI(
base::OnceCallback<
void(std::vector<AggregationServiceStorage::RequestAndId>)> callback)
override;
void SendReportsForWebUI(
const std::vector<AggregationServiceStorage::RequestId>& ids,
base::OnceClosure reports_sent_callback) override;

// AggregationServiceStorageContext:
const base::SequenceBound<AggregationServiceStorage>& GetStorage() override;
Expand All @@ -97,18 +107,30 @@ class CONTENT_EXPORT AggregationServiceImpl
std::vector<AggregationServiceStorage::RequestAndId> requests_and_ids);

void OnReportAssemblyComplete(
base::OnceClosure done,
AggregationServiceStorage::RequestId request_id,
GURL reporting_url,
absl::optional<AggregatableReport> report,
AggregatableReportAssembler::AssemblyStatus status);

void OnReportSendingComplete(AggregationServiceStorage::RequestId request_id,
void OnReportSendingComplete(base::OnceClosure done,
AggregationServiceStorage::RequestId request_id,
AggregatableReportSender::RequestStatus status);

void AssembleAndSendReports(
std::vector<AggregationServiceStorage::RequestAndId> requests_and_ids,
base::RepeatingClosure done);

void OnGetRequestsToSendFromWebUI(
base::OnceClosure reports_sent_callback,
std::vector<AggregationServiceStorage::RequestAndId> requests_and_ids);

base::SequenceBound<AggregationServiceStorage> storage_;
std::unique_ptr<AggregatableReportScheduler> scheduler_;
std::unique_ptr<AggregatableReportAssembler> assembler_;
std::unique_ptr<AggregatableReportSender> sender_;

base::WeakPtrFactory<AggregationServiceImpl> weak_factory_{this};
};

} // namespace content
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "content/browser/aggregation_service/aggregation_service_impl.h"

#include <stddef.h>
#include <stdint.h>

#include <map>
Expand All @@ -18,6 +19,7 @@
#include "base/files/scoped_temp_dir.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/time/time.h"
#include "content/browser/aggregation_service/aggregatable_report.h"
Expand All @@ -29,12 +31,18 @@
#include "content/public/test/browser_task_environment.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"

namespace content {

namespace {
using aggregation_service::RequestIdIs;
using testing::ElementsAre;
} // namespace

// TODO(alexmt): Consider rewriting these tests using gmock.

class TestAggregatableReportAssembler : public AggregatableReportAssembler {
Expand All @@ -49,6 +57,11 @@ class TestAggregatableReportAssembler : public AggregatableReportAssembler {
void AssembleReport(AggregatableReportRequest request,
AssemblyCallback callback) override {
callbacks_.emplace(unique_id_counter_++, std::move(callback));

if (callbacks_.size() < min_requests_count_)
return;

wait_loop_.Quit();
}

void TriggerResponse(int64_t report_id,
Expand All @@ -61,9 +74,19 @@ class TestAggregatableReportAssembler : public AggregatableReportAssembler {
callbacks_.erase(report_id);
}

void WaitForRequests(size_t num_requests) {
min_requests_count_ = num_requests;
if (callbacks_.size() >= num_requests)
return;
wait_loop_.Run();
}

private:
int64_t unique_id_counter_ = 0;
std::map<int64_t, AssemblyCallback> callbacks_;

size_t min_requests_count_ = 0;
base::RunLoop wait_loop_;
};

class TestAggregatableReportSender : public AggregatableReportSender {
Expand Down Expand Up @@ -211,6 +234,12 @@ class AggregationServiceImplTest : public testing::Test {
service()->ScheduleReport(std::move(request));
}

void StoreReport(AggregatableReportRequest request) {
service()
->storage_.AsyncCall(&AggregationServiceStorage::StoreRequest)
.WithArgs(std::move(request));
}

AggregationServiceImpl* service() { return service_impl_.get(); }
TestAggregatableReportAssembler* assembler() { return test_assembler_; }
TestAggregatableReportSender* sender() { return test_sender_; }
Expand Down Expand Up @@ -453,4 +482,32 @@ TEST_F(AggregationServiceImplTest,
.value());
}

TEST_F(AggregationServiceImplTest, GetPendingReportRequestsForWebUI) {
StoreReport(aggregation_service::CreateExampleRequest());
StoreReport(aggregation_service::CreateExampleRequest());

base::RunLoop run_loop;
service()->GetPendingReportRequestsForWebUI(base::BindLambdaForTesting(
[&](std::vector<AggregationServiceStorage::RequestAndId>
requests_and_ids) {
// IDs autoincrement from 1.
EXPECT_THAT(
requests_and_ids,
ElementsAre(RequestIdIs(AggregationServiceStorage::RequestId(1)),
RequestIdIs(AggregationServiceStorage::RequestId(2))));
run_loop.Quit();
}));
run_loop.Run();
}

TEST_F(AggregationServiceImplTest, SendReportsForWebUI) {
StoreReport(aggregation_service::CreateExampleRequest());

// IDs autoincrement from 1.
service()->SendReportsForWebUI({AggregationServiceStorage::RequestId(1)},
base::DoNothing());

assembler()->WaitForRequests(/*num_requests=*/1);
}

} // namespace content
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,17 @@ class AggregationServiceStorage {
base::Time strictly_after_time) = 0;

// Returns requests with report times on or before `not_after_time`. The
// returned reports are ordered by report time.
// returned requests are ordered by report time.
// TODO(crbug.com/1340046): Limit the number of in-progress reports kept in
// memory at the same time.
virtual std::vector<RequestAndId> GetRequestsReportingOnOrBefore(
base::Time not_after_time) = 0;

// Returns the requests with the given IDs. Empty vector is returned if `ids`
// is empty.
virtual std::vector<RequestAndId> GetRequests(
const std::vector<RequestId>& ids) = 0;

// Adjusts the report time of all reports with report times strictly before
// `now`. Each new report time is `now` + a random delay. The random delay for
// each report is picked independently from a uniform distribution between
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,38 @@ AggregationServiceStorageSql::GetRequestsReportingOnOrBefore(
return result;
}

std::vector<AggregationServiceStorage::RequestAndId>
AggregationServiceStorageSql::GetRequests(
const std::vector<AggregationServiceStorage::RequestId>& ids) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

if (!EnsureDatabaseOpen(DbCreationPolicy::kFailIfAbsent))
return {};

static constexpr char kGetRequestSql[] =
"SELECT request_id,request_proto FROM report_requests "
"WHERE request_id=?";
sql::Statement statement(
db_.GetCachedStatement(SQL_FROM_HERE, kGetRequestSql));

std::vector<AggregationServiceStorage::RequestAndId> result;
for (AggregationServiceStorage::RequestId id : ids) {
statement.Reset(/*clear_bound_vars=*/true);
statement.BindInt64(0, *id);
if (!statement.Step())
continue;
absl::optional<AggregatableReportRequest> parsed_request =
AggregatableReportRequest::Deserialize(statement.ColumnBlob(1));
if (!parsed_request)
continue;
result.push_back(AggregationServiceStorage::RequestAndId{
.request = std::move(*parsed_request),
.id = id,
});
}
return result;
}

absl::optional<base::Time>
AggregationServiceStorageSql::AdjustOfflineReportTimes(
base::Time now,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class CONTENT_EXPORT AggregationServiceStorageSql
base::Time strictly_after_time) override;
std::vector<AggregationServiceStorage::RequestAndId>
GetRequestsReportingOnOrBefore(base::Time not_after_time) override;
std::vector<AggregationServiceStorage::RequestAndId> GetRequests(
const std::vector<AggregationServiceStorage::RequestId>& ids) override;
absl::optional<base::Time> AdjustOfflineReportTimes(
base::Time now,
base::TimeDelta min_delay,
Expand Down

0 comments on commit 6a52078

Please sign in to comment.