Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls: fix detection of the upstream connection close event. #13858

Merged
merged 8 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/root/version_history/current.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Bug Fixes
* http: fixed URL parsing for HTTP/1.1 fully qualified URLs and connect requests containing IPv6 addresses.
* http: sending CONNECT_ERROR for HTTP/2 where appropriate during CONNECT requests.
* proxy_proto: fixed a bug where the wrong downstream address got sent to upstream connections.
* tls: fix detection of the upstream connection close event.
PiotrSikora marked this conversation as resolved.
Show resolved Hide resolved
* tls: fix read resumption after triggering buffer high-watermark and all remaining request/response bytes are stored in the SSL connection's internal buffers.

Removed Config or Runtime
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/transport_sockets/tls/ssl_handshaker.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SslHandshakerImpl : public Ssl::ConnectionInfo, public Ssl::Handshaker {
// Ssl::Handshaker
Network::PostIoAction doHandshake() override;

Ssl::SocketState state() { return state_; }
Ssl::SocketState state() const { return state_; }
void setState(Ssl::SocketState state) { state_ = state; }
SSL* ssl() const { return ssl_.get(); }
Ssl::HandshakeCallbacks* handshakeCallbacks() { return handshake_callbacks_; }
Expand Down
10 changes: 9 additions & 1 deletion source/extensions/transport_sockets/tls/ssl_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,18 @@ Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) {
case SSL_ERROR_WANT_READ:
break;
case SSL_ERROR_ZERO_RETURN:
// Graceful shutdown using close_notify TLS alert.
end_stream = true;
break;
case SSL_ERROR_SYSCALL:
if (result.error_.value() == 0) {
// Non-graceful shutdown by closing the underlying socket.
end_stream = true;
PiotrSikora marked this conversation as resolved.
Show resolved Hide resolved
break;
}
FALLTHRU;
case SSL_ERROR_WANT_WRITE:
// Renegotiation has started. We don't handle renegotiation so just fall through.
// Renegotiation has started. We don't handle renegotiation so just fall through.
PiotrSikora marked this conversation as resolved.
Show resolved Hide resolved
default:
drainErrorQueue();
action = PostIoAction::Close;
Expand Down
179 changes: 179 additions & 0 deletions test/extensions/transport_sockets/tls/ssl_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2540,6 +2540,185 @@ TEST_P(SslSocketTest, HalfClose) {
dispatcher_->run(Event::Dispatcher::RunType::Block);
}

TEST_P(SslSocketTest, ShutdownWithCloseNotify) {
const std::string server_ctx_yaml = R"EOF(
common_tls_context:
tls_certificates:
certificate_chain:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_cert.pem"
private_key:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_key.pem"
validation_context:
trusted_ca:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/ca_certificates.pem"
)EOF";

envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context);
auto server_cfg = std::make_unique<ServerContextConfigImpl>(server_tls_context, factory_context_);
ContextManagerImpl manager(time_system_);
Stats::TestUtil::TestStore server_stats_store;
ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager,
server_stats_store, std::vector<std::string>{});

auto socket = std::make_shared<Network::TcpListenSocket>(
Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true);
Network::MockTcpListenerCallbacks listener_callbacks;
Network::MockConnectionHandler connection_handler;
Network::ListenerPtr listener =
dispatcher_->createListener(socket, listener_callbacks, true, ENVOY_TCP_BACKLOG_SIZE);
std::shared_ptr<Network::MockReadFilter> server_read_filter(new Network::MockReadFilter());
std::shared_ptr<Network::MockReadFilter> client_read_filter(new Network::MockReadFilter());

const std::string client_ctx_yaml = R"EOF(
common_tls_context:
)EOF";

envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml), tls_context);
auto client_cfg = std::make_unique<ClientContextConfigImpl>(tls_context, factory_context_);
Stats::TestUtil::TestStore client_stats_store;
ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager,
client_stats_store);
Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection(
socket->localAddress(), Network::Address::InstanceConstSharedPtr(),
client_ssl_socket_factory.createTransportSocket(nullptr), nullptr);
Network::MockConnectionCallbacks client_connection_callbacks;
client_connection->enableHalfClose(true);
client_connection->addReadFilter(client_read_filter);
client_connection->addConnectionCallbacks(client_connection_callbacks);
client_connection->connect();

Network::ConnectionPtr server_connection;
Network::MockConnectionCallbacks server_connection_callbacks;
EXPECT_CALL(listener_callbacks, onAccept_(_))
.WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void {
server_connection = dispatcher_->createServerConnection(
std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr),
stream_info_);
server_connection->enableHalfClose(true);
server_connection->addReadFilter(server_read_filter);
server_connection->addConnectionCallbacks(server_connection_callbacks);
}));
EXPECT_CALL(*server_read_filter, onNewConnection());
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Buffer::OwnedImpl data("hello");
server_connection->write(data, true);
EXPECT_EQ(data.length(), 0);
}));

EXPECT_CALL(*client_read_filter, onNewConnection())
.WillOnce(Return(Network::FilterStatus::Continue));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*client_read_filter, onData(BufferStringEqual("hello"), true))
.WillOnce(Invoke([&](Buffer::Instance& read_buffer, bool) -> Network::FilterStatus {
read_buffer.drain(read_buffer.length());
client_connection->close(Network::ConnectionCloseType::NoFlush);
return Network::FilterStatus::StopIteration;
}));
EXPECT_CALL(*server_read_filter, onData(_, true));

EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose));
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
server_connection->close(Network::ConnectionCloseType::NoFlush);
dispatcher_->exit();
}));

dispatcher_->run(Event::Dispatcher::RunType::Block);
}

TEST_P(SslSocketTest, ShutdownWithoutCloseNotify) {
const std::string server_ctx_yaml = R"EOF(
common_tls_context:
tls_certificates:
certificate_chain:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_cert.pem"
private_key:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_key.pem"
validation_context:
trusted_ca:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/ca_certificates.pem"
)EOF";

envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context);
auto server_cfg = std::make_unique<ServerContextConfigImpl>(server_tls_context, factory_context_);
ContextManagerImpl manager(time_system_);
Stats::TestUtil::TestStore server_stats_store;
ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager,
server_stats_store, std::vector<std::string>{});

auto socket = std::make_shared<Network::TcpListenSocket>(
Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true);
Network::MockTcpListenerCallbacks listener_callbacks;
Network::MockConnectionHandler connection_handler;
Network::ListenerPtr listener =
dispatcher_->createListener(socket, listener_callbacks, true, ENVOY_TCP_BACKLOG_SIZE);
std::shared_ptr<Network::MockReadFilter> server_read_filter(new Network::MockReadFilter());
std::shared_ptr<Network::MockReadFilter> client_read_filter(new Network::MockReadFilter());

const std::string client_ctx_yaml = R"EOF(
common_tls_context:
)EOF";

envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml), tls_context);
auto client_cfg = std::make_unique<ClientContextConfigImpl>(tls_context, factory_context_);
Stats::TestUtil::TestStore client_stats_store;
ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager,
client_stats_store);
Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection(
socket->localAddress(), Network::Address::InstanceConstSharedPtr(),
client_ssl_socket_factory.createTransportSocket(nullptr), nullptr);
Network::MockConnectionCallbacks client_connection_callbacks;
client_connection->enableHalfClose(true);
client_connection->addReadFilter(client_read_filter);
client_connection->addConnectionCallbacks(client_connection_callbacks);
client_connection->connect();

Network::ConnectionPtr server_connection;
Network::MockConnectionCallbacks server_connection_callbacks;
EXPECT_CALL(listener_callbacks, onAccept_(_))
.WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void {
server_connection = dispatcher_->createServerConnection(
std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr),
stream_info_);
server_connection->enableHalfClose(true);
server_connection->addReadFilter(server_read_filter);
server_connection->addConnectionCallbacks(server_connection_callbacks);
}));
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Buffer::OwnedImpl data("hello");
server_connection->write(data, false);
EXPECT_EQ(data.length(), 0);
// Close without sending close_notify alert.
const SslHandshakerImpl* ssl_socket =
dynamic_cast<const SslHandshakerImpl*>(server_connection->ssl().get());
EXPECT_EQ(ssl_socket->state(), Ssl::SocketState::HandshakeComplete);
SSL_set_quiet_shutdown(ssl_socket->ssl(), 1);
server_connection->close(Network::ConnectionCloseType::NoFlush);
PiotrSikora marked this conversation as resolved.
Show resolved Hide resolved
}));

EXPECT_CALL(*client_read_filter, onNewConnection())
.WillOnce(Return(Network::FilterStatus::Continue));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*client_read_filter, onData(BufferStringEqual("hello"), true))
.WillOnce(Invoke([&](Buffer::Instance& read_buffer, bool) -> Network::FilterStatus {
read_buffer.drain(read_buffer.length());
client_connection->close(Network::ConnectionCloseType::NoFlush);
return Network::FilterStatus::StopIteration;
}));

PiotrSikora marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_->exit(); }));

dispatcher_->run(Event::Dispatcher::RunType::Block);
}

TEST_P(SslSocketTest, ClientAuthMultipleCAs) {
const std::string server_ctx_yaml = R"EOF(
common_tls_context:
Expand Down