Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions internal/testutil/mockprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,26 @@ import (
)

type MockProvider struct {
NameStr string
URL string
Bridged []string
Passthrough []string
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
NameStr string
URL string
Bridged []string
Passthrough []string
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
InjectAuthHeaderFunc func(h *http.Header)
}

func (m *MockProvider) Type() string { return m.NameStr }
func (m *MockProvider) Name() string { return m.NameStr }
func (m *MockProvider) BaseURL() string { return m.URL }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (*MockProvider) AuthHeader() string { return "Authorization" }
func (*MockProvider) InjectAuthHeader(_ *http.Header) {}
func (m *MockProvider) Type() string { return m.NameStr }
func (m *MockProvider) Name() string { return m.NameStr }
func (m *MockProvider) BaseURL() string { return m.URL }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (*MockProvider) AuthHeader() string { return "Authorization" }
func (m *MockProvider) InjectAuthHeader(h *http.Header) {
if m.InjectAuthHeaderFunc != nil {
m.InjectAuthHeaderFunc(h)
}
}
func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
func (*MockProvider) APIDumpDir() string { return "" }
func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) {
Expand Down
163 changes: 80 additions & 83 deletions passthrough.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package aibridge

import (
"net"
"context"
"net/http"
"net/http/httputil"
"net/url"
Expand All @@ -21,100 +21,97 @@ import (

// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
// by a [intercept.Provider].
// A single reverse proxy is created per provider and reused across all requests.
func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
provBaseURL, err := url.Parse(prov.BaseURL())
if err != nil {
return newInvalidBaseURLHandler(prov, logger, m, tracer, err)
}
if _, err := url.JoinPath(provBaseURL.Path, "/"); err != nil {
Comment thread
ssncferreira marked this conversation as resolved.
return newInvalidBaseURLHandler(prov, logger, m, tracer, err)
}

// Transport tuned for streaming (no response header timeout).
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}

// Build a reverse proxy to the upstream, reused across all requests for this provider.
// All request modifications happen in Rewrite.
proxy := &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
rewritePassthroughRequest(pr, provBaseURL, prov)
},
Transport: apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()),
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
},
}

return func(w http.ResponseWriter, r *http.Request) {
if m != nil {
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
}

ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
attribute.String(tracing.PassthroughURL, r.URL.String()),
attribute.String(tracing.PassthroughMethod, r.Method),
))
ctx, span := startSpan(r, tracer)
defer span.End()

upURL, err := url.Parse(prov.BaseURL())
if err != nil {
logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err))
http.Error(w, "request error", http.StatusBadGateway)
span.SetStatus(codes.Error, "failed to parse provider base URL: "+err.Error())
return
}
proxy.ServeHTTP(w, r.WithContext(ctx))
}
}

// Append the request path to the upstream base path.
reqPath, err := url.JoinPath(upURL.Path, r.URL.Path)
if err != nil {
logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstream_path", upURL.Path), slog.F("request_path", r.URL.Path))
http.Error(w, "failed to join upstream path", http.StatusInternalServerError)
span.SetStatus(codes.Error, "failed to join upstream path: "+err.Error())
return
}
// Ensure leading slash, proxied requests should have absolute paths.
// JoinPath can return relative paths, eg. when upURL path is empty.
if len(reqPath) == 0 || reqPath[0] != '/' {
reqPath = "/" + reqPath
}
// rewritePassthroughRequest configures the outbound request for the upstream and
// applies proxy headers and provider auth.
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov provider.Provider) {
pr.SetURL(provBaseURL)

// Build a reverse proxy to the upstream.
proxy := &httputil.ReverseProxy{
Director: func(req *http.Request) {
// Set scheme/host to upstream.
req.URL.Scheme = upURL.Scheme
req.URL.Host = upURL.Host
req.URL.Path = reqPath
req.URL.RawPath = ""

// Preserve query string.
req.URL.RawQuery = r.URL.RawQuery

// Set Host header for upstream.
req.Host = upURL.Host
span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, req.URL.String()))

// Copy headers from client.
req.Header = r.Header.Clone()

// Standard proxy headers.
host, _, herr := net.SplitHostPort(r.RemoteAddr)
if herr != nil {
host = r.RemoteAddr
}
if prior := req.Header.Get("X-Forwarded-For"); prior != "" {
req.Header.Set("X-Forwarded-For", prior+", "+host)
} else {
req.Header.Set("X-Forwarded-For", host)
}
req.Header.Set("X-Forwarded-Host", r.Host)
if r.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
} else {
req.Header.Set("X-Forwarded-Proto", "http")
}
// Avoid default Go user-agent if none provided.
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
}

// Inject provider auth.
prov.InjectAuthHeader(&req.Header)
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
},
}
// Rewrite sets "X-Forwarded-For" to just last hop (clients IP address).
// To preserve old Director behavior pr.In "X-Forwarded-For" header
// values need to be copied manually.
// https://pkg.go.dev/net/http/httputil#ProxyRequest.SetXForwarded
if prior, ok := pr.In.Header["X-Forwarded-For"]; ok {
pr.Out.Header["X-Forwarded-For"] = append([]string(nil), prior...)
}
pr.SetXForwarded()
Comment thread
ssncferreira marked this conversation as resolved.

span := trace.SpanFromContext(pr.Out.Context())
span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, pr.Out.URL.String()))

// Avoid default Go user-agent if none provided.
if _, ok := pr.Out.Header["User-Agent"]; !ok {
pr.Out.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
}

// Transport tuned for streaming (no response header timeout).
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// Inject provider auth.
prov.InjectAuthHeader(&pr.Out.Header)
}

// newInvalidBaseURLHandler returns a handler that always returns 502
// when the provider's base URL is invalid.
func newInvalidBaseURLHandler(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer, baseURLErr error) http.HandlerFunc {
Comment thread
ssncferreira marked this conversation as resolved.
return func(w http.ResponseWriter, r *http.Request) {
ctx, span := startSpan(r, tracer)
defer span.End()

if m != nil {
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
}
proxy.Transport = apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal())

proxy.ServeHTTP(w, r)
logger.Warn(ctx, "invalid provider base URL", slog.Error(baseURLErr))
http.Error(w, "invalid provider base URL", http.StatusBadGateway)
span.SetStatus(codes.Error, "invalid provider base URL: "+baseURLErr.Error())
}
}

func startSpan(r *http.Request, tracer trace.Tracer) (context.Context, trace.Span) {
return tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
attribute.String(tracing.PassthroughURL, r.URL.String()),
attribute.String(tracing.PassthroughMethod, r.Method),
))
}
Loading
Loading