From f0f1da93be4264312f438b0026c17844ee5c73a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 15 Apr 2026 15:14:36 +0000 Subject: [PATCH 1/3] chore: create proxy instance once per provider instead of on each request --- internal/testutil/mockprovider.go | 31 +++-- passthrough.go | 149 ++++++++++---------- passthrough_test.go | 223 ++++++++++++++++++++++++++++-- 3 files changed, 301 insertions(+), 102 deletions(-) diff --git a/internal/testutil/mockprovider.go b/internal/testutil/mockprovider.go index 3e1db798..8eb69fed 100644 --- a/internal/testutil/mockprovider.go +++ b/internal/testutil/mockprovider.go @@ -11,21 +11,26 @@ import ( ) type MockProvider struct { - NameStr string - URL string - Bridged []string - Passthrough []string - InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) + NameStr string + URL string + Bridged []string + Passthrough []string + InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) + InjectAuthHeaderFunc func(h *http.Header) } -func (m *MockProvider) Type() string { return m.NameStr } -func (m *MockProvider) Name() string { return m.NameStr } -func (m *MockProvider) BaseURL() string { return m.URL } -func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) } -func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } -func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough } -func (*MockProvider) AuthHeader() string { return "Authorization" } -func (*MockProvider) InjectAuthHeader(_ *http.Header) {} +func (m *MockProvider) Type() string { return m.NameStr } +func (m *MockProvider) Name() string { return m.NameStr } +func (m *MockProvider) BaseURL() string { return m.URL } +func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) } +func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } +func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough } +func (*MockProvider) AuthHeader() string { return "Authorization" } +func (m *MockProvider) InjectAuthHeader(h *http.Header) { + if m.InjectAuthHeaderFunc != nil { + m.InjectAuthHeaderFunc(h) + } +} func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil } func (*MockProvider) APIDumpDir() string { return "" } func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) { diff --git a/passthrough.go b/passthrough.go index b94ae128..a9cc8069 100644 --- a/passthrough.go +++ b/passthrough.go @@ -1,7 +1,6 @@ package aibridge import ( - "net" "net/http" "net/http/httputil" "net/url" @@ -21,7 +20,39 @@ import ( // newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically // by a [intercept.Provider]. +// A single reverse proxy is created per provider and reused across all requests. func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { + provBaseURL, err := url.Parse(prov.BaseURL()) + if err != nil { + return newInvalidBaseURLHandler(prov, logger, m, tracer, err) + } + if _, err := url.JoinPath(provBaseURL.Path, "/"); err != nil { + return newInvalidBaseURLHandler(prov, logger, m, tracer, err) + } + + // Transport tuned for streaming (no response header timeout). + t := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + // Build a reverse proxy to the upstream, reused across all requests for this provider. + // All request modifications happen in Rewrite. + proxy := &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + rewritePassthroughRequest(pr, provBaseURL, prov) + }, + Transport: apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()), + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) { + logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path)) + http.Error(rw, "upstream proxy error", http.StatusBadGateway) + }, + } + return func(w http.ResponseWriter, r *http.Request) { if m != nil { m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1) @@ -33,88 +64,48 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics )) defer span.End() - upURL, err := url.Parse(prov.BaseURL()) - if err != nil { - logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err)) - http.Error(w, "request error", http.StatusBadGateway) - span.SetStatus(codes.Error, "failed to parse provider base URL: "+err.Error()) - return - } + proxy.ServeHTTP(w, r.WithContext(ctx)) + } +} - // Append the request path to the upstream base path. - reqPath, err := url.JoinPath(upURL.Path, r.URL.Path) - if err != nil { - logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstream_path", upURL.Path), slog.F("request_path", r.URL.Path)) - http.Error(w, "failed to join upstream path", http.StatusInternalServerError) - span.SetStatus(codes.Error, "failed to join upstream path: "+err.Error()) - return - } - // Ensure leading slash, proxied requests should have absolute paths. - // JoinPath can return relative paths, eg. when upURL path is empty. - if len(reqPath) == 0 || reqPath[0] != '/' { - reqPath = "/" + reqPath - } +// rewritePassthroughRequest configures the outbound request for the upstream and +// applies proxy headers and provider auth. +func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov provider.Provider) { + pr.SetURL(provBaseURL) - // Build a reverse proxy to the upstream. - proxy := &httputil.ReverseProxy{ - Director: func(req *http.Request) { - // Set scheme/host to upstream. - req.URL.Scheme = upURL.Scheme - req.URL.Host = upURL.Host - req.URL.Path = reqPath - req.URL.RawPath = "" - - // Preserve query string. - req.URL.RawQuery = r.URL.RawQuery - - // Set Host header for upstream. - req.Host = upURL.Host - span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, req.URL.String())) - - // Copy headers from client. - req.Header = r.Header.Clone() - - // Standard proxy headers. - host, _, herr := net.SplitHostPort(r.RemoteAddr) - if herr != nil { - host = r.RemoteAddr - } - if prior := req.Header.Get("X-Forwarded-For"); prior != "" { - req.Header.Set("X-Forwarded-For", prior+", "+host) - } else { - req.Header.Set("X-Forwarded-For", host) - } - req.Header.Set("X-Forwarded-Host", r.Host) - if r.TLS != nil { - req.Header.Set("X-Forwarded-Proto", "https") - } else { - req.Header.Set("X-Forwarded-Proto", "http") - } - // Avoid default Go user-agent if none provided. - if _, ok := req.Header["User-Agent"]; !ok { - req.Header.Set("User-Agent", "aibridge") // TODO: use build tag. - } - - // Inject provider auth. - prov.InjectAuthHeader(&req.Header) - }, - ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) { - logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path)) - http.Error(rw, "upstream proxy error", http.StatusBadGateway) - }, - } + if prior, ok := pr.In.Header["X-Forwarded-For"]; ok { + pr.Out.Header["X-Forwarded-For"] = append([]string(nil), prior...) + } + pr.SetXForwarded() - // Transport tuned for streaming (no response header timeout). - t := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, + span := trace.SpanFromContext(pr.Out.Context()) + span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, pr.Out.URL.String())) + + // Avoid default Go user-agent if none provided. + if _, ok := pr.Out.Header["User-Agent"]; !ok { + pr.Out.Header.Set("User-Agent", "aibridge") // TODO: use build tag. + } + + // Inject provider auth. + prov.InjectAuthHeader(&pr.Out.Header) +} + +// newInvalidBaseURLHandler returns a handler that always returns 502 because +// the provider's base URL is invalid. +func newInvalidBaseURLHandler(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer, baseURLErr error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if m != nil { + m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1) } - proxy.Transport = apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()) - proxy.ServeHTTP(w, r) + ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( + attribute.String(tracing.PassthroughURL, r.URL.String()), + attribute.String(tracing.PassthroughMethod, r.Method), + )) + defer span.End() + + logger.Warn(ctx, "invalid provider base URL", slog.Error(baseURLErr)) + http.Error(w, "invalid provider base URL", http.StatusBadGateway) + span.SetStatus(codes.Error, "invalid provider base URL: "+baseURLErr.Error()) } } diff --git a/passthrough_test.go b/passthrough_test.go index 85600d65..9376892a 100644 --- a/passthrough_test.go +++ b/passthrough_test.go @@ -1,8 +1,14 @@ package aibridge //nolint:testpackage // tests unexported newPassthroughRouter import ( + "crypto/tls" + "maps" + "net" "net/http" "net/http/httptest" + "net/http/httputil" + "net/url" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -21,14 +27,19 @@ func TestPassthroughRoutes(t *testing.T) { tests := []struct { name string baseURLPath string - passthroughRoute string + reqPath string + reqHost string + reqRemoteAddr string + reqHeaders http.Header expectRequestPath string + expectQuery string + expectHeaders http.Header expectRespStatus int expectRespBody string }{ { name: "passthrough_route_no_path", - passthroughRoute: "/v1/conversations", + reqPath: "/v1/conversations", expectRequestPath: "/v1/conversations", expectRespStatus: http.StatusOK, expectRespBody: upstreamRespBody, @@ -36,7 +47,7 @@ func TestPassthroughRoutes(t *testing.T) { { name: "base_URL_path_is_preserved_in_passthrough_routes", baseURLPath: "/api/v2", - passthroughRoute: "/v1/models", + reqPath: "/v1/models", expectRequestPath: "/api/v2/v1/models", expectRespStatus: http.StatusOK, expectRespBody: upstreamRespBody, @@ -44,16 +55,43 @@ func TestPassthroughRoutes(t *testing.T) { { name: "passthrough_route_break_parse_base_url", baseURLPath: "/%zz", - passthroughRoute: "/v1/models/", + reqPath: "/v1/models/", expectRespStatus: http.StatusBadGateway, - expectRespBody: "request error", + expectRespBody: "invalid provider base URL", }, { - name: "passthrough_route_break_join_path", + name: "passthrough_route_rejects_invalid_base_url_path", baseURLPath: "/%25", - passthroughRoute: "/v1/models", - expectRespStatus: http.StatusInternalServerError, - expectRespBody: "failed to join upstream path", + reqPath: "/v1/models", + expectRespStatus: http.StatusBadGateway, + expectRespBody: "invalid provider base URL", + }, + { + name: "proxy_headers_are_set_and_forwarded_chain_is_appended", + reqPath: "/v1/models", + reqHost: "client.example.com", + reqRemoteAddr: "1.1.1.1:1111", + reqHeaders: http.Header{ + "X-Forwarded-For": {"2.2.2.2, 3.3.3.3"}, + }, + expectRequestPath: "/v1/models", + expectRespStatus: http.StatusOK, + expectRespBody: upstreamRespBody, + expectHeaders: http.Header{ + "Accept-Encoding": {"gzip"}, + "User-Agent": {"aibridge"}, + "X-Forwarded-For": {"2.2.2.2, 3.3.3.3, 1.1.1.1"}, + "X-Forwarded-Host": {"client.example.com"}, + "X-Forwarded-Proto": {"http"}, + }, + }, + { + name: "query_string_is_preserved", + reqPath: "/v1/models?search=gpt&limit=10", + expectRequestPath: "/v1/models", + expectQuery: "search=gpt&limit=10", + expectRespStatus: http.StatusOK, + expectRespBody: upstreamRespBody, }, } @@ -65,6 +103,10 @@ func TestPassthroughRoutes(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, tc.expectRequestPath, r.URL.Path) + assert.Equal(t, tc.expectQuery, r.URL.RawQuery) + if tc.expectHeaders != nil { + assert.Equal(t, tc.expectHeaders, r.Header) + } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(upstreamRespBody)) })) @@ -76,7 +118,10 @@ func TestPassthroughRoutes(t *testing.T) { handler := newPassthroughRouter(prov, logger, nil, testTracer) - req := httptest.NewRequest("", tc.passthroughRoute, nil) + req := httptest.NewRequest("", tc.reqPath, nil) + maps.Copy(req.Header, tc.reqHeaders) + req.Host = tc.reqHost + req.RemoteAddr = tc.reqRemoteAddr resp := httptest.NewRecorder() handler.ServeHTTP(resp, req) @@ -85,3 +130,161 @@ func TestPassthroughRoutes(t *testing.T) { }) } } + +func TestRewritePassthroughRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + reqPath string + reqRemoteAddr string + reqHeaders http.Header + reqTLS bool + baseURL string + provider *testutil.MockProvider + expectURL string + expectHeaders http.Header + }{ + { + name: "sets_upstream_url_and_forwarded_headers_from_client_peer", + reqPath: "http://client-host/chat?stream=true", + reqRemoteAddr: "1.1.1.1:1111", + baseURL: "https://upstream-host/base", + provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, + expectURL: "https://upstream-host/base/chat?stream=true", + expectHeaders: http.Header{ + "X-Forwarded-Host": {"client-host"}, + "X-Forwarded-Proto": {"http"}, + "X-Forwarded-For": {"1.1.1.1"}, + "User-Agent": {"aibridge"}, + }, + }, + { + name: "preserves_client_user_agent", + reqPath: "http://client-host/chat", + reqRemoteAddr: "1.1.1.1:1111", + reqHeaders: http.Header{"User-Agent": {"custom-agent/1.0"}}, + baseURL: "https://upstream-host/base", + provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, + expectURL: "https://upstream-host/base/chat", + expectHeaders: http.Header{ + "X-Forwarded-Host": {"client-host"}, + "X-Forwarded-Proto": {"http"}, + "X-Forwarded-For": {"1.1.1.1"}, + "User-Agent": {"custom-agent/1.0"}, + }, + }, + { + name: "injects_auth_header", + reqPath: "http://client-host/chat", + reqRemoteAddr: "1.1.1.1:1111", + baseURL: "https://upstream-host/base", + provider: &testutil.MockProvider{ + URL: "https://upstream-host/base", + InjectAuthHeaderFunc: func(h *http.Header) { + h.Set("Authorization", "Bearer test-token") + }, + }, + expectURL: "https://upstream-host/base/chat", + expectHeaders: http.Header{ + "X-Forwarded-Host": {"client-host"}, + "X-Forwarded-Proto": {"http"}, + "X-Forwarded-For": {"1.1.1.1"}, + "User-Agent": {"aibridge"}, + "Authorization": {"Bearer test-token"}, + }, + }, + { + name: "appends_remote_addr_to_existing_forwarded_for_chain", + reqPath: "http://client-host/chat", + reqRemoteAddr: "1.1.1.1:1111", + reqHeaders: http.Header{ + "X-Forwarded-For": {"2.2.2.2, 3.3.3.3"}, + }, + baseURL: "https://upstream-host/base", + provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, + expectURL: "https://upstream-host/base/chat", + expectHeaders: http.Header{ + "X-Forwarded-Host": {"client-host"}, + "X-Forwarded-Proto": {"http"}, + "X-Forwarded-For": {"2.2.2.2, 3.3.3.3, 1.1.1.1"}, + "User-Agent": {"aibridge"}, + }, + }, + { + name: "omits_forwarded_for_when_remote_addr_is_not_parseable", + reqPath: "http://client-host/chat", + reqRemoteAddr: "not-a-socket-address", + reqHeaders: http.Header{ + "X-Forwarded-For": {"1.1.1.1"}, + }, + baseURL: "https://upstream-host/base", + provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, + expectURL: "https://upstream-host/base/chat", + expectHeaders: http.Header{ + "X-Forwarded-Host": {"client-host"}, + "X-Forwarded-Proto": {"http"}, + "User-Agent": {"aibridge"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, tc.reqPath, nil) + maps.Copy(r.Header, tc.reqHeaders) + r.RemoteAddr = tc.reqRemoteAddr + if tc.reqTLS { + r.TLS = &tls.ConnectionState{} + } + provBaseURL, err := url.Parse(tc.baseURL) + assert.NoError(t, err) + + pr := &httputil.ProxyRequest{ + In: r, + Out: r.Clone(r.Context()), + } + + rewritePassthroughRequest(pr, provBaseURL, tc.provider) + + assert.Equal(t, tc.expectURL, pr.Out.URL.String()) + assert.Equal(t, "", pr.Out.Host) + assert.Equal(t, tc.expectHeaders, pr.Out.Header) + }) + } +} + +func TestPassthroughRouterReusesProxyInstance(t *testing.T) { + t.Parallel() + + var newConnections atomic.Int32 + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + upstream.Config.ConnState = func(_ net.Conn, state http.ConnState) { + if state == http.StateNew { + newConnections.Add(1) + } + } + upstream.Start() + t.Cleanup(upstream.Close) + + logger := slogtest.Make(t, nil) + prov := &testutil.MockProvider{URL: upstream.URL} + handler := newPassthroughRouter(prov, logger, nil, testTracer) + + for i := range 2 { + req := httptest.NewRequest(http.MethodGet, "http://proxy.example.test/v1/models", nil) + resp := httptest.NewRecorder() + + handler.ServeHTTP(resp, req) + + assert.Equalf(t, http.StatusOK, resp.Code, "request %d", i+1) + assert.Equal(t, "ok", resp.Body.String()) + } + + assert.EqualValues(t, 1, newConnections.Load()) +} From 1334e6bf2a58a9ee532207fd8529bad6dc8f4289 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 16 Apr 2026 16:01:16 +0000 Subject: [PATCH 2/3] review: TLS test case, comments, span start extraction --- passthrough.go | 30 ++++++++++++++++++------------ passthrough_test.go | 22 +++++++++++++++------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/passthrough.go b/passthrough.go index a9cc8069..bf7ec7bd 100644 --- a/passthrough.go +++ b/passthrough.go @@ -1,6 +1,7 @@ package aibridge import ( + "context" "net/http" "net/http/httputil" "net/url" @@ -58,10 +59,7 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1) } - ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( - attribute.String(tracing.PassthroughURL, r.URL.String()), - attribute.String(tracing.PassthroughMethod, r.Method), - )) + ctx, span := startSpan(r, tracer) defer span.End() proxy.ServeHTTP(w, r.WithContext(ctx)) @@ -73,6 +71,10 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov provider.Provider) { pr.SetURL(provBaseURL) + // Rewrite sets "X-Forwarded-For" to just last hop (clients IP address). + // To preserve old Director behavior pr.In "X-Forwarded-For" header + // values need to be copied manually. + // https://pkg.go.dev/net/http/httputil#ProxyRequest.SetXForwarded if prior, ok := pr.In.Header["X-Forwarded-For"]; ok { pr.Out.Header["X-Forwarded-For"] = append([]string(nil), prior...) } @@ -90,22 +92,26 @@ func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov.InjectAuthHeader(&pr.Out.Header) } -// newInvalidBaseURLHandler returns a handler that always returns 502 because -// the provider's base URL is invalid. +// newInvalidBaseURLHandler returns a handler that always returns 502 +// when the provider's base URL is invalid. func newInvalidBaseURLHandler(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer, baseURLErr error) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + ctx, span := startSpan(r, tracer) + defer span.End() + if m != nil { m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1) } - ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( - attribute.String(tracing.PassthroughURL, r.URL.String()), - attribute.String(tracing.PassthroughMethod, r.Method), - )) - defer span.End() - logger.Warn(ctx, "invalid provider base URL", slog.Error(baseURLErr)) http.Error(w, "invalid provider base URL", http.StatusBadGateway) span.SetStatus(codes.Error, "invalid provider base URL: "+baseURLErr.Error()) } } + +func startSpan(r *http.Request, tracer trace.Tracer) (context.Context, trace.Span) { + return tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( + attribute.String(tracing.PassthroughURL, r.URL.String()), + attribute.String(tracing.PassthroughMethod, r.Method), + )) +} diff --git a/passthrough_test.go b/passthrough_test.go index 9376892a..b0567c89 100644 --- a/passthrough_test.go +++ b/passthrough_test.go @@ -140,7 +140,6 @@ func TestRewritePassthroughRequest(t *testing.T) { reqRemoteAddr string reqHeaders http.Header reqTLS bool - baseURL string provider *testutil.MockProvider expectURL string expectHeaders http.Header @@ -149,7 +148,6 @@ func TestRewritePassthroughRequest(t *testing.T) { name: "sets_upstream_url_and_forwarded_headers_from_client_peer", reqPath: "http://client-host/chat?stream=true", reqRemoteAddr: "1.1.1.1:1111", - baseURL: "https://upstream-host/base", provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, expectURL: "https://upstream-host/base/chat?stream=true", expectHeaders: http.Header{ @@ -164,7 +162,6 @@ func TestRewritePassthroughRequest(t *testing.T) { reqPath: "http://client-host/chat", reqRemoteAddr: "1.1.1.1:1111", reqHeaders: http.Header{"User-Agent": {"custom-agent/1.0"}}, - baseURL: "https://upstream-host/base", provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, expectURL: "https://upstream-host/base/chat", expectHeaders: http.Header{ @@ -178,7 +175,6 @@ func TestRewritePassthroughRequest(t *testing.T) { name: "injects_auth_header", reqPath: "http://client-host/chat", reqRemoteAddr: "1.1.1.1:1111", - baseURL: "https://upstream-host/base", provider: &testutil.MockProvider{ URL: "https://upstream-host/base", InjectAuthHeaderFunc: func(h *http.Header) { @@ -201,7 +197,6 @@ func TestRewritePassthroughRequest(t *testing.T) { reqHeaders: http.Header{ "X-Forwarded-For": {"2.2.2.2, 3.3.3.3"}, }, - baseURL: "https://upstream-host/base", provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, expectURL: "https://upstream-host/base/chat", expectHeaders: http.Header{ @@ -211,6 +206,20 @@ func TestRewritePassthroughRequest(t *testing.T) { "User-Agent": {"aibridge"}, }, }, + { + name: "tls_request_sets_forwarded_proto_to_https", + reqPath: "http://client-host/chat", + reqRemoteAddr: "1.1.1.1:1111", + reqTLS: true, + provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, + expectURL: "https://upstream-host/base/chat", + expectHeaders: http.Header{ + "X-Forwarded-Host": {"client-host"}, + "X-Forwarded-Proto": {"https"}, + "X-Forwarded-For": {"1.1.1.1"}, + "User-Agent": {"aibridge"}, + }, + }, { name: "omits_forwarded_for_when_remote_addr_is_not_parseable", reqPath: "http://client-host/chat", @@ -218,7 +227,6 @@ func TestRewritePassthroughRequest(t *testing.T) { reqHeaders: http.Header{ "X-Forwarded-For": {"1.1.1.1"}, }, - baseURL: "https://upstream-host/base", provider: &testutil.MockProvider{URL: "https://upstream-host/base"}, expectURL: "https://upstream-host/base/chat", expectHeaders: http.Header{ @@ -239,7 +247,7 @@ func TestRewritePassthroughRequest(t *testing.T) { if tc.reqTLS { r.TLS = &tls.ConnectionState{} } - provBaseURL, err := url.Parse(tc.baseURL) + provBaseURL, err := url.Parse(tc.provider.URL) assert.NoError(t, err) pr := &httputil.ProxyRequest{ From f50ea084e457a6b67e45c48e79ffd9a3eb016135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 17 Apr 2026 10:20:59 +0000 Subject: [PATCH 3/3] review 2: added comment about not parsable remote addr edge case --- passthrough_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/passthrough_test.go b/passthrough_test.go index b0567c89..702098be 100644 --- a/passthrough_test.go +++ b/passthrough_test.go @@ -221,6 +221,11 @@ func TestRewritePassthroughRequest(t *testing.T) { }, }, { + // This is an edge case where whole `X-Forwarded-For` header + // is dropped if last hop (remote addr) is not parseable. + // This is how library handles this case and is not directly + // related to our code. Added it to verify that we + // don't accidentally break this behavior. name: "omits_forwarded_for_when_remote_addr_is_not_parseable", reqPath: "http://client-host/chat", reqRemoteAddr: "not-a-socket-address",