diff --git a/source/common/filter/auth/client_ssl.cc b/source/common/filter/auth/client_ssl.cc index cd6de5344c62..f23188b48d64 100644 --- a/source/common/filter/auth/client_ssl.cc +++ b/source/common/filter/auth/client_ssl.cc @@ -108,22 +108,33 @@ Network::FilterStatus Instance::onNewConnection() { if (!read_callbacks_->connection().ssl()) { config_->stats().auth_no_ssl_.inc(); return Network::FilterStatus::Continue; + } else { + // Otherwise we need to wait for handshake to be complete before proceeding. + return Network::FilterStatus::StopIteration; + } +} + +void Instance::onEvent(uint32_t events) { + if (!(events & Network::ConnectionEvent::Connected)) { + return; } + ASSERT(read_callbacks_->connection().ssl()); if (config_->ipWhiteList().contains(read_callbacks_->connection().remoteAddress())) { config_->stats().auth_ip_white_list_.inc(); - return Network::FilterStatus::Continue; + read_callbacks_->continueReading(); + return; } if (!config_->allowedPrincipals().allowed( read_callbacks_->connection().ssl()->sha256PeerCertificateDigest())) { config_->stats().auth_digest_no_match_.inc(); read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); - return Network::FilterStatus::StopIteration; + return; } config_->stats().auth_digest_match_.inc(); - return Network::FilterStatus::Continue; + read_callbacks_->continueReading(); } } // Client Ssl diff --git a/source/common/filter/auth/client_ssl.h b/source/common/filter/auth/client_ssl.h index 6eca3187982a..faae983ab0a0 100644 --- a/source/common/filter/auth/client_ssl.h +++ b/source/common/filter/auth/client_ssl.h @@ -97,7 +97,7 @@ typedef std::shared_ptr ConfigPtr; /** * A client SSL auth filter instance. One per connection. */ -class Instance : public Network::ReadFilter { +class Instance : public Network::ReadFilter, public Network::ConnectionCallbacks { public: Instance(ConfigPtr config) : config_(config) {} @@ -106,8 +106,13 @@ class Instance : public Network::ReadFilter { Network::FilterStatus onNewConnection() override; void initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) override { read_callbacks_ = &callbacks; + read_callbacks_->connection().addConnectionCallbacks(*this); } + // Network::ConnectionCallbacks + void onBufferChange(Network::ConnectionBufferType, uint64_t, int64_t) override {} + void onEvent(uint32_t events) override; + private: ConfigPtr config_; Network::ReadFilterCallbacks* read_callbacks_{}; diff --git a/source/common/ssl/connection_impl.cc b/source/common/ssl/connection_impl.cc index d001f8f2c74c..014d0ab42a9c 100644 --- a/source/common/ssl/connection_impl.cc +++ b/source/common/ssl/connection_impl.cc @@ -92,7 +92,9 @@ Network::ConnectionImpl::PostIoAction ConnectionImpl::doHandshake() { handshake_complete_ = true; raiseEvents(Network::ConnectionEvent::Connected); - return PostIoAction::KeepOpen; + + // It's possible that we closed during the handshake callback. + return state() == State::Open ? PostIoAction::KeepOpen : PostIoAction::Close; } else { int err = SSL_get_error(ssl_.get(), rc); conn_log_debug("handshake error: {}", *this, err); diff --git a/test/common/filter/auth/client_ssl_test.cc b/test/common/filter/auth/client_ssl_test.cc index 1fb9238be5e3..d80f0d154ec5 100644 --- a/test/common/filter/auth/client_ssl_test.cc +++ b/test/common/filter/auth/client_ssl_test.cc @@ -9,6 +9,7 @@ #include "test/test_common/utility.h" using testing::_; +using testing::InSequence; using testing::Invoke; using testing::Return; using testing::ReturnNew; @@ -49,6 +50,7 @@ class ClientSslAuthFilterTest : public testing::Test { } void createAuthFilter() { + filter_callbacks_.connection_.callbacks_.clear(); instance_.reset(new Instance(config_)); instance_->initializeReadFilterCallbacks(filter_callbacks_); } @@ -91,7 +93,7 @@ TEST_F(ClientSslAuthFilterTest, NoCluster) { EXPECT_THROW(new Config(loader, tls_, cm_, dispatcher_, stats_store_, runtime_), EnvoyException); } -TEST_F(ClientSslAuthFilterTest, Basic) { +TEST_F(ClientSslAuthFilterTest, NoSsl) { setup(); Buffer::OwnedImpl dummy("hello"); @@ -100,15 +102,27 @@ TEST_F(ClientSslAuthFilterTest, Basic) { EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection()); EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy)); EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy)); + filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); + + EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_no_ssl").value()); +} + +TEST_F(ClientSslAuthFilterTest, Ssl) { + InSequence s; + + setup(); + Buffer::OwnedImpl dummy("hello"); // Create a new filter for an SSL connection, with no backing auth data yet. createAuthFilter(); - EXPECT_CALL(filter_callbacks_.connection_, ssl()).Times(2).WillRepeatedly(Return(&ssl_)); + ON_CALL(filter_callbacks_.connection_, ssl()).WillByDefault(Return(&ssl_)); EXPECT_CALL(filter_callbacks_.connection_, remoteAddress()) .WillOnce(ReturnRefOfCopy(std::string("192.168.1.1"))); EXPECT_CALL(ssl_, sha256PeerCertificateDigest()).WillOnce(Return("digest")); EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection()); + filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected); + filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); // Respond. EXPECT_CALL(*interval_timer_, enableTimer(_)); @@ -121,26 +135,29 @@ TEST_F(ClientSslAuthFilterTest, Basic) { // Create a new filter for an SSL connection with an authorized cert. createAuthFilter(); - EXPECT_CALL(filter_callbacks_.connection_, ssl()).Times(2).WillRepeatedly(Return(&ssl_)); EXPECT_CALL(filter_callbacks_.connection_, remoteAddress()) .WillOnce(ReturnRefOfCopy(std::string("192.168.1.1"))); EXPECT_CALL(ssl_, sha256PeerCertificateDigest()) .WillOnce(Return("1b7d42ef0025ad89c1c911d6c10d7e86a4cb7c5863b2980abcbad1895f8b5314")); - EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection()); + EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection()); + EXPECT_CALL(filter_callbacks_, continueReading()); + filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected); EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy)); EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy)); + filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); // White list case. createAuthFilter(); - EXPECT_CALL(filter_callbacks_.connection_, ssl()).WillOnce(Return(&ssl_)); EXPECT_CALL(filter_callbacks_.connection_, remoteAddress()) .WillOnce(ReturnRefOfCopy(std::string("1.2.3.4"))); - EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection()); + EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection()); + EXPECT_CALL(filter_callbacks_, continueReading()); + filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected); EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy)); EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy)); + filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.update_success").value()); - EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_no_ssl").value()); EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_ip_white_list").value()); EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_digest_match").value()); EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_digest_no_match").value()); @@ -175,7 +192,6 @@ TEST_F(ClientSslAuthFilterTest, Basic) { callbacks_->onFailure(Http::AsyncClient::FailureReason::Reset); // Interval timer fires, cannot obtain async client. - EXPECT_CALL(*interval_timer_, enableTimer(_)); EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_)); EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce( @@ -185,6 +201,7 @@ TEST_F(ClientSslAuthFilterTest, Basic) { Http::HeaderMapPtr{new Http::TestHeaderMapImpl{{":status", "503"}}})}); return nullptr; })); + EXPECT_CALL(*interval_timer_, enableTimer(_)); interval_timer_->callback_(); EXPECT_EQ(4U, stats_store_.counter("auth.clientssl.vpn.update_failure").value()); diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index a841a15b3f34..1f4c19523410 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -68,6 +68,7 @@ TEST(ConnectionImplTest, BufferCallbacks) { EXPECT_CALL(server_callbacks, onBufferChange(ConnectionBufferType::Read, 4, -4)).InSequence(s2); EXPECT_CALL(server_callbacks, onEvent(ConnectionEvent::LocalClose)).InSequence(s2); + EXPECT_CALL(*read_filter, onNewConnection()); EXPECT_CALL(*read_filter, onData(_)) .WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus { data.drain(data.length()); diff --git a/test/common/network/proxy_protocol_test.cc b/test/common/network/proxy_protocol_test.cc index bf9dc5217e82..a5b4dc2f79ae 100644 --- a/test/common/network/proxy_protocol_test.cc +++ b/test/common/network/proxy_protocol_test.cc @@ -50,7 +50,8 @@ TEST_F(ProxyProtocolTest, Basic) { })); read_filter_.reset(new MockReadFilter()); - EXPECT_CALL(*read_filter_.get(), onData(BufferStringEqual("more data"))); + EXPECT_CALL(*read_filter_, onNewConnection()); + EXPECT_CALL(*read_filter_, onData(BufferStringEqual("more data"))); dispatcher_.run(Event::Dispatcher::RunType::NonBlock); accepted_connection->close(ConnectionCloseType::NoFlush);