Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client ssl filter: fix connection handling #235

Merged
merged 2 commits into from
Nov 19, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions source/common/filter/auth/client_ssl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,33 @@ Network::FilterStatus Instance::onNewConnection() {
if (!read_callbacks_->connection().ssl()) {
config_->stats().auth_no_ssl_.inc();
return Network::FilterStatus::Continue;
} else {
// Otherwise we need to wait for handshake to be complete before proceeding.
return Network::FilterStatus::StopIteration;
}
}

void Instance::onEvent(uint32_t events) {
if (!(events & Network::ConnectionEvent::Connected)) {
return;
}

ASSERT(read_callbacks_->connection().ssl());
if (config_->ipWhiteList().contains(read_callbacks_->connection().remoteAddress())) {
config_->stats().auth_ip_white_list_.inc();
return Network::FilterStatus::Continue;
read_callbacks_->continueReading();
return;
}

if (!config_->allowedPrincipals().allowed(
read_callbacks_->connection().ssl()->sha256PeerCertificateDigest())) {
config_->stats().auth_digest_no_match_.inc();
read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush);
return Network::FilterStatus::StopIteration;
return;
}

config_->stats().auth_digest_match_.inc();
return Network::FilterStatus::Continue;
read_callbacks_->continueReading();
}

} // Client Ssl
Expand Down
7 changes: 6 additions & 1 deletion source/common/filter/auth/client_ssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ typedef std::shared_ptr<Config> ConfigPtr;
/**
* A client SSL auth filter instance. One per connection.
*/
class Instance : public Network::ReadFilter {
class Instance : public Network::ReadFilter, public Network::ConnectionCallbacks {
public:
Instance(ConfigPtr config) : config_(config) {}

Expand All @@ -106,8 +106,13 @@ class Instance : public Network::ReadFilter {
Network::FilterStatus onNewConnection() override;
void initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) override {
read_callbacks_ = &callbacks;
read_callbacks_->connection().addConnectionCallbacks(*this);
}

// Network::ConnectionCallbacks
void onBufferChange(Network::ConnectionBufferType, uint64_t, int64_t) override {}
void onEvent(uint32_t events) override;

private:
ConfigPtr config_;
Network::ReadFilterCallbacks* read_callbacks_{};
Expand Down
4 changes: 3 additions & 1 deletion source/common/ssl/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ Network::ConnectionImpl::PostIoAction ConnectionImpl::doHandshake() {

handshake_complete_ = true;
raiseEvents(Network::ConnectionEvent::Connected);
return PostIoAction::KeepOpen;

// It's possible that we closed during the handshake callback.
return state() == State::Open ? PostIoAction::KeepOpen : PostIoAction::Close;
} else {
int err = SSL_get_error(ssl_.get(), rc);
conn_log_debug("handshake error: {}", *this, err);
Expand Down
33 changes: 25 additions & 8 deletions test/common/filter/auth/client_ssl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "test/test_common/utility.h"

using testing::_;
using testing::InSequence;
using testing::Invoke;
using testing::Return;
using testing::ReturnNew;
Expand Down Expand Up @@ -49,6 +50,7 @@ class ClientSslAuthFilterTest : public testing::Test {
}

void createAuthFilter() {
filter_callbacks_.connection_.callbacks_.clear();
instance_.reset(new Instance(config_));
instance_->initializeReadFilterCallbacks(filter_callbacks_);
}
Expand Down Expand Up @@ -91,7 +93,7 @@ TEST_F(ClientSslAuthFilterTest, NoCluster) {
EXPECT_THROW(new Config(loader, tls_, cm_, dispatcher_, stats_store_, runtime_), EnvoyException);
}

TEST_F(ClientSslAuthFilterTest, Basic) {
TEST_F(ClientSslAuthFilterTest, NoSsl) {
setup();
Buffer::OwnedImpl dummy("hello");

Expand All @@ -100,15 +102,27 @@ TEST_F(ClientSslAuthFilterTest, Basic) {
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection());
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_no_ssl").value());
}

TEST_F(ClientSslAuthFilterTest, Ssl) {
InSequence s;

setup();
Buffer::OwnedImpl dummy("hello");

// Create a new filter for an SSL connection, with no backing auth data yet.
createAuthFilter();
EXPECT_CALL(filter_callbacks_.connection_, ssl()).Times(2).WillRepeatedly(Return(&ssl_));
ON_CALL(filter_callbacks_.connection_, ssl()).WillByDefault(Return(&ssl_));
EXPECT_CALL(filter_callbacks_.connection_, remoteAddress())
.WillOnce(ReturnRefOfCopy(std::string("192.168.1.1")));
EXPECT_CALL(ssl_, sha256PeerCertificateDigest()).WillOnce(Return("digest"));
EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush));
EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection());
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected);
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

// Respond.
EXPECT_CALL(*interval_timer_, enableTimer(_));
Expand All @@ -121,26 +135,29 @@ TEST_F(ClientSslAuthFilterTest, Basic) {

// Create a new filter for an SSL connection with an authorized cert.
createAuthFilter();
EXPECT_CALL(filter_callbacks_.connection_, ssl()).Times(2).WillRepeatedly(Return(&ssl_));
EXPECT_CALL(filter_callbacks_.connection_, remoteAddress())
.WillOnce(ReturnRefOfCopy(std::string("192.168.1.1")));
EXPECT_CALL(ssl_, sha256PeerCertificateDigest())
.WillOnce(Return("1b7d42ef0025ad89c1c911d6c10d7e86a4cb7c5863b2980abcbad1895f8b5314"));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection());
EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection());
EXPECT_CALL(filter_callbacks_, continueReading());
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected);
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

// White list case.
createAuthFilter();
EXPECT_CALL(filter_callbacks_.connection_, ssl()).WillOnce(Return(&ssl_));
EXPECT_CALL(filter_callbacks_.connection_, remoteAddress())
.WillOnce(ReturnRefOfCopy(std::string("1.2.3.4")));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onNewConnection());
EXPECT_EQ(Network::FilterStatus::StopIteration, instance_->onNewConnection());
EXPECT_CALL(filter_callbacks_, continueReading());
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::Connected);
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
EXPECT_EQ(Network::FilterStatus::Continue, instance_->onData(dummy));
filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose);

EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.update_success").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_no_ssl").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_ip_white_list").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_digest_match").value());
EXPECT_EQ(1U, stats_store_.counter("auth.clientssl.vpn.auth_digest_no_match").value());
Expand Down Expand Up @@ -175,7 +192,6 @@ TEST_F(ClientSslAuthFilterTest, Basic) {
callbacks_->onFailure(Http::AsyncClient::FailureReason::Reset);

// Interval timer fires, cannot obtain async client.
EXPECT_CALL(*interval_timer_, enableTimer(_));
EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_));
EXPECT_CALL(cm_.async_client_, send_(_, _, _))
.WillOnce(
Expand All @@ -185,6 +201,7 @@ TEST_F(ClientSslAuthFilterTest, Basic) {
Http::HeaderMapPtr{new Http::TestHeaderMapImpl{{":status", "503"}}})});
return nullptr;
}));
EXPECT_CALL(*interval_timer_, enableTimer(_));
interval_timer_->callback_();

EXPECT_EQ(4U, stats_store_.counter("auth.clientssl.vpn.update_failure").value());
Expand Down
1 change: 1 addition & 0 deletions test/common/network/connection_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ TEST(ConnectionImplTest, BufferCallbacks) {
EXPECT_CALL(server_callbacks, onBufferChange(ConnectionBufferType::Read, 4, -4)).InSequence(s2);
EXPECT_CALL(server_callbacks, onEvent(ConnectionEvent::LocalClose)).InSequence(s2);

EXPECT_CALL(*read_filter, onNewConnection());
EXPECT_CALL(*read_filter, onData(_))
.WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus {
data.drain(data.length());
Expand Down
3 changes: 2 additions & 1 deletion test/common/network/proxy_protocol_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ TEST_F(ProxyProtocolTest, Basic) {
}));

read_filter_.reset(new MockReadFilter());
EXPECT_CALL(*read_filter_.get(), onData(BufferStringEqual("more data")));
EXPECT_CALL(*read_filter_, onNewConnection());
EXPECT_CALL(*read_filter_, onData(BufferStringEqual("more data")));

dispatcher_.run(Event::Dispatcher::RunType::NonBlock);
accepted_connection->close(ConnectionCloseType::NoFlush);
Expand Down