diff --git a/envoy b/envoy index 48d361dfcd..9153a6077d 160000 --- a/envoy +++ b/envoy @@ -1 +1 @@ -Subproject commit 48d361dfcd9a8f0d5d5e9f9793a6a5651747bc03 +Subproject commit 9153a6077d17ed4af1457b998a9a6b3c75572456 diff --git a/envoy_build_config/extensions_build_config.bzl b/envoy_build_config/extensions_build_config.bzl index ba6c53a9d5..469bedd242 100644 --- a/envoy_build_config/extensions_build_config.bzl +++ b/envoy_build_config/extensions_build_config.bzl @@ -105,7 +105,7 @@ EXTENSIONS = { # SSL # - "envoy.extensions.common.crypto.utility_lib": "@envoy_openssl//source/extensions/common/crypto:utility_lib", + "envoy.common.crypto.utility_lib": "@envoy_openssl//source/extensions/common/crypto:utility_lib", # # Stat sinks diff --git a/source/extensions/filters/http/lua/BUILD b/source/extensions/filters/http/lua/BUILD index 7033364551..6d44e33dab 100644 --- a/source/extensions/filters/http/lua/BUILD +++ b/source/extensions/filters/http/lua/BUILD @@ -5,6 +5,7 @@ licenses(["notice"]) # Apache 2 load( "@envoy//bazel:envoy_build_system.bzl", + "envoy_cc_extension", "envoy_cc_library", "envoy_package", ) @@ -47,16 +48,17 @@ envoy_cc_library( ], ) -envoy_cc_library( +envoy_cc_extension( name = "config", repository = "@envoy", srcs = ["config.cc"], hdrs = ["config.h"], + security_posture = "robust_to_untrusted_downstream", deps = [ ":lua_filter_lib", "@envoy//include/envoy/registry", - "@envoy//source/common/config:filter_json_lib", "@envoy//source/extensions/filters/http:well_known_names", "@envoy//source/extensions/filters/http/common:factory_base_lib", + "@envoy_api//envoy/config/filter/http/lua/v2:pkg_cc_proto", ], ) diff --git a/source/extensions/filters/http/lua/config.cc b/source/extensions/filters/http/lua/config.cc index 11c3756606..686b82105c 100644 --- a/source/extensions/filters/http/lua/config.cc +++ b/source/extensions/filters/http/lua/config.cc @@ -3,8 +3,6 @@ #include "envoy/config/filter/http/lua/v2/lua.pb.validate.h" #include "envoy/registry/registry.h" -#include "common/config/filter_json.h" - #include "extensions/filters/http/lua/lua_filter.h" namespace Envoy { @@ -22,15 +20,6 @@ Http::FilterFactoryCb LuaFilterConfig::createFilterFactoryFromProtoTyped( }; } -Http::FilterFactoryCb -LuaFilterConfig::createFilterFactory(const Json::Object& json_config, - const std::string& stat_prefix, - Server::Configuration::FactoryContext& context) { - envoy::config::filter::http::lua::v2::Lua proto_config; - Config::FilterJson::translateLuaFilter(json_config, proto_config); - return createFilterFactoryFromProtoTyped(proto_config, stat_prefix, context); -} - /** * Static registration for the Lua filter. @see RegisterFactory. */ diff --git a/source/extensions/filters/http/lua/config.h b/source/extensions/filters/http/lua/config.h index 5f5f9aaf3a..3d007e0c74 100644 --- a/source/extensions/filters/http/lua/config.h +++ b/source/extensions/filters/http/lua/config.h @@ -18,10 +18,6 @@ class LuaFilterConfig : public Common::FactoryBase(raw_body, body_size); - headers->insertContentLength().value(body_size); + headers->setContentLength(body_size); } // Once we respond we treat that as the end of the script even if there is more code. Thus we @@ -193,7 +193,7 @@ int StreamHandleWrapper::luaHttpCall(lua_State* state) { if (body != nullptr) { message->body() = std::make_unique(body, body_size); - message->headers().insertContentLength().value(body_size); + message->headers().setContentLength(body_size); } absl::optional timeout; diff --git a/source/extensions/filters/http/lua/lua_filter.h b/source/extensions/filters/http/lua/lua_filter.h index c6676ff09c..2b99219124 100644 --- a/source/extensions/filters/http/lua/lua_filter.h +++ b/source/extensions/filters/http/lua/lua_filter.h @@ -5,7 +5,6 @@ #include "common/crypto/utility.h" -#include "extensions/common/crypto/crypto_impl.h" #include "extensions/filters/common/lua/wrappers.h" #include "extensions/filters/http/lua/wrappers.h" #include "extensions/filters/http/well_known_names.h" diff --git a/source/extensions/filters/http/lua/wrappers.cc b/source/extensions/filters/http/lua/wrappers.cc index 4874a716ce..a772f3c1ed 100644 --- a/source/extensions/filters/http/lua/wrappers.cc +++ b/source/extensions/filters/http/lua/wrappers.cc @@ -79,12 +79,7 @@ int HeaderMapWrapper::luaReplace(lua_State* state) { const char* value = luaL_checkstring(state, 3); const Http::LowerCaseString lower_key(key); - Http::HeaderEntry* entry = headers_.get(lower_key); - if (entry != nullptr) { - entry->value(value, strlen(value)); - } else { - headers_.addCopy(lower_key, value); - } + headers_.setCopy(lower_key, value); return 0; } diff --git a/source/extensions/grpc_credentials/aws_iam/config.cc b/source/extensions/grpc_credentials/aws_iam/config.cc index d2b423f744..b711f4a844 100644 --- a/source/extensions/grpc_credentials/aws_iam/config.cc +++ b/source/extensions/grpc_credentials/aws_iam/config.cc @@ -119,9 +119,9 @@ AwsIamHeaderAuthenticator::buildMessageToSign(absl::string_view service_url, Http::Utility::extractHostPathFromUri(uri, host, path); Http::RequestMessageImpl message; - message.headers().insertMethod().value().setReference(Http::Headers::get().MethodValues.Post); - message.headers().insertHost().value(host); - message.headers().insertPath().value(path); + message.headers().setReferenceMethod(Http::Headers::get().MethodValues.Post); + message.headers().setHost(host); + message.headers().setPath(path); return message; } diff --git a/source/extensions/transport_sockets/tls/config.cc b/source/extensions/transport_sockets/tls/config.cc index 248b47392c..9e617ebef5 100644 --- a/source/extensions/transport_sockets/tls/config.cc +++ b/source/extensions/transport_sockets/tls/config.cc @@ -53,7 +53,9 @@ Ssl::ContextManagerPtr SslContextManagerFactory::createContextManager(TimeSource return std::make_unique(time_source); } -REGISTER_FACTORY(SslContextManagerFactory, Ssl::ContextManagerFactory); +static Envoy::Registry::RegisterInternalFactory + ssl_manager_registered; } // namespace Tls } // namespace TransportSockets diff --git a/source/extensions/transport_sockets/tls/context_impl.cc b/source/extensions/transport_sockets/tls/context_impl.cc index c733b97273..f5b35f6b23 100644 --- a/source/extensions/transport_sockets/tls/context_impl.cc +++ b/source/extensions/transport_sockets/tls/context_impl.cc @@ -98,7 +98,7 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c } throw EnvoyException(fmt::format("Failed to initialize cipher suites {}. The following " "ciphers were rejected when tried individually: {}", - config.cipherSuites(), StringUtil::join(bad_ciphers, ", "))); + config.cipherSuites(), absl::StrJoin(bad_ciphers, ", "))); } if (!SSL_CTX_set1_curves_list(ctx.ssl_ctx_.get(), config.ecdhCurves().c_str())) { @@ -1046,7 +1046,7 @@ int ServerContextImpl::sessionTicketProcess(SSL*, uint8_t* key_name, uint8_t* iv // This RELEASE_ASSERT is logically a static_assert, but we can't actually get // EVP_CIPHER_key_length(cipher) at compile-time - RELEASE_ASSERT(key.aes_key_.size() == (unsigned) EVP_CIPHER_key_length(cipher), ""); + RELEASE_ASSERT(key.aes_key_.size() == static_cast(EVP_CIPHER_key_length(cipher)), ""); if (!EVP_EncryptInit_ex(ctx, cipher, nullptr, key.aes_key_.data(), iv)) { return -1; } @@ -1067,7 +1067,7 @@ int ServerContextImpl::sessionTicketProcess(SSL*, uint8_t* key_name, uint8_t* iv return -1; } - RELEASE_ASSERT(key.aes_key_.size() == (unsigned) EVP_CIPHER_key_length(cipher), ""); + RELEASE_ASSERT(key.aes_key_.size() == static_cast(EVP_CIPHER_key_length(cipher)), ""); if (!EVP_DecryptInit_ex(ctx, cipher, nullptr, key.aes_key_.data(), iv)) { return -1; } diff --git a/test/common/grpc/grpc_client_integration_test_harness.h b/test/common/grpc/grpc_client_integration_test_harness.h index cb9dc9fa3c..8b37766adb 100644 --- a/test/common/grpc/grpc_client_integration_test_harness.h +++ b/test/common/grpc/grpc_client_integration_test_harness.h @@ -148,10 +148,11 @@ class HelloworldStream : public MockAsyncStreamCallbacks } void expectGrpcStatus(Status::GrpcStatus grpc_status) { - if (grpc_status == Status::GrpcStatus::InvalidCode) { + if (grpc_status == Status::WellKnownGrpcStatus::InvalidCode) { EXPECT_CALL(*this, onRemoteClose(_, _)).WillExitIfNeeded(); - } else if (grpc_status > Status::GrpcStatus::MaximumValid) { - EXPECT_CALL(*this, onRemoteClose(Status::GrpcStatus::InvalidCode, _)).WillExitIfNeeded(); + } else if (grpc_status > Status::WellKnownGrpcStatus::MaximumKnown) { + EXPECT_CALL(*this, onRemoteClose(Status::WellKnownGrpcStatus::InvalidCode, _)) + .WillExitIfNeeded(); } else { EXPECT_CALL(*this, onRemoteClose(grpc_status, _)).WillExitIfNeeded(); } diff --git a/test/extensions/filters/http/common/aws/signer_impl_test.cc b/test/extensions/filters/http/common/aws/signer_impl_test.cc index fe4991c66d..c09a986b5f 100644 --- a/test/extensions/filters/http/common/aws/signer_impl_test.cc +++ b/test/extensions/filters/http/common/aws/signer_impl_test.cc @@ -30,9 +30,9 @@ class SignerImplTest : public testing::Test { time_system_.setSystemTime(std::chrono::milliseconds(1514862245000)); } - void addMethod(const std::string& method) { message_->headers().insertMethod().value(method); } + void addMethod(const std::string& method) { message_->headers().setMethod(method); } - void addPath(const std::string& path) { message_->headers().insertPath().value(path); } + void addPath(const std::string& path) { message_->headers().setPath(path); } void addHeader(const std::string& key, const std::string& value) { message_->headers().addCopy(Http::LowerCaseString(key), value); diff --git a/test/extensions/filters/http/lua/config_test.cc b/test/extensions/filters/http/lua/config_test.cc index 343abb9d09..5267896739 100644 --- a/test/extensions/filters/http/lua/config_test.cc +++ b/test/extensions/filters/http/lua/config_test.cc @@ -1,8 +1,11 @@ +#include + #include "envoy/config/filter/http/lua/v2/lua.pb.validate.h" #include "extensions/filters/http/lua/config.h" #include "test/mocks/server/mocks.h" +#include "test/test_common/utility.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -23,16 +26,15 @@ TEST(LuaFilterConfigTest, ValidateFail) { } TEST(LuaFilterConfigTest, LuaFilterInJson) { - std::string json_string = R"EOF( - { - "inline_code" : "print(5)" - } + const std::string yaml_string = R"EOF( + inline_code : "print(5)" )EOF"; - Json::ObjectSharedPtr json_config = Json::Factory::loadFromString(json_string); + envoy::config::filter::http::lua::v2::Lua proto_config; + TestUtility::loadFromYaml(yaml_string, proto_config); NiceMock context; LuaFilterConfig factory; - Http::FilterFactoryCb cb = factory.createFilterFactory(*json_config, "stats", context); + Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(proto_config, "stats", context); Http::MockFilterChainFactoryCallbacks filter_callback; EXPECT_CALL(filter_callback, addStreamFilter(_)); cb(filter_callback); diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc b/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc index 7d74e1a0a0..a79515d979 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc @@ -101,7 +101,7 @@ BENCHMARK(BM_TlsInspector)->Unit(benchmark::kMicrosecond); int main(int argc, char** argv) { Envoy::Thread::MutexBasicLockable lock; Envoy::Logger::Context logging_context(spdlog::level::warn, - Envoy::Logger::Logger::DEFAULT_LOG_FORMAT, lock); + Envoy::Logger::Logger::DEFAULT_LOG_FORMAT, lock, false); benchmark::Initialize(&argc, argv); if (benchmark::ReportUnrecognizedArguments(argc, argv)) { diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index ca85802fea..007891d03c 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -103,7 +103,7 @@ class TestUtilOptions : public TestUtilOptionsBase { bool expect_success, Network::Address::IpVersion version) : TestUtilOptionsBase(expect_success, version), client_ctx_yaml_(client_ctx_yaml), server_ctx_yaml_(server_ctx_yaml), expect_no_cert_(false), expect_no_cert_chain_(false), - expect_private_key_method_(false), expect_premature_disconnect_(false), + expect_premature_disconnect_(false), expected_server_close_event_(Network::ConnectionEvent::RemoteClose) { if (expect_success) { setExpectedServerStats("ssl.handshake"); @@ -233,7 +233,6 @@ class TestUtilOptions : public TestUtilOptionsBase { bool expect_no_cert_; bool expect_no_cert_chain_; - bool expect_private_key_method_; bool expect_premature_disconnect_; Network::ConnectionEvent expected_server_close_event_; std::string expected_digest_; @@ -267,8 +266,8 @@ void testUtil(const TestUtilOptions& options) { server_stats_store, std::vector{}); Event::DispatcherPtr dispatcher = server_api->allocateDispatcher(); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(options.version()), - nullptr, true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(options.version()), nullptr, true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = dispatcher->createListener(socket, callbacks, true); @@ -288,7 +287,7 @@ void testUtil(const TestUtilOptions& options) { ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection_ptr = client_connection.get(); Network::ConnectionPtr server_connection; @@ -555,8 +554,8 @@ const std::string testUtilV2(const TestUtilOptionsV2& options) { server_stats_store, server_names); Event::DispatcherPtr dispatcher(server_api->allocateDispatcher()); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(options.version()), - nullptr, true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(options.version()), nullptr, true); NiceMock callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = dispatcher->createListener(socket, callbacks, true); @@ -572,7 +571,7 @@ const std::string testUtilV2(const TestUtilOptionsV2& options) { ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), client_ssl_socket_factory.createTransportSocket(options.transportSocketOptions()), nullptr); if (!options.clientSession().empty()) { @@ -2167,14 +2166,14 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, server_stats_store, std::vector{}); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, - true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = dispatcher_->createListener(socket, callbacks, true); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), Network::Test::createRawBufferSocket(), nullptr); client_connection->connect(); Network::MockConnectionCallbacks client_connection_callbacks; @@ -2222,8 +2221,8 @@ TEST_P(SslSocketTest, HalfClose) { ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, server_stats_store, std::vector{}); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, - true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); Network::MockListenerCallbacks listener_callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = dispatcher_->createListener(socket, listener_callbacks, true); @@ -2241,7 +2240,7 @@ TEST_P(SslSocketTest, HalfClose) { ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->enableHalfClose(true); client_connection->addReadFilter(client_read_filter); @@ -2303,8 +2302,8 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, server_stats_store, std::vector{}); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, - true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = dispatcher_->createListener(socket, callbacks, true); @@ -2324,7 +2323,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { Stats::IsolatedStoreImpl client_stats_store; ClientSslSocketFactory ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(nullptr), nullptr); // Verify that server sent list with 2 acceptable client certificate CA names. @@ -2396,10 +2395,10 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, ServerSslSocketFactory server_ssl_socket_factory2(std::move(server_cfg2), manager, server_stats_store, server_names2); - Network::TcpListenSocket socket1(Network::Test::getCanonicalLoopbackAddress(ip_version), nullptr, - true); - Network::TcpListenSocket socket2(Network::Test::getCanonicalLoopbackAddress(ip_version), nullptr, - true); + auto socket1 = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(ip_version), nullptr, true); + auto socket2 = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(ip_version), nullptr, true); NiceMock callbacks; Network::MockConnectionHandler connection_handler; Event::DispatcherPtr dispatcher(server_api->allocateDispatcher()); @@ -2419,7 +2418,7 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, std::make_unique(client_tls_context, client_factory_context); ClientSslSocketFactory ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher->createClientConnection( - socket1.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket1->localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(nullptr), nullptr); Network::MockConnectionCallbacks client_connection_callbacks; @@ -2430,7 +2429,7 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, Network::ConnectionPtr server_connection; EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { - Network::TransportSocketFactory& tsf = socket->localAddress() == socket1.localAddress() + Network::TransportSocketFactory& tsf = socket->localAddress() == socket1->localAddress() ? server_ssl_socket_factory1 : server_ssl_socket_factory2; server_connection = dispatcher->createServerConnection(std::move(socket), @@ -2455,7 +2454,7 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, EXPECT_EQ(0UL, client_stats_store.counter("ssl.session_reused").value()); client_connection = dispatcher->createClientConnection( - socket2.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket2->localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); const SslSocketInfo* ssl_socket = @@ -2468,7 +2467,7 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { - Network::TransportSocketFactory& tsf = socket->localAddress() == socket1.localAddress() + Network::TransportSocketFactory& tsf = socket->localAddress() == socket1->localAddress() ? server_ssl_socket_factory1 : server_ssl_socket_factory2; server_connection = dispatcher->createServerConnection(std::move(socket), @@ -2814,10 +2813,10 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { ServerSslSocketFactory server2_ssl_socket_factory(std::move(server2_cfg), manager, server_stats_store, std::vector{}); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, - true); - Network::TcpListenSocket socket2(Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, - true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); + auto socket2 = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = dispatcher_->createListener(socket, callbacks, true); @@ -2838,7 +2837,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { Stats::IsolatedStoreImpl client_stats_store; ClientSslSocketFactory ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(nullptr), nullptr); Network::MockConnectionCallbacks client_connection_callbacks; @@ -2851,7 +2850,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& accepted_socket) -> void { Network::TransportSocketFactory& tsf = - accepted_socket->localAddress() == socket.localAddress() ? server_ssl_socket_factory + accepted_socket->localAddress() == socket->localAddress() ? server_ssl_socket_factory : server2_ssl_socket_factory; server_connection = dispatcher_->createServerConnection(std::move(accepted_socket), tsf.createTransportSocket(nullptr)); @@ -2878,7 +2877,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { EXPECT_EQ(1UL, client_stats_store.counter("ssl.handshake").value()); client_connection = dispatcher_->createClientConnection( - socket2.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket2->localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); const SslSocketInfo* ssl_socket = @@ -2891,7 +2890,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& accepted_socket) -> void { Network::TransportSocketFactory& tsf = - accepted_socket->localAddress() == socket.localAddress() ? server_ssl_socket_factory + accepted_socket->localAddress() == socket->localAddress() ? server_ssl_socket_factory : server2_ssl_socket_factory; server_connection = dispatcher_->createServerConnection(std::move(accepted_socket), tsf.createTransportSocket(nullptr)); @@ -2929,8 +2928,8 @@ void SslSocketTest::testClientSessionResumption(const std::string& server_ctx_ya ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, server_stats_store, std::vector{}); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(version), nullptr, - true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(version), nullptr, true); NiceMock callbacks; Network::MockConnectionHandler connection_handler; Api::ApiPtr api = Api::createApiForTest(server_stats_store, time_system_); @@ -2954,7 +2953,7 @@ void SslSocketTest::testClientSessionResumption(const std::string& server_ctx_ya ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); Network::MockConnectionCallbacks client_connection_callbacks; @@ -3015,7 +3014,7 @@ void SslSocketTest::testClientSessionResumption(const std::string& server_ctx_ya close_count = 0; client_connection = dispatcher->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); client_connection->connect(); @@ -3189,14 +3188,14 @@ TEST_P(SslSocketTest, SslError) { ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, server_stats_store, std::vector{}); - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, - true); + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = dispatcher_->createListener(socket, callbacks, true); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), Network::Test::createRawBufferSocket(), nullptr); client_connection->connect(); Buffer::OwnedImpl bad_data("bad_handshake_data"); @@ -3874,6 +3873,8 @@ class SslReadBufferLimitTest : public SslSocketTest { server_ssl_socket_factory_ = std::make_unique( std::move(server_cfg), *manager_, server_stats_store_, std::vector{}); + socket_ = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); listener_ = dispatcher_->createListener(socket_, listener_callbacks_, true); TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml_), upstream_tls_context_); @@ -3885,7 +3886,7 @@ class SslReadBufferLimitTest : public SslSocketTest { auto transport_socket = client_ssl_socket_factory_->createTransportSocket(nullptr); client_transport_socket_ = transport_socket.get(); client_connection_ = dispatcher_->createClientConnection( - socket_.localAddress(), source_address_, std::move(transport_socket), nullptr); + socket_->localAddress(), source_address_, std::move(transport_socket), nullptr); client_connection_->addConnectionCallbacks(client_callbacks_); client_connection_->connect(); read_filter_.reset(new Network::MockReadFilter()); @@ -4018,8 +4019,7 @@ class SslReadBufferLimitTest : public SslSocketTest { Stats::IsolatedStoreImpl server_stats_store_; Stats::IsolatedStoreImpl client_stats_store_; - Network::TcpListenSocket socket_{Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, - true}; + std::shared_ptr socket_; Network::MockListenerCallbacks listener_callbacks_; Network::MockConnectionHandler connection_handler_; const std::string server_ctx_yaml_ = R"EOF( @@ -4187,7 +4187,7 @@ void fakePauseJob() { ASYNC_JOB* job; ASYNC_WAIT_CTX* waitctx; OSSL_ASYNC_FD pipefds[2] = {0, 0}; - OSSL_ASYNC_FD* writefd; + void* void_writefd; char buf = 'X'; if ((job = ASYNC_get_current_job()) == NULL) @@ -4195,10 +4195,10 @@ void fakePauseJob() { waitctx = ASYNC_get_wait_ctx(job); - if (ASYNC_WAIT_CTX_get_fd(waitctx, fake_engine_id, &pipefds[0], (void**)&writefd)) { - pipefds[1] = *writefd; + if (ASYNC_WAIT_CTX_get_fd(waitctx, fake_engine_id, &pipefds[0], &void_writefd)) { + pipefds[1] = *static_cast(void_writefd); } else { - writefd = static_cast(OPENSSL_malloc(sizeof(*writefd))); + OSSL_ASYNC_FD* writefd = static_cast(OPENSSL_malloc(sizeof(*writefd))); if (writefd == NULL) return; if (pipe(pipefds) != 0) { diff --git a/test/integration/autonomous_upstream.cc b/test/integration/autonomous_upstream.cc index bd00fd176b..dfecdc5f66 100644 --- a/test/integration/autonomous_upstream.cc +++ b/test/integration/autonomous_upstream.cc @@ -92,10 +92,8 @@ bool AutonomousUpstream::createNetworkFilterChain(Network::Connection& connectio bool AutonomousUpstream::createListenerFilterChain(Network::ListenerFilterManager&) { return true; } -bool AutonomousUpstream::createUdpListenerFilterChain(Network::UdpListenerFilterManager&, - Network::UdpReadFilterCallbacks&) { - return true; -} +void AutonomousUpstream::createUdpListenerFilterChain(Network::UdpListenerFilterManager&, + Network::UdpReadFilterCallbacks&) {} void AutonomousUpstream::setLastRequestHeaders(const Http::HeaderMap& headers) { Thread::LockGuard lock(headers_lock_); diff --git a/test/integration/autonomous_upstream.h b/test/integration/autonomous_upstream.h index 6783b6c5a8..15110798b8 100644 --- a/test/integration/autonomous_upstream.h +++ b/test/integration/autonomous_upstream.h @@ -65,7 +65,7 @@ class AutonomousUpstream : public FakeUpstream { createNetworkFilterChain(Network::Connection& connection, const std::vector& filter_factories) override; bool createListenerFilterChain(Network::ListenerFilterManager& listener) override; - bool createUdpListenerFilterChain(Network::UdpListenerFilterManager& listener, + void createUdpListenerFilterChain(Network::UdpListenerFilterManager& listener, Network::UdpReadFilterCallbacks& callbacks) override; void setLastRequestHeaders(const Http::HeaderMap& headers); diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index f3884981a1..e94f96e2a0 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -16,6 +16,7 @@ #include "common/network/address_impl.h" #include "common/network/listen_socket_impl.h" #include "common/network/raw_buffer_socket.h" +#include "common/network/socket_option_factory.h" #include "common/network/utility.h" #include "server/connection_handler_impl.h" @@ -357,7 +358,7 @@ FakeUpstream::FakeUpstream(const std::string& uds_path, FakeHttpConnection::Type static Network::SocketPtr makeTcpListenSocket(const Network::Address::InstanceConstSharedPtr& address) { - return Network::SocketPtr{new Network::TcpListenSocket(address, nullptr, true)}; + return std::make_unique(address, nullptr, true); } static Network::SocketPtr makeTcpListenSocket(uint32_t port, Network::Address::IpVersion version) { @@ -365,14 +366,25 @@ static Network::SocketPtr makeTcpListenSocket(uint32_t port, Network::Address::I Network::Utility::parseInternetAddress(Network::Test::getAnyAddressString(version), port)); } +static Network::SocketPtr +makeUdpListenSocket(const Network::Address::InstanceConstSharedPtr& address) { + auto socket = std::make_unique(address, nullptr, true); + // TODO(mattklein123): These options are set in multiple locations. We should centralize them for + // UDP listeners. + socket->addOptions(Network::SocketOptionFactory::buildIpPacketInfoOptions()); + socket->addOptions(Network::SocketOptionFactory::buildRxQueueOverFlowOptions()); + return socket; +} + FakeUpstream::FakeUpstream(const Network::Address::InstanceConstSharedPtr& address, FakeHttpConnection::Type type, Event::TestTimeSystem& time_system, - bool enable_half_close) - : FakeUpstream(Network::Test::createRawBufferSocketFactory(), makeTcpListenSocket(address), + bool enable_half_close, bool udp_fake_upstream) + : FakeUpstream(Network::Test::createRawBufferSocketFactory(), + udp_fake_upstream ? makeUdpListenSocket(address) : makeTcpListenSocket(address), type, time_system, enable_half_close) { - ENVOY_LOG(info, "starting fake server on socket {}:{}. Address version is {}", + ENVOY_LOG(info, "starting fake server on socket {}:{}. Address version is {}. UDP={}", address->ip()->addressAsString(), address->ip()->port(), - Network::Test::addressVersionAsString(address->ip()->version())); + Network::Test::addressVersionAsString(address->ip()->version()), udp_fake_upstream); } FakeUpstream::FakeUpstream(uint32_t port, FakeHttpConnection::Type type, @@ -381,7 +393,7 @@ FakeUpstream::FakeUpstream(uint32_t port, FakeHttpConnection::Type type, : FakeUpstream(Network::Test::createRawBufferSocketFactory(), makeTcpListenSocket(port, version), type, time_system, enable_half_close) { ENVOY_LOG(info, "starting fake server on port {}. Address version is {}", - this->localAddress()->ip()->port(), Network::Test::addressVersionAsString(version)); + localAddress()->ip()->port(), Network::Test::addressVersionAsString(version)); } FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, @@ -390,13 +402,14 @@ FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket : FakeUpstream(std::move(transport_socket_factory), makeTcpListenSocket(port, version), type, time_system, false) { ENVOY_LOG(info, "starting fake SSL server on port {}. Address version is {}", - this->localAddress()->ip()->port(), Network::Test::addressVersionAsString(version)); + localAddress()->ip()->port(), Network::Test::addressVersionAsString(version)); } FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, Network::SocketPtr&& listen_socket, FakeHttpConnection::Type type, Event::TestTimeSystem& time_system, bool enable_half_close) - : http_type_(type), socket_(std::move(listen_socket)), + : http_type_(type), socket_(Network::SocketSharedPtr(listen_socket.release())), + socket_factory_(std::make_shared(socket_)), api_(Api::createApiForTest(stats_store_)), time_system_(time_system), dispatcher_(api_->allocateDispatcher()), handler_(new Server::ConnectionHandlerImpl(*dispatcher_, "fake_upstream")), @@ -426,15 +439,15 @@ bool FakeUpstream::createNetworkFilterChain(Network::Connection& connection, auto connection_wrapper = std::make_unique(connection, allow_unexpected_disconnects_); connection_wrapper->moveIntoListBack(std::move(connection_wrapper), new_connections_); - new_connection_event_.notifyOne(); + upstream_event_.notifyOne(); return true; } bool FakeUpstream::createListenerFilterChain(Network::ListenerFilterManager&) { return true; } -bool FakeUpstream::createUdpListenerFilterChain(Network::UdpListenerFilterManager&, - Network::UdpReadFilterCallbacks&) { - return true; +void FakeUpstream::createUdpListenerFilterChain(Network::UdpListenerFilterManager& udp_listener, + Network::UdpReadFilterCallbacks& callbacks) { + udp_listener.addReadFilter(std::make_unique(*this, callbacks)); } void FakeUpstream::threadRoutine() { @@ -462,7 +475,7 @@ AssertionResult FakeUpstream::waitForHttpConnection(Event::Dispatcher& client_di if (time_system.monotonicTime() >= end_time) { return AssertionFailure() << "Timed out waiting for new connection."; } - time_system_.waitFor(lock_, new_connection_event_, 5ms); + time_system_.waitFor(lock_, upstream_event_, 5ms); if (new_connections_.empty()) { // Run the client dispatcher since we may need to process window updates, etc. client_dispatcher.run(Event::Dispatcher::RunType::NonBlock); @@ -495,7 +508,7 @@ FakeUpstream::waitForHttpConnection(Event::Dispatcher& client_dispatcher, FakeUpstream& upstream = *it; Thread::ReleasableLockGuard lock(upstream.lock_); if (upstream.new_connections_.empty()) { - time_system.waitFor(upstream.lock_, upstream.new_connection_event_, 5ms); + time_system.waitFor(upstream.lock_, upstream.upstream_event_, 5ms); } if (upstream.new_connections_.empty()) { @@ -522,7 +535,7 @@ AssertionResult FakeUpstream::waitForRawConnection(FakeRawConnectionPtr& connect Thread::LockGuard lock(lock_); if (new_connections_.empty()) { ENVOY_LOG(debug, "waiting for raw connection"); - time_system_.waitFor(lock_, new_connection_event_, + time_system_.waitFor(lock_, upstream_event_, timeout); // Safe since CondVar::waitFor won't throw. } @@ -545,6 +558,36 @@ SharedConnectionWrapper& FakeUpstream::consumeConnection() { return connection_wrapper->shared_connection(); } +testing::AssertionResult FakeUpstream::waitForUdpDatagram(Network::UdpRecvData& data_to_fill, + std::chrono::milliseconds timeout) { + Thread::LockGuard lock(lock_); + auto end_time = time_system_.monotonicTime() + timeout; + while (received_datagrams_.empty()) { + if (time_system_.monotonicTime() >= end_time) { + return AssertionFailure() << "Timed out waiting for UDP datagram."; + } + time_system_.waitFor(lock_, upstream_event_, 5ms); // Safe since CondVar::waitFor won't throw. + } + data_to_fill = std::move(received_datagrams_.front()); + received_datagrams_.pop_front(); + return AssertionSuccess(); +} + +void FakeUpstream::onRecvDatagram(Network::UdpRecvData& data) { + Thread::LockGuard lock(lock_); + received_datagrams_.emplace_back(std::move(data)); + upstream_event_.notifyOne(); +} + +void FakeUpstream::sendUdpDatagram(const std::string& buffer, + const Network::Address::Instance& peer) { + dispatcher_->post([this, buffer, &peer] { + const auto rc = Network::Utility::writeToSocket(socket_->ioHandle(), Buffer::OwnedImpl(buffer), + nullptr, peer); + EXPECT_TRUE(rc.rc_ == buffer.length()); + }); +} + AssertionResult FakeRawConnection::waitForData(uint64_t num_bytes, std::string* data, milliseconds timeout) { Thread::LockGuard lock(lock_); diff --git a/test/integration/fake_upstream.h b/test/integration/fake_upstream.h index a7ab469680..9a7c0edf34 100644 --- a/test/integration/fake_upstream.h +++ b/test/integration/fake_upstream.h @@ -29,6 +29,8 @@ #include "common/network/listen_socket_impl.h" #include "common/stats/isolated_store_impl.h" +#include "server/active_raw_udp_listener_config.h" + #include "test/test_common/printers.h" #include "test/test_common/test_time_system.h" #include "test/test_common/utility.h" @@ -526,7 +528,7 @@ class FakeUpstream : Logger::Loggable, // Creates a fake upstream bound to the specified |address|. FakeUpstream(const Network::Address::InstanceConstSharedPtr& address, FakeHttpConnection::Type type, Event::TestTimeSystem& time_system, - bool enable_half_close = false); + bool enable_half_close = false, bool udp_fake_upstream = false); // Creates a fake upstream bound to INADDR_ANY and the specified |port|. FakeUpstream(uint32_t port, FakeHttpConnection::Type type, Network::Address::IpVersion version, @@ -560,6 +562,15 @@ class FakeUpstream : Logger::Loggable, FakeHttpConnectionPtr& connection, std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + // Waits for 1 UDP datagram to be received. + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForUdpDatagram(Network::UdpRecvData& data_to_fill, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + // Send a UDP datagram on the fake upstream thread. + void sendUdpDatagram(const std::string& buffer, const Network::Address::Instance& peer); + // Network::FilterChainManager const Network::FilterChain* findFilterChain(const Network::ConnectionSocket&) const override { return filter_chain_.get(); @@ -570,7 +581,7 @@ class FakeUpstream : Logger::Loggable, createNetworkFilterChain(Network::Connection& connection, const std::vector& filter_factories) override; bool createListenerFilterChain(Network::ListenerFilterManager& listener) override; - bool createUdpListenerFilterChain(Network::UdpListenerFilterManager& udp_listener, + void createUdpListenerFilterChain(Network::UdpListenerFilterManager& udp_listener, Network::UdpReadFilterCallbacks& callbacks) override; void set_allow_unexpected_disconnects(bool value) { allow_unexpected_disconnects_ = value; } @@ -589,16 +600,52 @@ class FakeUpstream : Logger::Loggable, Network::SocketPtr&& connection, FakeHttpConnection::Type type, Event::TestTimeSystem& time_system, bool enable_half_close); + class FakeListenSocketFactory : public Network::ListenSocketFactory { + public: + FakeListenSocketFactory(Network::SocketSharedPtr socket) : socket_(socket) {} + + // Network::ListenSocketFactory + Network::Address::SocketType socketType() const override { return socket_->socketType(); } + + const Network::Address::InstanceConstSharedPtr& localAddress() const override { + return socket_->localAddress(); + } + + Network::SocketSharedPtr getListenSocket() override { return socket_; } + absl::optional> sharedSocket() const override { + return *socket_; + } + + private: + Network::SocketSharedPtr socket_; + }; + + class FakeUpstreamUdpFilter : public Network::UdpListenerReadFilter { + public: + FakeUpstreamUdpFilter(FakeUpstream& parent, Network::UdpReadFilterCallbacks& callbacks) + : UdpListenerReadFilter(callbacks), parent_(parent) {} + + // Network::UdpListenerReadFilter + void onData(Network::UdpRecvData& data) override { parent_.onRecvDatagram(data); } + void onReceiveError(Api::IoError::IoErrorCode) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + + private: + FakeUpstream& parent_; + }; + class FakeListener : public Network::ListenerConfig { public: - FakeListener(FakeUpstream& parent) : parent_(parent), name_("fake_upstream") {} + FakeListener(FakeUpstream& parent) + : parent_(parent), name_("fake_upstream"), + udp_listener_factory_(std::make_unique()) {} private: // Network::ListenerConfig Network::FilterChainManager& filterChainManager() override { return parent_; } Network::FilterChainFactory& filterChainFactory() override { return parent_; } - Network::Socket& socket() override { return *parent_.socket_; } - const Network::Socket& socket() const override { return *parent_.socket_; } + Network::ListenSocketFactory& listenSocketFactory() override { + return *parent_.socket_factory_; + } bool bindToPort() override { return true; } bool handOffRestoredDestinationConnections() const override { return false; } uint32_t perConnectionBufferLimitBytes() const override { return 0; } @@ -609,7 +656,9 @@ class FakeUpstream : Logger::Loggable, Stats::Scope& listenerScope() override { return parent_.stats_store_; } uint64_t listenerTag() const override { return 0; } const std::string& name() const override { return name_; } - Network::ActiveUdpListenerFactory* udpListenerFactory() override { return nullptr; } + Network::ActiveUdpListenerFactory* udpListenerFactory() override { + return udp_listener_factory_.get(); + } Network::ConnectionBalancer& connectionBalancer() override { return connection_balancer_; } envoy::api::v2::core::TrafficDirection direction() const override { return envoy::api::v2::core::TrafficDirection::UNSPECIFIED; @@ -618,18 +667,21 @@ class FakeUpstream : Logger::Loggable, FakeUpstream& parent_; const std::string name_; Network::NopConnectionBalancerImpl connection_balancer_; + const Network::ActiveUdpListenerFactoryPtr udp_listener_factory_; }; void threadRoutine(); SharedConnectionWrapper& consumeConnection() ABSL_EXCLUSIVE_LOCKS_REQUIRED(lock_); + void onRecvDatagram(Network::UdpRecvData& data); - Network::SocketPtr socket_; + Network::SocketSharedPtr socket_; + Network::ListenSocketFactorySharedPtr socket_factory_; ConditionalInitializer server_initialized_; // Guards any objects which can be altered both in the upstream thread and the // main test thread. Thread::MutexBasicLockable lock_; Thread::ThreadPtr thread_; - Thread::CondVar new_connection_event_; + Thread::CondVar upstream_event_; Api::ApiPtr api_; Event::TestTimeSystem& time_system_; Event::DispatcherPtr dispatcher_; @@ -644,6 +696,7 @@ class FakeUpstream : Logger::Loggable, const bool enable_half_close_; FakeListener listener_; const Network::FilterChainSharedPtr filter_chain_; + std::list received_datagrams_ ABSL_GUARDED_BY(lock_); }; using FakeUpstreamPtr = std::unique_ptr; diff --git a/test/integration/integration.cc b/test/integration/integration.cc index ef18b2e704..210775d1da 100644 --- a/test/integration/integration.cc +++ b/test/integration/integration.cc @@ -403,7 +403,7 @@ void BaseIntegrationTest::registerTestServerPorts(const std::vector auto listeners = test_server_->server().listenerManager().listeners(); auto listener_it = listeners.cbegin(); for (; port_it != port_names.end() && listener_it != listeners.end(); ++port_it, ++listener_it) { - const auto listen_addr = listener_it->get().socket().localAddress(); + const auto listen_addr = listener_it->get().listenSocketFactory().localAddress(); if (listen_addr->type() == Network::Address::Type::Ip) { ENVOY_LOG(debug, "registered '{}' as port {}.", *port_it, listen_addr->ip()->port()); registerPort(*port_it, listen_addr->ip()->port()); @@ -421,7 +421,7 @@ void BaseIntegrationTest::createGeneratedApiTestServer(const std::string& bootst bool reject_unknown_dynamic_fields, bool allow_lds_rejection) { test_server_ = IntegrationTestServer::create( - bootstrap_path, version_, on_server_init_function_, deterministic_, timeSystem(), *api_, + bootstrap_path, version_, on_server_ready_function_, on_server_init_function_, deterministic_, timeSystem(), *api_, defer_listener_finalization_, process_object_, allow_unknown_static_fields, reject_unknown_dynamic_fields, concurrency_); if (config_helper_.bootstrap().static_resources().listeners_size() > 0 && @@ -475,7 +475,7 @@ void BaseIntegrationTest::createApiTestServer(const ApiFilesystemConfig& api_fil void BaseIntegrationTest::createTestServer(const std::string& json_path, const std::vector& port_names) { test_server_ = createIntegrationTestServer( - TestEnvironment::temporaryFileSubstitute(json_path, port_map_, version_), nullptr, + TestEnvironment::temporaryFileSubstitute(json_path, port_map_, version_), nullptr, nullptr, timeSystem()); registerTestServerPorts(port_names); } @@ -499,9 +499,10 @@ void BaseIntegrationTest::sendRawHttpAndWaitForResponse(int port, const char* ra IntegrationTestServerPtr BaseIntegrationTest::createIntegrationTestServer(const std::string& bootstrap_path, + std::function on_server_ready_function, std::function on_server_init_function, Event::TestTimeSystem& time_system) { - return IntegrationTestServer::create(bootstrap_path, version_, on_server_init_function, + return IntegrationTestServer::create(bootstrap_path, version_, on_server_ready_function, on_server_init_function, deterministic_, time_system, *api_, defer_listener_finalization_); } @@ -666,7 +667,7 @@ AssertionResult BaseIntegrationTest::compareDeltaDiscoveryRequest( request.error_detail().code(), expected_error_code, request.error_detail().message()); } - if (expected_error_code != Grpc::Status::GrpcStatus::Ok && + if (expected_error_code != Grpc::Status::WellKnownGrpcStatus::Ok && request.error_detail().message().find(expected_error_substring) == std::string::npos) { return AssertionFailure() << "\"" << expected_error_substring << "\" is not a substring of actual error message \"" diff --git a/test/integration/integration.h b/test/integration/integration.h index 3cdabb86ab..66461a88a1 100644 --- a/test/integration/integration.h +++ b/test/integration/integration.h @@ -222,14 +222,13 @@ class BaseIntegrationTest : Logger::Loggable { // available if you're writing a SotW/delta-specific test. // TODO(fredlas) expect_node was defaulting false here; the delta+SotW unification work restores // it. - AssertionResult - compareDiscoveryRequest(const std::string& expected_type_url, const std::string& expected_version, - const std::vector& expected_resource_names, - const std::vector& expected_resource_names_added, - const std::vector& expected_resource_names_removed, - bool expect_node = true, - const Protobuf::int32 expected_error_code = Grpc::Status::GrpcStatus::Ok, - const std::string& expected_error_message = ""); + AssertionResult compareDiscoveryRequest( + const std::string& expected_type_url, const std::string& expected_version, + const std::vector& expected_resource_names, + const std::vector& expected_resource_names_added, + const std::vector& expected_resource_names_removed, bool expect_node = true, + const Protobuf::int32 expected_error_code = Grpc::Status::WellKnownGrpcStatus::Ok, + const std::string& expected_error_message = ""); template void sendDiscoveryResponse(const std::string& type_url, const std::vector& state_of_the_world, const std::vector& added_or_updated, @@ -245,7 +244,7 @@ class BaseIntegrationTest : Logger::Loggable { const std::string& expected_type_url, const std::vector& expected_resource_subscriptions, const std::vector& expected_resource_unsubscriptions, - const Protobuf::int32 expected_error_code = Grpc::Status::GrpcStatus::Ok, + const Protobuf::int32 expected_error_code = Grpc::Status::WellKnownGrpcStatus::Ok, const std::string& expected_error_message = "") { return compareDeltaDiscoveryRequest(expected_type_url, expected_resource_subscriptions, expected_resource_unsubscriptions, xds_stream_, @@ -256,7 +255,7 @@ class BaseIntegrationTest : Logger::Loggable { const std::string& expected_type_url, const std::vector& expected_resource_subscriptions, const std::vector& expected_resource_unsubscriptions, FakeStreamPtr& stream, - const Protobuf::int32 expected_error_code = Grpc::Status::GrpcStatus::Ok, + const Protobuf::int32 expected_error_code = Grpc::Status::WellKnownGrpcStatus::Ok, const std::string& expected_error_message = ""); // TODO(fredlas) expect_node was defaulting false here; the delta+SotW unification work restores @@ -264,7 +263,7 @@ class BaseIntegrationTest : Logger::Loggable { AssertionResult compareSotwDiscoveryRequest( const std::string& expected_type_url, const std::string& expected_version, const std::vector& expected_resource_names, bool expect_node = true, - const Protobuf::int32 expected_error_code = Grpc::Status::GrpcStatus::Ok, + const Protobuf::int32 expected_error_code = Grpc::Status::WellKnownGrpcStatus::Ok, const std::string& expected_error_message = ""); template @@ -334,6 +333,7 @@ class BaseIntegrationTest : Logger::Loggable { // Will not return until that server is listening. virtual IntegrationTestServerPtr createIntegrationTestServer(const std::string& bootstrap_path, + std::function on_server_ready_function, std::function on_server_init_function, Event::TestTimeSystem& time_system); @@ -350,6 +350,9 @@ class BaseIntegrationTest : Logger::Loggable { // The ProcessObject to use when constructing the envoy server. absl::optional> process_object_{absl::nullopt}; + // Steps that should be done before the envoy server starting. + std::function on_server_ready_function_; + // Steps that should be done in parallel with the envoy server starting. E.g., xDS // pre-init, control plane synchronization needed for server start. std::function on_server_init_function_; @@ -368,6 +371,9 @@ class BaseIntegrationTest : Logger::Loggable { bool enable_half_close_{false}; + // Whether the default created fake upstreams are UDP listeners. + bool udp_fake_upstream_{false}; + // True if test will use a fixed RNG value. bool deterministic_{}; diff --git a/test/integration/server.cc b/test/integration/server.cc index fe4d46d2b0..50554b867a 100644 --- a/test/integration/server.cc +++ b/test/integration/server.cc @@ -52,12 +52,16 @@ OptionsImpl createTestOptionsImpl(const std::string& config_path, const std::str IntegrationTestServerPtr IntegrationTestServer::create( const std::string& config_path, const Network::Address::IpVersion version, + std::function server_ready_function, std::function on_server_init_function, bool deterministic, Event::TestTimeSystem& time_system, Api::Api& api, bool defer_listener_finalization, absl::optional> process_object, bool allow_unknown_static_fields, bool reject_unknown_dynamic_fields, uint32_t concurrency) { IntegrationTestServerPtr server{ std::make_unique(time_system, api, config_path)}; + if (server_ready_function != nullptr) { + server->setOnServerReadyCb(server_ready_function); + } server->start(version, on_server_init_function, deterministic, defer_listener_finalization, process_object, allow_unknown_static_fields, reject_unknown_dynamic_fields, concurrency); @@ -111,7 +115,7 @@ void IntegrationTestServer::start( if (tap_path) { std::vector ports; for (auto listener : server().listenerManager().listeners()) { - const auto listen_addr = listener.get().socket().localAddress(); + const auto listen_addr = listener.get().listenSocketFactory().localAddress(); if (listen_addr->type() == Network::Address::Type::Ip) { ports.push_back(listen_addr->ip()->port()); } @@ -152,6 +156,9 @@ void IntegrationTestServer::onWorkerListenerRemoved() { void IntegrationTestServer::serverReady() { pending_listeners_ = server().listenerManager().listeners().size(); + if (on_server_ready_cb_ != nullptr) { + on_server_ready_cb_(*this); + } server_set_.setReady(); } diff --git a/test/integration/server.h b/test/integration/server.h index 513b5c6d1a..9a1ca60e75 100644 --- a/test/integration/server.h +++ b/test/integration/server.h @@ -233,6 +233,7 @@ class IntegrationTestServer : public Logger::Loggable, public: static IntegrationTestServerPtr create(const std::string& config_path, const Network::Address::IpVersion version, + std::function on_server_ready_function, std::function on_server_init_function, bool deterministic, Event::TestTimeSystem& time_system, Api::Api& api, bool defer_listener_finalization = false, @@ -252,6 +253,9 @@ class IntegrationTestServer : public Logger::Loggable, void setOnWorkerListenerRemovedCb(std::function on_worker_listener_removed) { on_worker_listener_removed_cb_ = std::move(on_worker_listener_removed); } + void setOnServerReadyCb(std::function on_server_ready) { + on_server_ready_cb_ = std::move(on_server_ready); + } void onRuntimeCreated() override; void start(const Network::Address::IpVersion version, @@ -354,6 +358,7 @@ class IntegrationTestServer : public Logger::Loggable, std::function on_worker_listener_added_cb_; std::function on_worker_listener_removed_cb_; TcpDumpPtr tcp_dump_; + std::function on_server_ready_cb_; }; // Default implementation of IntegrationTestServer diff --git a/test/integration/utility.cc b/test/integration/utility.cc index 74dbdd298c..ca1f05955b 100644 --- a/test/integration/utility.cc +++ b/test/integration/utility.cc @@ -85,12 +85,12 @@ IntegrationUtil::makeSingleRequest(const Network::Address::InstanceConstSharedPt encoder.getStream().addCallbacks(*response); Http::HeaderMapImpl headers; - headers.insertMethod().value(method); - headers.insertPath().value(url); - headers.insertHost().value(host); - headers.insertScheme().value(Http::Headers::get().SchemeValues.Http); + headers.setMethod(method); + headers.setPath(url); + headers.setHost(host); + headers.setReferenceScheme(Http::Headers::get().SchemeValues.Http); if (!content_type.empty()) { - headers.insertContentType().value(content_type); + headers.setContentType(content_type); } encoder.encodeHeaders(headers, body.empty()); if (!body.empty()) { diff --git a/test/mocks/server/mocks.cc b/test/mocks/server/mocks.cc index 3a95f145c1..16ea2134de 100644 --- a/test/mocks/server/mocks.cc +++ b/test/mocks/server/mocks.cc @@ -91,9 +91,10 @@ MockOverloadManager::~MockOverloadManager() = default; MockListenerComponentFactory::MockListenerComponentFactory() : socket_(std::make_shared>()) { ON_CALL(*this, createListenSocket(_, _, _, _)) - .WillByDefault(Invoke( - [&](Network::Address::InstanceConstSharedPtr, Network::Address::SocketType, - const Network::Socket::OptionsSharedPtr& options, bool) -> Network::SocketSharedPtr { + .WillByDefault( + Invoke([&](Network::Address::InstanceConstSharedPtr, Network::Address::SocketType, + const Network::Socket::OptionsSharedPtr& options, + const ListenSocketCreationParams&) -> Network::SocketSharedPtr { if (!Network::Socket::applyOptions(options, *socket_, envoy::api::v2::core::SocketOption::STATE_PREBIND)) { throw EnvoyException("MockListenerComponentFactory: Setting socket options failed"); @@ -115,7 +116,8 @@ MockWorkerFactory::~MockWorkerFactory() = default; MockWorker::MockWorker() { ON_CALL(*this, addListener(_, _)) .WillByDefault( - Invoke([this](Network::ListenerConfig&, AddListenerCompletion completion) -> void { + Invoke([this](Network::ListenerConfig& config, AddListenerCompletion completion) -> void { + config.listenSocketFactory().getListenSocket(); EXPECT_EQ(nullptr, add_listener_completion_); add_listener_completion_ = completion; })); diff --git a/test/mocks/server/mocks.h b/test/mocks/server/mocks.h index 9889503a09..097e2382e1 100644 --- a/test/mocks/server/mocks.h +++ b/test/mocks/server/mocks.h @@ -77,6 +77,7 @@ class MockOptions : public Options { MOCK_CONST_METHOD0(componentLogLevels, const std::vector>&()); MOCK_CONST_METHOD0(logFormat, const std::string&()); + MOCK_CONST_METHOD0(logFormatEscaped, bool()); MOCK_CONST_METHOD0(logPath, const std::string&()); MOCK_CONST_METHOD0(parentShutdownTime, std::chrono::seconds()); MOCK_CONST_METHOD0(restartEpoch, uint64_t()); @@ -88,7 +89,6 @@ class MockOptions : public Options { MOCK_CONST_METHOD0(hotRestartDisabled, bool()); MOCK_CONST_METHOD0(signalHandlingEnabled, bool()); MOCK_CONST_METHOD0(mutexTracingEnabled, bool()); - MOCK_CONST_METHOD0(libeventBufferEnabled, bool()); MOCK_CONST_METHOD0(fakeSymbolTableEnabled, bool()); MOCK_CONST_METHOD0(cpusetThreadsEnabled, bool()); MOCK_CONST_METHOD0(toCommandLineOptions, Server::CommandLineOptionsPtr()); @@ -259,7 +259,7 @@ class MockListenerComponentFactory : public ListenerComponentFactory { Network::SocketSharedPtr(Network::Address::InstanceConstSharedPtr address, Network::Address::SocketType socket_type, const Network::Socket::OptionsSharedPtr& options, - bool bind_to_port)); + const ListenSocketCreationParams& params)); MOCK_METHOD1(createDrainManager_, DrainManager*(envoy::api::v2::Listener::DrainType drain_type)); MOCK_METHOD0(nextListenerTag, uint64_t()); @@ -280,6 +280,8 @@ class MockListenerManager : public ListenerManager { MOCK_METHOD1(startWorkers, void(GuardDog& guard_dog)); MOCK_METHOD1(stopListeners, void(StopListenersType listeners_type)); MOCK_METHOD0(stopWorkers, void()); + MOCK_METHOD0(beginListenerUpdate, void()); + MOCK_METHOD1(endListenerUpdate, void(ListenerManager::FailureStates&&)); }; class MockServerLifecycleNotifier : public ServerLifecycleNotifier {