Skip to content

Commit

Permalink
[Enrollment] Perform state key retrieval as late as possible
Browse files Browse the repository at this point in the history
This CL therefore moves the state key retrieval after the PSM requests
which determine whether we actually need to fetch the enrollment state.

Bug: b/288250694
Tests: Refactored unit tests

Change-Id: Ib7d89c86bbfcd860241ed0c17c7db2b7d1dd3fe6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4987602
Reviewed-by: Sergiy Belozorov <sergiyb@chromium.org>
Commit-Queue: Roland Bock <rbock@google.com>
Cr-Commit-Position: refs/heads/main@{#1216870}
  • Loading branch information
Roland Bock authored and Chromium LUCI CQ committed Oct 30, 2023
1 parent b5c352b commit 8f8d2e5
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 95 deletions.
116 changes: 59 additions & 57 deletions chrome/browser/ash/policy/enrollment/enrollment_state_fetcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,49 +254,6 @@ class DeviceIdentifiers {
}
};

// Class to obtain state keys.
//
// This is a step in enrollment state fetch (see Sequence class below).
class StateKeys {
static constexpr int kMaxAttempts = 10;

public:
StateKeys() = default;
StateKeys(const StateKeys&) = delete;
StateKeys& operator=(const StateKeys&) = delete;

using CompletionCallback =
base::OnceCallback<void(absl::optional<std::string>)>;

// This will try up to `kMaxAttempts` times to obtain the state keys. If
// successful, it will return the current state key by calling the completion
// callback.
// Otherwise, it will return `absl::nullopt`.
void Retrieve(ServerBackedStateKeysBroker* state_key_broker,
CompletionCallback completion_callback) {
++attempts_;
state_key_broker->RequestStateKeys(base::BindOnce(
&StateKeys::OnStateKeysRetrieved, weak_factory_.GetWeakPtr(),
state_key_broker, std::move(completion_callback)));
}

private:
void OnStateKeysRetrieved(ServerBackedStateKeysBroker* state_key_broker,
CompletionCallback completion_callback,
const std::vector<std::string>& state_keys) {
if (state_keys.empty() || state_keys[0].empty()) {
if (attempts_ >= kMaxAttempts) {
return std::move(completion_callback).Run(absl::nullopt);
}
return Retrieve(state_key_broker, std::move(completion_callback));
}
return std::move(completion_callback).Run(state_keys[0]);
}

int attempts_ = 0;
base::WeakPtrFactory<StateKeys> weak_factory_{this};
};

// Class to send RLWE OPRF request as part of PSM protocol.
//
// This is a step in enrollment state fetch (see Sequence class below).
Expand All @@ -313,6 +270,10 @@ class RlweOprf {
void Request(DeterminationContext& context,
CompletionCallback completion_callback) {
DCHECK(completion_callback);

context.psm_rlwe_client = context.rlwe_client_factory.Run(
private_membership::rlwe::CROS_DEVICE_STATE_UNIFIED,
ConstructPlainttextId(context.rlz_brand_code, context.serial_number));
const auto oprf_request = context.psm_rlwe_client->CreateOprfRequest();
if (!oprf_request.ok()) {
LOG(ERROR) << "Failed to create PSM RLWE OPRF request: "
Expand Down Expand Up @@ -408,6 +369,7 @@ class RlweQuery {
oprf_response,
CompletionCallback completion_callback) {
DCHECK(completion_callback);
DCHECK(context.psm_rlwe_client);
const auto query_request =
context.psm_rlwe_client->CreateQueryRequest(oprf_response);

Expand Down Expand Up @@ -516,6 +478,49 @@ class RlweQuery {
base::WeakPtrFactory<RlweQuery> weak_factory_{this};
};

// Class to obtain state keys.
//
// This is a step in enrollment state fetch (see Sequence class below).
class StateKeys {
static constexpr int kMaxAttempts = 10;

public:
StateKeys() = default;
StateKeys(const StateKeys&) = delete;
StateKeys& operator=(const StateKeys&) = delete;

using CompletionCallback =
base::OnceCallback<void(absl::optional<std::string>)>;

// This will try up to `kMaxAttempts` times to obtain the state keys. If
// successful, it will return the current state key by calling the completion
// callback.
// Otherwise, it will return `absl::nullopt`.
void Retrieve(ServerBackedStateKeysBroker* state_key_broker,
CompletionCallback completion_callback) {
++attempts_;
state_key_broker->RequestStateKeys(base::BindOnce(
&StateKeys::OnStateKeysRetrieved, weak_factory_.GetWeakPtr(),
state_key_broker, std::move(completion_callback)));
}

private:
void OnStateKeysRetrieved(ServerBackedStateKeysBroker* state_key_broker,
CompletionCallback completion_callback,
const std::vector<std::string>& state_keys) {
if (state_keys.empty() || state_keys[0].empty()) {
if (attempts_ >= kMaxAttempts) {
return std::move(completion_callback).Run(absl::nullopt);
}
return Retrieve(state_key_broker, std::move(completion_callback));
}
return std::move(completion_callback).Run(state_keys[0]);
}

int attempts_ = 0;
base::WeakPtrFactory<StateKeys> weak_factory_{this};
};

// Class to send state request to DMServer.
//
// This is a step in enrollment state fetch (see Sequence class below).
Expand Down Expand Up @@ -907,20 +912,6 @@ class EnrollmentStateFetcherImpl::Sequence {
return ReportResult(AutoEnrollmentState::kNoEnrollment);
}

state_keys_.Retrieve(context_.state_key_broker,
base::BindOnce(&Sequence::OnStateKeysRetrieved,
weak_factory_.GetWeakPtr()));
}

void OnStateKeysRetrieved(absl::optional<std::string> state_key) {
ReportStepDurationAndResetTimer(kUMASuffixStateKeyRetrieval);
base::UmaHistogramBoolean(kUMAStateDeterminationStateKeysRetrieved,
state_key.has_value());
LOG_IF(WARNING, !state_key) << "Failed to obtain state keys";
context_.state_key = state_key;
context_.psm_rlwe_client = context_.rlwe_client_factory.Run(
private_membership::rlwe::CROS_DEVICE_STATE_UNIFIED,
ConstructPlainttextId(context_.rlz_brand_code, context_.serial_number));
oprf_.Request(context_, base::BindOnce(&Sequence::OnOprfRequestDone,
weak_factory_.GetWeakPtr()));
}
Expand Down Expand Up @@ -956,6 +947,17 @@ class EnrollmentStateFetcherImpl::Sequence {
return ReportResult(AutoEnrollmentState::kNoEnrollment);
}
query_.StoreResponse(local_state_, result.value());
state_keys_.Retrieve(context_.state_key_broker,
base::BindOnce(&Sequence::OnStateKeysRetrieved,
weak_factory_.GetWeakPtr()));
}

void OnStateKeysRetrieved(absl::optional<std::string> state_key) {
ReportStepDurationAndResetTimer(kUMASuffixStateKeyRetrieval);
base::UmaHistogramBoolean(kUMAStateDeterminationStateKeysRetrieved,
state_key.has_value());
LOG_IF(WARNING, !state_key) << "Failed to obtain state keys";
context_.state_key = state_key;
state_.Request(context_, base::BindOnce(&Sequence::OnStateRequestDone,
weak_factory_.GetWeakPtr()));
}
Expand Down

0 comments on commit 8f8d2e5

Please sign in to comment.