Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hcm: add match_upstream to SchemeHeaderTransformation #34099

Merged
merged 10 commits into from
Jun 4, 2024
5 changes: 5 additions & 0 deletions api/envoy/config/core/v3/protocol.proto
Original file line number Diff line number Diff line change
Expand Up @@ -652,5 +652,10 @@ message SchemeHeaderTransformation {
oneof transformation {
// Overwrite any Scheme header with the contents of this string.
alyssawilk marked this conversation as resolved.
Show resolved Hide resolved
string scheme_to_overwrite = 1 [(validate.rules).string = {in: "http" in: "https"}];

wtzhang23 marked this conversation as resolved.
Show resolved Hide resolved
// Set the Scheme header to match the upstream transport protocol. For example, should a
// request be sent to the upstream over TLS, the scheme header will be set to "https". Should the
// request be sent over plaintext, the scheme header will be set to "http".
bool match_upstream = 2;
}
}
14 changes: 14 additions & 0 deletions envoy/stream_info/stream_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,20 @@ class StreamInfo {
*/
virtual void setDownstreamTransportFailureReason(absl::string_view failure_reason) PURE;

/**
* Checked by routing filters before forwarding a request upstream.
* @return to override the scheme header to match the upstream transport
* protocol at routing filters.
*/
virtual bool shouldSchemeMatchUpstream() const PURE;

/**
* Called if a filter decides that the scheme should match the upstream transport protocol
* @param should_match_upstream true to hint to routing filters to override the scheme header
* to match the upstream transport protocol.
*/
virtual void setShouldSchemeMatchUpstream(bool should_match_upstream) PURE;

/**
* Checked by streams after finishing serving the request.
* @return bool true if the connection should be drained once this stream has
Expand Down
1 change: 1 addition & 0 deletions source/common/http/async_client_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ AsyncStreamImpl::AsyncStreamImpl(AsyncClientImpl& parent, AsyncClient::StreamCal
stream_info_.dynamicMetadata().MergeFrom(options.metadata);
stream_info_.setIsShadow(options.is_shadow);
stream_info_.setUpstreamClusterInfo(parent_.cluster_);
stream_info_.setShouldSchemeMatchUpstream(true);
wtzhang23 marked this conversation as resolved.
Show resolved Hide resolved
stream_info_.route_ = route_;

if (options.buffer_body_for_retry) {
Expand Down
5 changes: 5 additions & 0 deletions source/common/http/conn_manager_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ class ConnectionManagerConfig {
*/
virtual const absl::optional<std::string>& schemeToSet() const PURE;

/**
* @return bool whether the scheme should be overwritten to match the upstream transport protocol.
*/
virtual bool shouldSchemeMatchUpstream() const PURE;

/**
* @return ConnectionManagerStats& the stats to write to.
*/
Expand Down
3 changes: 3 additions & 0 deletions source/common/http/conn_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,9 @@ ConnectionManagerImpl::ActiveStream::ActiveStream(ConnectionManagerImpl& connect
filter_manager_.streamInfo().setStreamIdProvider(
std::make_shared<HttpStreamIdProviderImpl>(*this));

filter_manager_.streamInfo().setShouldSchemeMatchUpstream(
connection_manager.config_->shouldSchemeMatchUpstream());

// TODO(chaoqin-li1123): can this be moved to the on demand filter?
static const std::string route_factory = "envoy.route_config_update_requester.default";
auto factory =
Expand Down
16 changes: 14 additions & 2 deletions source/common/router/router.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,17 @@ uint64_t FilterUtility::percentageOfTimeout(const std::chrono::milliseconds resp
return static_cast<uint64_t>(response_time.count() * TimeoutPrecisionFactor / timeout.count());
}

void FilterUtility::setUpstreamScheme(Http::RequestHeaderMap& headers, bool downstream_secure) {
void FilterUtility::setUpstreamScheme(Http::RequestHeaderMap& headers, bool downstream_secure,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so there's a difference in what we pass in here
I think we need to decide if we want to set scheme based on "is upstream ssl" or "is upstream secure"
e.g. ALTS is secure but not TLS

Copy link
Contributor Author

@wtzhang23 wtzhang23 May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the golang grpc and alts source code, it seems like this should only be set when the upstream is using TLS and not ATLS. The RFC also seems to couple it with TLS.

Given the number of transport sockets, I'm thinking of adding a new virtual method to the transport socket to allow each socket to select which scheme it would like to set. Let me know if you have any opinions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided to keep it simple and just check if it has an SSL context (TLS) to set the scheme to https, otherwise use http.

bool upstream_secure, bool use_upstream) {
if (use_upstream) {
if (upstream_secure) {
headers.setReferenceScheme(Http::Headers::get().SchemeValues.Https);
} else {
headers.setReferenceScheme(Http::Headers::get().SchemeValues.Http);
}
return;
}

if (Http::Utility::schemeIsValid(headers.getSchemeValue())) {
return;
}
Expand Down Expand Up @@ -713,7 +723,9 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers,
route_entry_->finalizeRequestHeaders(headers, callbacks_->streamInfo(),
!config_->suppress_envoy_headers_);
FilterUtility::setUpstreamScheme(
headers, callbacks_->streamInfo().downstreamAddressProvider().sslConnection() != nullptr);
headers, callbacks_->streamInfo().downstreamAddressProvider().sslConnection() != nullptr,
host->transportSocketFactory().implementsSecureTransport(),
callbacks_->streamInfo().shouldSchemeMatchUpstream());

// Ensure an http transport scheme is selected before continuing with decoding.
ASSERT(headers.Scheme());
Expand Down
4 changes: 3 additions & 1 deletion source/common/router/router.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ class FilterUtility {

/**
* Set the :scheme header using the best information available. In order this is
* - security of upstream connection if use_upstream is true
* - existing scheme header if valid
* - x-forwarded-proto header if valid
* - security of downstream connection
*/
static void setUpstreamScheme(Http::RequestHeaderMap& headers, bool downstream_secure);
static void setUpstreamScheme(Http::RequestHeaderMap& headers, bool downstream_secure,
bool upstream_secure, bool use_upstream);

/**
* Determine whether a request should be shadowed.
Expand Down
7 changes: 7 additions & 0 deletions source/common/stream_info/stream_info_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,12 @@ struct StreamInfoImpl : public StreamInfo {
return downstream_transport_failure_reason_;
}

bool shouldSchemeMatchUpstream() const override { return should_scheme_match_upstream_; }

void setShouldSchemeMatchUpstream(bool should_match_upstream) override {
should_scheme_match_upstream_ = should_match_upstream;
}

bool shouldDrainConnectionUponCompletion() const override { return should_drain_connection_; }

void setShouldDrainConnectionUponCompletion(bool should_drain) override {
Expand Down Expand Up @@ -479,6 +485,7 @@ struct StreamInfoImpl : public StreamInfo {
BytesMeterSharedPtr downstream_bytes_meter_;
bool is_shadow_{false};
std::string downstream_transport_failure_reason_;
bool should_scheme_match_upstream_{false};
bool should_drain_connection_{false};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ HttpConnectionManagerConfig::HttpConnectionManagerConfig(
scheme_to_set_ = config.scheme_header_transformation().scheme_to_overwrite();
}

should_scheme_match_upstream_ = config.scheme_header_transformation().match_upstream();

if (!config.server_name().empty()) {
server_name_ = config.server_name();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class HttpConnectionManagerConfig : Logger::Loggable<Logger::Id::config>,
return server_transformation_;
}
const absl::optional<std::string>& schemeToSet() const override { return scheme_to_set_; }
bool shouldSchemeMatchUpstream() const override { return should_scheme_match_upstream_; }
Http::ConnectionManagerStats& stats() override { return stats_; }
Http::ConnectionManagerTracingStats& tracingStats() override { return tracing_stats_; }
bool useRemoteAddress() const override { return use_remote_address_; }
Expand Down Expand Up @@ -316,6 +317,7 @@ class HttpConnectionManagerConfig : Logger::Loggable<Logger::Id::config>,
HttpConnectionManagerProto::OVERWRITE};
std::string server_name_;
absl::optional<std::string> scheme_to_set_;
bool should_scheme_match_upstream_;
Tracing::TracerSharedPtr tracer_{std::make_shared<Tracing::NullTracer>()};
Http::TracingConnectionManagerConfigPtr tracing_config_;
absl::optional<std::string> user_agent_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void GrpcHealthCheckerImpl::GrpcActiveHealthCheckSession::onInterval() {
headers_message->headers(),
// Here there is no downstream connection so scheme will be based on
// upstream crypto
host_->transportSocketFactory().implementsSecureTransport());
false, host_->transportSocketFactory().implementsSecureTransport(), true);

auto status = request_encoder_->encodeHeaders(headers_message->headers(), false);
// Encoding will only fail if required headers are missing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void HttpHealthCheckerImpl::HttpActiveHealthCheckSession::onInterval() {
*request_headers,
// Here there is no downstream connection so scheme will be based on
// upstream crypto
host_->transportSocketFactory().implementsSecureTransport());
false, host_->transportSocketFactory().implementsSecureTransport(), true);
StreamInfo::StreamInfoImpl stream_info(protocol_, parent_.dispatcher_.timeSource(),
local_connection_info_provider_,
StreamInfo::FilterState::LifeSpan::FilterChain);
Expand Down
2 changes: 2 additions & 0 deletions source/server/admin/admin.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class AdminImpl : public Admin,
OptRef<const Router::ScopeKeyBuilder> scopeKeyBuilder() override { return scope_key_builder_; }
const std::string& serverName() const override { return Http::DefaultServerString::get(); }
const absl::optional<std::string>& schemeToSet() const override { return scheme_; }
bool shouldSchemeMatchUpstream() const override { return scheme_match_upstream_; }
HttpConnectionManagerProto::ServerHeaderTransformation
serverHeaderTransformation() const override {
return HttpConnectionManagerProto::OVERWRITE;
Expand Down Expand Up @@ -494,6 +495,7 @@ class AdminImpl : public Admin,
const std::vector<Http::OriginalIPDetectionSharedPtr> detection_extensions_{};
const std::vector<Http::EarlyHeaderMutationPtr> early_header_mutations_{};
const absl::optional<std::string> scheme_{};
const bool scheme_match_upstream_ = true;
wtzhang23 marked this conversation as resolved.
Show resolved Hide resolved
const bool ignore_global_conn_limit_;
std::unique_ptr<HttpConnectionManagerProto::ProxyStatusConfig> proxy_status_config_;
const Http::HeaderValidatorFactoryPtr header_validator_factory_;
Expand Down
2 changes: 2 additions & 0 deletions test/common/http/conn_manager_impl_fuzz_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class FuzzConfig : public ConnectionManagerConfig {
return server_transformation_;
}
const absl::optional<std::string>& schemeToSet() const override { return scheme_; }
bool shouldSchemeMatchUpstream() const override { return scheme_match_upstream_; }
ConnectionManagerStats& stats() override { return stats_; }
ConnectionManagerTracingStats& tracingStats() override { return tracing_stats_; }
bool useRemoteAddress() const override { return use_remote_address_; }
Expand Down Expand Up @@ -265,6 +266,7 @@ class FuzzConfig : public ConnectionManagerConfig {
HttpConnectionManagerProto::ServerHeaderTransformation server_transformation_{
HttpConnectionManagerProto::OVERWRITE};
absl::optional<std::string> scheme_;
bool scheme_match_upstream_{};
Stats::IsolatedStoreImpl fake_stats_;
ConnectionManagerStats stats_;
ConnectionManagerTracingStats tracing_stats_;
Expand Down
58 changes: 58 additions & 0 deletions test/common/http/conn_manager_impl_test_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4210,5 +4210,63 @@ TEST_F(HttpConnectionManagerImplTest, DownstreamTimingsRecordWhenRequestHeaderPr
filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose);
}

TEST_F(HttpConnectionManagerImplTest, PassMatchUpstreamSchemeHintToStreamInfo) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level comment:
it seems that the relevant part of the test is essentially:
EXPECT_TRUE(filter->callbacks_->streamInfo().shouldSchemeMatchUpstream());

Can this test be refactored to minimize the non-relevant parts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cut as much of the test as I'm familiar with. Will need to dive deeper to see if any other cuts can be made.

setup(/*ssl=*/false, /*server_name=*/"", /*tracing=*/false);
scheme_match_upstream_ = true;

// Store the basic request encoder during filter chain setup.
std::shared_ptr<MockStreamDecoderFilter> filter(new NiceMock<MockStreamDecoderFilter>());

EXPECT_CALL(*filter, decodeHeaders(_, true))
.Times(1)
.WillRepeatedly(Invoke([&](RequestHeaderMap& headers, bool) -> FilterHeadersStatus {
EXPECT_NE(nullptr, headers.ForwardedFor());
EXPECT_EQ("http", headers.getForwardedProtoValue());
EXPECT_TRUE(filter->callbacks_->streamInfo().shouldSchemeMatchUpstream());

return FilterHeadersStatus::StopIteration;
}));

EXPECT_CALL(*filter, setDecoderFilterCallbacks(_));

EXPECT_CALL(filter_factory_, createFilterChain(_))
.Times(1)
.WillRepeatedly(Invoke([&](FilterChainManager& manager) -> bool {
auto factory = createDecoderFilterFactoryCb(filter);
manager.applyFilterFactoryCb({}, factory);
return true;
}));

EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_));

// When dispatch is called on the codec, we pretend to get a new stream and then fire a headers
// only request into it. Then we respond into the filter.
EXPECT_CALL(*codec_, dispatch(_))
.Times(1)
.WillRepeatedly(Invoke([&](Buffer::Instance& data) -> Http::Status {
decoder_ = &conn_manager_->newStream(response_encoder_);

RequestHeaderMapPtr headers{new TestRequestHeaderMapImpl{
{":authority", "host"}, {":path", "/"}, {":method", "GET"}}};
decoder_->decodeHeaders(std::move(headers), true);
ResponseHeaderMapPtr response_headers{new TestResponseHeaderMapImpl{{":status", "200"}}};
filter->callbacks_->streamInfo().setResponseCodeDetails("");
filter->callbacks_->encodeHeaders(std::move(response_headers), true, "details");
response_encoder_.stream_.codec_callbacks_->onCodecEncodeComplete();

data.drain(1);
return Http::okStatus();
}));

// Kick off the incoming data.
Buffer::OwnedImpl fake_input{};
conn_manager_->onData(fake_input, false);

EXPECT_EQ(1U, stats_.named_.downstream_rq_2xx_.value());
EXPECT_EQ(1U, listener_stats_.downstream_rq_2xx_.value());
EXPECT_EQ(1U, stats_.named_.downstream_rq_completed_.value());
EXPECT_EQ(1U, listener_stats_.downstream_rq_completed_.value());
wtzhang23 marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace Http
} // namespace Envoy
1 change: 1 addition & 0 deletions test/common/http/conn_manager_impl_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ConnectionManagerConfigProxyObject : public ConnectionManagerConfig {
return parent_.serverHeaderTransformation();
}
const absl::optional<std::string>& schemeToSet() const override { return parent_.schemeToSet(); }
bool shouldSchemeMatchUpstream() const override { return parent_.shouldSchemeMatchUpstream(); }
ConnectionManagerStats& stats() override { return parent_.stats(); }
ConnectionManagerTracingStats& tracingStats() override { return parent_.tracingStats(); }
bool useRemoteAddress() const override { return parent_.useRemoteAddress(); }
Expand Down
2 changes: 2 additions & 0 deletions test/common/http/conn_manager_impl_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class HttpConnectionManagerImplMixin : public ConnectionManagerConfig {
return server_transformation_;
}
const absl::optional<std::string>& schemeToSet() const override { return scheme_; }
bool shouldSchemeMatchUpstream() const override { return scheme_match_upstream_; }
ConnectionManagerStats& stats() override { return stats_; }
ConnectionManagerTracingStats& tracingStats() override { return tracing_stats_; }
bool useRemoteAddress() const override { return use_remote_address_; }
Expand Down Expand Up @@ -235,6 +236,7 @@ class HttpConnectionManagerImplMixin : public ConnectionManagerConfig {
HttpConnectionManagerProto::ServerHeaderTransformation server_transformation_{
HttpConnectionManagerProto::OVERWRITE};
absl::optional<std::string> scheme_;
bool scheme_match_upstream_{false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If this is not being overwritten, consider converting to const.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need this to be non-const to override in the PassMatchUpstreamSchemeHintToStreamInfo test I wrote.

Network::Address::Ipv4Instance local_address_{"127.0.0.1"};
bool use_remote_address_{true};
Http::DefaultInternalAddressConfig internal_address_config_;
Expand Down
28 changes: 28 additions & 0 deletions test/common/router/router_2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -978,5 +978,33 @@ TEST_F(RouterTestSupressGRPCStatsDisabled, IncludeHttpTimeoutStats) {
.value());
}

class RouterTestSchemeMatchUpstream : public RouterTestBase {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to subclass? Can't you just put the expect_call in your test body?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed new class and moved code to router.cc where all the tests that use RouterTest live.

public:
RouterTestSchemeMatchUpstream()
: RouterTestBase(false, false, false, false, Protobuf::RepeatedPtrField<std::string>{}) {
EXPECT_CALL(callbacks_.stream_info_, shouldSchemeMatchUpstream()).WillRepeatedly(Return(true));
}
};

TEST_F(RouterTestSchemeMatchUpstream, OverwriteSchemeWithUpstreamTransportProtocol) {
EXPECT_CALL(cm_.thread_local_cluster_, httpConnPool(_, absl::optional<Http::Protocol>(), _));
EXPECT_CALL(cm_.thread_local_cluster_.conn_pool_, newStream(_, _, _))
.WillOnce(Return(&cancellable_));
expectResponseTimerCreate();

Http::TestRequestHeaderMapImpl headers;
HttpTestUtility::addDefaultHeaders(headers);
headers.setScheme("https");
router_->decodeHeaders(headers, true);
EXPECT_EQ(headers.getSchemeValue(), "http");

// When the router filter gets reset we should cancel the pool request.
EXPECT_CALL(cancellable_, cancel(_));
router_->onDestroy();
EXPECT_TRUE(verifyHostUpstreamStats(0, 0));
EXPECT_EQ(0U,
callbacks_.route_->route_entry_.virtual_cluster_.stats().upstream_rq_total_.value());
}

} // namespace Router
} // namespace Envoy
25 changes: 19 additions & 6 deletions test/common/router/router_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5806,31 +5806,44 @@ TEST(RouterFilterUtilityTest, FinalTimeoutSupressEnvoyHeaders) {
TEST(RouterFilterUtilityTest, SetUpstreamScheme) {
TestScopedRuntime scoped_runtime;

// With no scheme and x-forwarded-proto, set scheme based on encryption level
// With upstream scheme, set scheme based on upstream encryption level
{
Http::TestRequestHeaderMapImpl headers;
FilterUtility::setUpstreamScheme(headers, false);
FilterUtility::setUpstreamScheme(headers, false, false, true);
EXPECT_EQ("http", headers.get_(":scheme"));
}
{
Http::TestRequestHeaderMapImpl headers;
FilterUtility::setUpstreamScheme(headers, true);
FilterUtility::setUpstreamScheme(headers, false, true, true);
EXPECT_EQ("https", headers.get_(":scheme"));
}

// With no scheme and x-forwarded-proto, set scheme based on downstream
// encryption level
{
Http::TestRequestHeaderMapImpl headers;
FilterUtility::setUpstreamScheme(headers, false, false, false);
EXPECT_EQ("http", headers.get_(":scheme"));
}
{
Http::TestRequestHeaderMapImpl headers;
FilterUtility::setUpstreamScheme(headers, true, false, false);
EXPECT_EQ("https", headers.get_(":scheme"));
}

// With invalid x-forwarded-proto, still use scheme.
{
Http::TestRequestHeaderMapImpl headers;
headers.setForwardedProto("foo");
FilterUtility::setUpstreamScheme(headers, true);
FilterUtility::setUpstreamScheme(headers, true, false, false);
EXPECT_EQ("https", headers.get_(":scheme"));
}

// Use valid x-forwarded-proto.
{
Http::TestRequestHeaderMapImpl headers;
headers.setForwardedProto(Http::Headers::get().SchemeValues.Http);
FilterUtility::setUpstreamScheme(headers, true);
FilterUtility::setUpstreamScheme(headers, true, false, false);
EXPECT_EQ("http", headers.get_(":scheme"));
}

Expand All @@ -5839,7 +5852,7 @@ TEST(RouterFilterUtilityTest, SetUpstreamScheme) {
Http::TestRequestHeaderMapImpl headers;
headers.setScheme(Http::Headers::get().SchemeValues.Https);
headers.setForwardedProto(Http::Headers::get().SchemeValues.Http);
FilterUtility::setUpstreamScheme(headers, false);
FilterUtility::setUpstreamScheme(headers, false, false, false);
EXPECT_EQ("https", headers.get_(":scheme"));
}
}
Expand Down
Loading
Loading