Skip to content
Permalink
Browse files

Cache SSL EC explicitly

  • Loading branch information
jvgutierrez authored and maskit committed Sep 27, 2019
1 parent 752f5e6 commit 757256129811441f29eea288b1d7e19bc54fab9c
@@ -43,6 +43,7 @@
#include "P_UnixNetVConnection.h"
#include "P_UnixNet.h"
#include "P_ALPNSupport.h"
#include "P_SSLUtils.h"

// These are included here because older OpenSSL libraries don't have them.
// Don't copy these defines, or use their values directly, they are merely
@@ -302,26 +303,35 @@ class SSLNetVConnection : public UnixNetVConnection, public ALPNSupport
return ssl ? SSL_get_cipher_name(ssl) : nullptr;
}

void
setSSLCurveNID(ssl_curve_id curve_nid)
{
sslCurveNID = curve_nid;
}

ssl_curve_id
getSSLCurveNID() const
{
return sslCurveNID;
}

const char *
getSSLCurve() const
{
if (!ssl) {
return nullptr;
}

ssl_curve_id curve = getSSLCurveNID();
#ifndef OPENSSL_IS_BORINGSSL
int curve_nid = SSL_get_shared_curve(ssl, 0);

if (curve_nid == NID_undef) {
if (curve == NID_undef) {
return nullptr;
}
return OBJ_nid2sn(curve_nid);
return OBJ_nid2sn(curve);
#else
if (uint16_t curve_id = SSL_get_curve_id(ssl); curve_id != 0) {
return SSL_get_curve_name(curve_id);
} else {
if (curve == 0) {
return nullptr;
}
return SSL_get_curve_name(curve);
#endif
}

@@ -418,6 +428,7 @@ class SSLNetVConnection : public UnixNetVConnection, public ALPNSupport
std::string_view map_tls_protocol_to_tag(const char *proto_string) const;
bool update_rbio(bool move_to_socket);
void increment_ssl_version_metric(int version) const;
void fetch_ssl_curve();

enum SSLHandshakeStatus sslHandshakeStatus = SSL_HANDSHAKE_ONGOING;
bool sslClientRenegotiationAbort = false;
@@ -426,6 +437,7 @@ class SSLNetVConnection : public UnixNetVConnection, public ALPNSupport
IOBufferReader *handShakeHolder = nullptr;
IOBufferReader *handShakeReader = nullptr;
int handShakeBioStored = 0;
int sslCurveNID = NID_undef;

bool transparentPassThrough = false;

@@ -39,6 +39,15 @@ class SSLNetVConnection;

typedef int ssl_error_t;

#ifndef OPENSSL_IS_BORING
typedef int ssl_curve_id;
#else
typedef uint16_t ssl_curve_id;
#endif

// Return the SSL Curve ID associated to the specified SSL connection
ssl_curve_id SSLGetCurveNID(SSL *ssl);

/**
@brief Load SSL certificates from ssl_multicert.config and setup SSLCertLookup for SSLCertificateConfig
*/
@@ -1255,6 +1255,7 @@ SSLNetVConnection::sslServerHandShakeEvent(int &err)
unsigned len = 0;

increment_ssl_version_metric(SSL_version(ssl));
fetch_ssl_curve();

// If it's possible to negotiate both NPN and ALPN, then ALPN
// is preferred since it is the server's preference. The server
@@ -1806,6 +1807,14 @@ SSLNetVConnection::increment_ssl_version_metric(int version) const
}
}

void
SSLNetVConnection::fetch_ssl_curve()
{
if (!getSSLSessionCacheHit()) {
setSSLCurveNID(SSLGetCurveNID(ssl));
}
}

std::string_view
SSLNetVConnection::map_tls_protocol_to_tag(const char *proto_string) const
{
@@ -60,7 +60,7 @@ SSLSessionCache::getSessionBuffer(const SSLSessionID &sid, char *buffer, int &le
}

bool
SSLSessionCache::getSession(const SSLSessionID &sid, SSL_SESSION **sess) const
SSLSessionCache::getSession(const SSLSessionID &sid, SSL_SESSION **sess, ssl_session_cache_exdata **data) const
{
uint64_t hash = sid.hash();
uint64_t target_bucket = hash % nbuckets;
@@ -73,7 +73,7 @@ SSLSessionCache::getSession(const SSLSessionID &sid, SSL_SESSION **sess) const
target_bucket, bucket, buf, hash);
}

return bucket->getSession(sid, sess);
return bucket->getSession(sid, sess, data);
}

void
@@ -97,7 +97,7 @@ SSLSessionCache::removeSession(const SSLSessionID &sid)
}

void
SSLSessionCache::insertSession(const SSLSessionID &sid, SSL_SESSION *sess)
SSLSessionCache::insertSession(const SSLSessionID &sid, SSL_SESSION *sess, SSL *ssl)
{
uint64_t hash = sid.hash();
uint64_t target_bucket = hash % nbuckets;
@@ -110,11 +110,11 @@ SSLSessionCache::insertSession(const SSLSessionID &sid, SSL_SESSION *sess)
target_bucket, bucket, buf, hash);
}

bucket->insertSession(sid, sess);
bucket->insertSession(sid, sess, ssl);
}

void
SSLSessionBucket::insertSession(const SSLSessionID &id, SSL_SESSION *sess)
SSLSessionBucket::insertSession(const SSLSessionID &id, SSL_SESSION *sess, SSL *ssl)
{
size_t len = i2d_SSL_SESSION(sess, nullptr); // make sure we're not going to need more than SSL_MAX_SESSION_SIZE bytes
/* do not cache a session that's too big. */
@@ -158,12 +158,19 @@ SSLSessionBucket::insertSession(const SSLSessionID &id, SSL_SESSION *sess)
}

Ptr<IOBufferData> buf;
buf = new_IOBufferData(buffer_size_to_index(len, MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
Ptr<IOBufferData> buf_exdata;
size_t len_exdata = sizeof(ssl_session_cache_exdata);
buf = new_IOBufferData(buffer_size_to_index(len, MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
ink_release_assert(static_cast<size_t>(buf->block_size()) >= len);
unsigned char *loc = reinterpret_cast<unsigned char *>(buf->data());
i2d_SSL_SESSION(sess, &loc);
buf_exdata = new_IOBufferData(buffer_size_to_index(len, MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
ink_release_assert(static_cast<size_t>(buf_exdata->block_size()) >= len_exdata);
ssl_session_cache_exdata *exdata = reinterpret_cast<ssl_session_cache_exdata *>(buf_exdata->data());
// This could be moved to a function in charge of populating exdata
exdata->curve = SSLGetCurveNID(ssl);

ats_scoped_obj<SSLSession> ssl_session(new SSLSession(id, buf, len));
ats_scoped_obj<SSLSession> ssl_session(new SSLSession(id, buf, len, buf_exdata));

/* do the actual insert */
queue.enqueue(ssl_session.release());
@@ -207,7 +214,7 @@ SSLSessionBucket::getSessionBuffer(const SSLSessionID &id, char *buffer, int &le
}

bool
SSLSessionBucket::getSession(const SSLSessionID &id, SSL_SESSION **sess)
SSLSessionBucket::getSession(const SSLSessionID &id, SSL_SESSION **sess, ssl_session_cache_exdata **data)
{
char buf[id.len * 2 + 1];
buf[0] = '\0'; // just to be safe.
@@ -237,6 +244,10 @@ SSLSessionBucket::getSession(const SSLSessionID &id, SSL_SESSION **sess)
if (node->session_id == id) {
const unsigned char *loc = reinterpret_cast<const unsigned char *>(node->asn1_data->data());
*sess = d2i_SSL_SESSION(nullptr, &loc, node->len_asn1_data);
if (data != nullptr) {
ssl_session_cache_exdata *exdata = reinterpret_cast<ssl_session_cache_exdata *>(node->extra_data->data());
*data = exdata;
}

return true;
}
@@ -32,6 +32,10 @@

#define SSL_MAX_SESSION_SIZE 256

struct ssl_session_cache_exdata {
ssl_curve_id curve;
};

struct SSLSessionID : public TSSslSessionID {
SSLSessionID(const unsigned char *s, size_t l)
{
@@ -115,9 +119,10 @@ class SSLSession
SSLSessionID session_id;
Ptr<IOBufferData> asn1_data; /* this is the ASN1 representation of the SSL_CTX */
size_t len_asn1_data;
Ptr<IOBufferData> extra_data;

SSLSession(const SSLSessionID &id, const Ptr<IOBufferData> &ssl_asn1_data, size_t len_asn1)
: session_id(id), asn1_data(ssl_asn1_data), len_asn1_data(len_asn1)
SSLSession(const SSLSessionID &id, const Ptr<IOBufferData> &ssl_asn1_data, size_t len_asn1, Ptr<IOBufferData> &exdata)
: session_id(id), asn1_data(ssl_asn1_data), len_asn1_data(len_asn1), extra_data(exdata)
{
}

@@ -129,8 +134,8 @@ class SSLSessionBucket
public:
SSLSessionBucket();
~SSLSessionBucket();
void insertSession(const SSLSessionID &, SSL_SESSION *ctx);
bool getSession(const SSLSessionID &, SSL_SESSION **ctx);
void insertSession(const SSLSessionID &, SSL_SESSION *ctx, SSL *ssl);
bool getSession(const SSLSessionID &, SSL_SESSION **ctx, ssl_session_cache_exdata **data);
int getSessionBuffer(const SSLSessionID &, char *buffer, int &len);
void removeSession(const SSLSessionID &);

@@ -146,9 +151,9 @@ class SSLSessionBucket
class SSLSessionCache
{
public:
bool getSession(const SSLSessionID &sid, SSL_SESSION **sess) const;
bool getSession(const SSLSessionID &sid, SSL_SESSION **sess, ssl_session_cache_exdata **data) const;
int getSessionBuffer(const SSLSessionID &sid, char *buffer, int &len) const;
void insertSession(const SSLSessionID &sid, SSL_SESSION *sess);
void insertSession(const SSLSessionID &sid, SSL_SESSION *sess, SSL *ssl);
void removeSession(const SSLSessionID &sid);
SSLSessionCache();
~SSLSessionCache();
@@ -200,9 +200,11 @@ ssl_get_cached_session(SSL *ssl, const unsigned char *id, int len, int *copy)
hook = hook->m_link.next;
}

SSL_SESSION *session = nullptr;
if (session_cache->getSession(sid, &session)) {
SSL_SESSION *session = nullptr;
ssl_session_cache_exdata *exdata = nullptr;
if (session_cache->getSession(sid, &session, &exdata)) {
ink_assert(session);
ink_assert(exdata);

// Double check the timeout
if (ssl_session_timed_out(session)) {
@@ -217,6 +219,7 @@ ssl_get_cached_session(SSL *ssl, const unsigned char *id, int len, int *copy)
SSLNetVConnection *netvc = SSLNetVCAccess(ssl);
SSL_INCREMENT_DYN_STAT(ssl_session_cache_hit);
netvc->setSSLSessionCacheHit(true);
netvc->setSSLCurveNID(exdata->curve);
}
} else {
SSL_INCREMENT_DYN_STAT(ssl_session_cache_miss);
@@ -229,6 +232,7 @@ ssl_new_cached_session(SSL *ssl, SSL_SESSION *sess)
{
unsigned int len = 0;
const unsigned char *id = SSL_SESSION_get_id(sess, &len);

SSLSessionID sid(id, len);

if (diags->tag_activated("ssl.session_cache")) {
@@ -239,7 +243,7 @@ ssl_new_cached_session(SSL *ssl, SSL_SESSION *sess)
}

SSL_INCREMENT_DYN_STAT(ssl_session_cache_new_session);
session_cache->insertSession(sid, sess);
session_cache->insertSession(sid, sess, ssl);

// Call hook after new session is created
APIHook *hook = ssl_hooks->get(TSSslHookInternalID(TS_SSL_SESSION_HOOK));
@@ -1975,3 +1979,13 @@ SSLMultiCertConfigLoader::clear_pw_references(SSL_CTX *ssl_ctx)
SSL_CTX_set_default_passwd_cb(ssl_ctx, nullptr);
SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, nullptr);
}

ssl_curve_id
SSLGetCurveNID(SSL *ssl)
{
#ifndef OPENSSL_IS_BORINGSSL
return SSL_get_shared_curve(ssl, 0);
#else
return SSL_get_curve_id(ssl);
#endif
}
@@ -9551,7 +9551,7 @@ TSSslSessionGet(const TSSslSessionID *session_id)
{
SSL_SESSION *session = nullptr;
if (session_id && session_cache) {
session_cache->getSession(reinterpret_cast<const SSLSessionID &>(*session_id), &session);
session_cache->getSession(reinterpret_cast<const SSLSessionID &>(*session_id), &session, nullptr);
}
return reinterpret_cast<TSSslSession>(session);
}
@@ -9568,7 +9568,7 @@ TSSslSessionGetBuffer(const TSSslSessionID *session_id, char *buffer, int *len_p
}

TSReturnCode
TSSslSessionInsert(const TSSslSessionID *session_id, TSSslSession add_session)
TSSslSessionInsert(const TSSslSessionID *session_id, TSSslSession add_session, TSSslConnection ssl_conn)
{
// Don't insert if there is no session id or the cache is not yet set up
if (session_id && session_cache) {
@@ -9579,7 +9579,8 @@ TSSslSessionInsert(const TSSslSessionID *session_id, TSSslSession add_session)
Debug("ssl.session_cache.insert", "TSSslSessionInsert: Inserting session '%s' ", buf);
}
SSL_SESSION *session = reinterpret_cast<SSL_SESSION *>(add_session);
session_cache->insertSession(reinterpret_cast<const SSLSessionID &>(*session_id), session);
SSL *ssl = reinterpret_cast<SSL *>(ssl_conn);
session_cache->insertSession(reinterpret_cast<const SSLSessionID &>(*session_id), session, ssl);
// insertSession returns void, assume all went well
return TS_SUCCESS;
} else {

0 comments on commit 7572561

Please sign in to comment.
You can’t perform that action at this time.