Skip to content

Commit

Permalink
[Security] Reapply Move ownership of tsi_ssl_client_handshaker_factory (
Browse files Browse the repository at this point in the history
#34726)

This reverts commit 601aaf8, which
results in rolling forward #34408.

The 601aaf reversion happened because of a deadlock found in Python. The
root cause ended up being an issue with the Python wrapper and was fixed
in #34712 , so this can be rolled forward again
  • Loading branch information
gtcooke94 committed Oct 18, 2023
1 parent a0c1027 commit 74c1de6
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 70 deletions.
124 changes: 115 additions & 9 deletions src/core/lib/security/credentials/ssl/ssl_credentials.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "src/core/lib/surface/api_trace.h"
#include "src/core/tsi/ssl/session_cache/ssl_session_cache.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security_interface.h"

//
// SSL Channel Credentials.
Expand All @@ -47,6 +48,26 @@ grpc_ssl_credentials::grpc_ssl_credentials(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const grpc_ssl_verify_peer_options* verify_options) {
build_config(pem_root_certs, pem_key_cert_pair, verify_options);
// Use default (e.g. OS) root certificates if the user did not pass any root
// certificates.
if (config_.pem_root_certs == nullptr) {
const char* pem_root_certs =
grpc_core::DefaultSslRootStore::GetPemRootCerts();
if (pem_root_certs == nullptr) {
gpr_log(GPR_ERROR, "Could not get default pem root certs.");
} else {
char* default_roots = gpr_strdup(pem_root_certs);
config_.pem_root_certs = default_roots;
root_store_ = grpc_core::DefaultSslRootStore::GetRootStore();
}
} else {
config_.pem_root_certs = config_.pem_root_certs;
root_store_ = nullptr;
}

client_handshaker_initialization_status_ = InitializeClientHandshakerFactory(
&config_, config_.pem_root_certs, root_store_, nullptr,
&client_handshaker_factory_);
}

grpc_ssl_credentials::~grpc_ssl_credentials() {
Expand All @@ -56,26 +77,67 @@ grpc_ssl_credentials::~grpc_ssl_credentials() {
config_.verify_options.verify_peer_destruct(
config_.verify_options.verify_peer_callback_userdata);
}
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_);
}

grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_ssl_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, grpc_core::ChannelArgs* args) {
if (config_.pem_root_certs == nullptr) {
gpr_log(GPR_ERROR,
"No root certs in config. Client-side security connector must have "
"root certs.");
return nullptr;
}
absl::optional<std::string> overridden_target_name =
args->GetOwnedString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG);
auto* ssl_session_cache = args->GetObject<tsi::SslSessionLRUCache>();
grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
grpc_ssl_channel_security_connector_create(
this->Ref(), std::move(call_creds), &config_, target,
overridden_target_name.has_value() ? overridden_target_name->c_str()
: nullptr,
ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr());
if (sc == nullptr) {
return sc;
tsi_ssl_session_cache* session_cache =
ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr();

grpc_core::RefCountedPtr<grpc_channel_security_connector> security_connector =
nullptr;
if (session_cache != nullptr) {
// We need a separate factory and SSL_CTX if there's a cache in the channel
// args. SSL_CTX should live with the factory and that should live on the
// credentials. However, there is a way to configure a session cache in the
// channel args, so that prevents us from also keeping the session cache at
// the credentials level. In the case of a session cache, we still need to
// keep a separate factory and SSL_CTX at the subchannel/security_connector
// level.
tsi_ssl_client_handshaker_factory* factory_with_cache = nullptr;
grpc_security_status status = InitializeClientHandshakerFactory(
&config_, config_.pem_root_certs, root_store_, session_cache,
&factory_with_cache);
if (status != GRPC_SECURITY_OK) {
gpr_log(GPR_ERROR,
"InitializeClientHandshakerFactory returned bad "
"status.");
return nullptr;
}
security_connector = grpc_ssl_channel_security_connector_create(
this->Ref(), std::move(call_creds), &config_, target,
overridden_target_name.has_value() ? overridden_target_name->c_str()
: nullptr,
factory_with_cache);
tsi_ssl_client_handshaker_factory_unref(factory_with_cache);
} else {
if (client_handshaker_initialization_status_ != GRPC_SECURITY_OK) {
return nullptr;
}
security_connector = grpc_ssl_channel_security_connector_create(
this->Ref(), std::move(call_creds), &config_, target,
overridden_target_name.has_value() ? overridden_target_name->c_str()
: nullptr,
client_handshaker_factory_);
}

if (security_connector == nullptr) {
return security_connector;
}
*args = args->Set(GRPC_ARG_HTTP2_SCHEME, "https");
return sc;
return security_connector;
}

grpc_core::UniqueTypeName grpc_ssl_credentials::Type() {
Expand Down Expand Up @@ -118,6 +180,50 @@ void grpc_ssl_credentials::set_max_tls_version(
config_.max_tls_version = max_tls_version;
}

grpc_security_status grpc_ssl_credentials::InitializeClientHandshakerFactory(
const grpc_ssl_config* config, const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store,
tsi_ssl_session_cache* ssl_session_cache,
tsi_ssl_client_handshaker_factory** handshaker_factory) {
// This class level factory can't have a session cache by design. If we want
// to init one with a cache we need to make a new one
if (client_handshaker_factory_ != nullptr && ssl_session_cache == nullptr) {
return GRPC_SECURITY_OK;
}

bool has_key_cert_pair = config->pem_key_cert_pair != nullptr &&
config->pem_key_cert_pair->private_key != nullptr &&
config->pem_key_cert_pair->cert_chain != nullptr;
tsi_ssl_client_handshaker_options options;
if (pem_root_certs == nullptr) {
gpr_log(
GPR_ERROR,
"Handshaker factory creation failed. pem_root_certs cannot be nullptr");
return GRPC_SECURITY_ERROR;
}
options.pem_root_certs = pem_root_certs;
options.root_store = root_store;
options.alpn_protocols =
grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
if (has_key_cert_pair) {
options.pem_key_cert_pair = config->pem_key_cert_pair;
}
options.cipher_suites = grpc_get_ssl_cipher_suites();
options.session_cache = ssl_session_cache;
options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version);
options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version);
const tsi_result result =
tsi_create_ssl_client_handshaker_factory_with_options(&options,
handshaker_factory);
gpr_free(options.alpn_protocols);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
return GRPC_SECURITY_ERROR;
}
return GRPC_SECURITY_OK;
}

// Deprecated in favor of grpc_ssl_credentials_create_ex. Will be removed
// once all of its call sites are migrated to grpc_ssl_credentials_create_ex.
grpc_channel_credentials* grpc_ssl_credentials_create(
Expand Down
14 changes: 14 additions & 0 deletions src/core/lib/security/credentials/ssl/ssl_credentials.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,21 @@ class grpc_ssl_credentials : public grpc_channel_credentials {
grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const grpc_ssl_verify_peer_options* verify_options);

// InitializeClientHandshakerFactory constructs a client handshaker factory
// that is stored on this credentials object. This handshaker factory will be
// used when creating handshakers using these credentials except in the case
// that there is a session cache. If a session cache is used, a new handshaker
// factory will be created and used that contains that session cache.
grpc_security_status InitializeClientHandshakerFactory(
const grpc_ssl_config* config, const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store,
tsi_ssl_session_cache* ssl_session_cache,
tsi_ssl_client_handshaker_factory** handshaker_factory);

grpc_ssl_config config_;
tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr;
const tsi_ssl_root_certs_store* root_store_ = nullptr;
grpc_security_status client_handshaker_initialization_status_;
};

struct grpc_ssl_server_certificate_config {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ class grpc_ssl_channel_security_connector final
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name)
const char* overridden_target_name,
tsi_ssl_client_handshaker_factory* client_handshaker_factory)
: grpc_channel_security_connector(GRPC_SSL_URL_SCHEME,
std::move(channel_creds),
std::move(request_metadata_creds)),
client_handshaker_factory_(client_handshaker_factory),
overridden_target_name_(
overridden_target_name == nullptr ? "" : overridden_target_name),
verify_options_(&config->verify_options) {
Expand All @@ -98,39 +100,6 @@ class grpc_ssl_channel_security_connector final
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_);
}

grpc_security_status InitializeHandshakerFactory(
const grpc_ssl_config* config, const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store,
tsi_ssl_session_cache* ssl_session_cache) {
bool has_key_cert_pair =
config->pem_key_cert_pair != nullptr &&
config->pem_key_cert_pair->private_key != nullptr &&
config->pem_key_cert_pair->cert_chain != nullptr;
tsi_ssl_client_handshaker_options options;
GPR_DEBUG_ASSERT(pem_root_certs != nullptr);
options.pem_root_certs = pem_root_certs;
options.root_store = root_store;
options.alpn_protocols =
grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
if (has_key_cert_pair) {
options.pem_key_cert_pair = config->pem_key_cert_pair;
}
options.cipher_suites = grpc_get_ssl_cipher_suites();
options.session_cache = ssl_session_cache;
options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version);
options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version);
const tsi_result result =
tsi_create_ssl_client_handshaker_factory_with_options(
&options, &client_handshaker_factory_);
gpr_free(options.alpn_protocols);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
return GRPC_SECURITY_ERROR;
}
return GRPC_SECURITY_OK;
}

void add_handshakers(const grpc_core::ChannelArgs& args,
grpc_pollset_set* /*interested_parties*/,
grpc_core::HandshakeManager* handshake_mgr) override {
Expand Down Expand Up @@ -204,7 +173,7 @@ class grpc_ssl_channel_security_connector final
}

private:
tsi_ssl_client_handshaker_factory* client_handshaker_factory_;
tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr;
std::string target_name_;
std::string overridden_target_name_;
const verify_peer_options* verify_options_;
Expand Down Expand Up @@ -410,36 +379,17 @@ grpc_ssl_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
tsi_ssl_session_cache* ssl_session_cache) {
tsi_ssl_client_handshaker_factory* client_factory) {
if (config == nullptr || target_name == nullptr) {
gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name.");
return nullptr;
}

const char* pem_root_certs;
const tsi_ssl_root_certs_store* root_store;
if (config->pem_root_certs == nullptr) {
// Use default root certificates.
pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts();
if (pem_root_certs == nullptr) {
gpr_log(GPR_ERROR, "Could not get default pem root certs.");
return nullptr;
}
root_store = grpc_core::DefaultSslRootStore::GetRootStore();
} else {
pem_root_certs = config->pem_root_certs;
root_store = nullptr;
}

grpc_core::RefCountedPtr<grpc_ssl_channel_security_connector> c =
grpc_core::MakeRefCounted<grpc_ssl_channel_security_connector>(
std::move(channel_creds), std::move(request_metadata_creds), config,
target_name, overridden_target_name);
const grpc_security_status result = c->InitializeHandshakerFactory(
config, pem_root_certs, root_store, ssl_session_cache);
if (result != GRPC_SECURITY_OK) {
return nullptr;
}
target_name, overridden_target_name,
tsi_ssl_client_handshaker_factory_ref(client_factory));
return c;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ grpc_ssl_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
tsi_ssl_session_cache* ssl_session_cache);
tsi_ssl_client_handshaker_factory* factory);

// Config for ssl servers.
struct grpc_ssl_server_config {
Expand Down
7 changes: 7 additions & 0 deletions src/core/tsi/ssl_transport_security.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,13 @@ void tsi_ssl_client_handshaker_factory_unref(
tsi_ssl_handshaker_factory_unref(&factory->base);
}

tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref(
tsi_ssl_client_handshaker_factory* client_factory) {
if (client_factory == nullptr) return nullptr;
return reinterpret_cast<tsi_ssl_client_handshaker_factory*>(
tsi_ssl_handshaker_factory_ref(&client_factory->base));
}

static void tsi_ssl_client_handshaker_factory_destroy(
tsi_ssl_handshaker_factory* factory) {
if (factory == nullptr) return;
Expand Down
4 changes: 4 additions & 0 deletions src/core/tsi/ssl_transport_security.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
const char* server_name_indication, size_t network_bio_buf_size,
size_t ssl_bio_buf_size, tsi_handshaker** handshaker);

// Increments reference count of the client handshaker factory.
tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref(
tsi_ssl_client_handshaker_factory* client_factory);

// Decrements reference count of the handshaker factory. Handshaker factory will
// be destroyed once no references exist.
void tsi_ssl_client_handshaker_factory_unref(
Expand Down
7 changes: 6 additions & 1 deletion test/core/tsi/ssl_transport_security_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,9 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() {
TSI_OK);
}

client_handshaker_factory =
tsi_ssl_client_handshaker_factory_ref(client_handshaker_factory);

tsi_handshaker_destroy(handshaker[1]);
ASSERT_FALSE(handshaker_factory_destructor_called);

Expand All @@ -970,8 +973,10 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() {
ASSERT_FALSE(handshaker_factory_destructor_called);

tsi_handshaker_destroy(handshaker[2]);
ASSERT_TRUE(handshaker_factory_destructor_called);
ASSERT_FALSE(handshaker_factory_destructor_called);

tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory);
ASSERT_TRUE(handshaker_factory_destructor_called);
gpr_free(cert_chain);
}

Expand Down
3 changes: 1 addition & 2 deletions test/cpp/end2end/ssl_credentials_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ void DoRpc(const std::string& server_addr,
grpc::testing::EchoResponse response;
request.set_message(kMessage);
ClientContext context;
context.set_deadline(grpc_timeout_milliseconds_to_deadline(
/*time_ms=*/5000 * grpc_test_slowdown_factor()));
context.set_deadline(grpc_timeout_seconds_to_deadline(/*time_s=*/10));
grpc::Status result = stub->Echo(&context, request, &response);
EXPECT_TRUE(result.ok());
if (!result.ok()) {
Expand Down

0 comments on commit 74c1de6

Please sign in to comment.