Skip to content

Commit

Permalink
Fix a possible use-after-free with platform cert verification (#2692)
Browse files Browse the repository at this point in the history
Fix a possible use-after-free with platform cert verification by using a unique_ptr in the flat_hash_set of pending validations. The flat_hash_set does not ensure pointer stability, but the validation thread holds a pointer to the PendingVerification, which is problematic. This PR makes PendingVerification non-moveable and non-copyable which avoids this problem.

There is also another potential use-after free in that the task posted to the dispatcher deletes the PendingValidation, but the PendingValidation touches member variables after the call to post. Reordered the call to post to avoid this.

Fixes #2691

Signed-off-by: Ryan Hamilton rch@google.com
Signed-off-by: JP Simard <jp@jpsim.com>
  • Loading branch information
RyanTheOptimist authored and jpsim committed Nov 28, 2022
1 parent 79a98ae commit dab4c03
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,11 @@ ValidationResults PlatformBridgeCertValidator::doVerifyCertChain(
} else {
host = host_name;
}
PendingValidation validation(*this, std::move(certs), host, std::move(transport_socket_options),
std::move(callback));
auto insert_result = validations_.insert(std::move(validation));
ASSERT(insert_result.second);
PendingValidation& ref = const_cast<PendingValidation&>(*insert_result.first);
std::thread verification_thread(&PendingValidation::verifyCertsByPlatform, &ref);
auto validation = std::make_unique<PendingValidation>(
*this, std::move(certs), host, std::move(transport_socket_options), std::move(callback));
PendingValidation* validation_ptr = validation.get();
validations_.insert(std::move(validation));
std::thread verification_thread(&PendingValidation::verifyCertsByPlatform, validation_ptr);
std::thread::id thread_id = verification_thread.get_id();
validation_threads_[thread_id] = std::move(verification_thread);
return {ValidationResults::ValidationStatus::Pending, absl::nullopt, absl::nullopt};
Expand Down Expand Up @@ -138,7 +137,16 @@ void PlatformBridgeCertValidator::PendingValidation::verifyCertsByPlatform() {
void PlatformBridgeCertValidator::PendingValidation::postVerifyResultAndCleanUp(
bool success, absl::string_view error_details, uint8_t tls_alert,
OptRef<Stats::Counter> error_counter) {
ENVOY_LOG(trace,
"Finished platform cert validation for {}, post result callback to network thread",
host_name_);

if (parent_.platform_validator_->release_validator) {
parent_.platform_validator_->release_validator();
}
std::weak_ptr<size_t> weak_alive_indicator(parent_.alive_indicator_);

// Once this task runs, `this` will be deleted so this must be the last statement in the file.
result_callback_->dispatcher().post([this, weak_alive_indicator, success,
error = std::string(error_details), tls_alert, error_counter,
thread_id = std::this_thread::get_id()]() {
Expand All @@ -152,15 +160,8 @@ void PlatformBridgeCertValidator::PendingValidation::postVerifyResultAndCleanUp(
const_cast<Stats::Counter&>(error_counter.ref()).inc();
}
result_callback_->onCertValidationResult(success, error, tls_alert);
parent_.validations_.erase(*this);
parent_.validations_.erase(this);
});
ENVOY_LOG(trace,
"Finished platform cert validation for {}, post result callback to network thread",
host_name_);

if (parent_.platform_validator_->release_validator) {
parent_.platform_validator_->release_validator();
}
}

} // namespace Tls
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,15 @@ class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable<Logge
result_callback_(std::move(result_callback)),
transport_socket_options_(std::move(transport_socket_options)) {}

// Ensure that this class is never moved or copied to guarantee pointer stability.
PendingValidation(const PendingValidation&) = delete;
PendingValidation(PendingValidation&&) = delete;

void verifyCertsByPlatform();

void postVerifyResultAndCleanUp(bool success, absl::string_view error_details,
uint8_t tls_alert, OptRef<Stats::Counter> error_counter);

struct Hash {
size_t operator()(const PendingValidation& p) const {
return reinterpret_cast<size_t>(p.result_callback_.get());
}
};
struct Eq {
bool operator()(const PendingValidation& a, const PendingValidation& b) const {
return a.result_callback_.get() == b.result_callback_.get();
}
};

private:
Event::SchedulableCallbackPtr next_iteration_callback_;
PlatformBridgeCertValidator& parent_;
Expand All @@ -111,8 +104,7 @@ class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable<Logge
// latches the platform extension API.
const envoy_cert_validator* platform_validator_;
absl::flat_hash_map<std::thread::id, std::thread> validation_threads_;
absl::flat_hash_set<PendingValidation, PendingValidation::Hash, PendingValidation::Eq>
validations_;
absl::flat_hash_set<std::unique_ptr<PendingValidation>> validations_;
std::shared_ptr<size_t> alive_indicator_{new size_t(1)};
};

Expand Down

0 comments on commit dab4c03

Please sign in to comment.