From 8190b506db4461f2850e1a54f3e98d3f67f9bc65 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Sat, 28 Mar 2026 18:18:26 +0000 Subject: [PATCH] feat: introduce provider upstream concept and per-upstream routing --- bridge.go | 30 +++++++++++-------- circuitbreaker/circuitbreaker.go | 20 ++++++------- circuitbreaker/circuitbreaker_test.go | 16 +++++----- intercept/chatcompletions/base.go | 15 +++++----- intercept/chatcompletions/blocking.go | 6 ++-- intercept/chatcompletions/streaming.go | 6 ++-- intercept/chatcompletions/streaming_test.go | 3 +- intercept/interceptor.go | 10 +++++++ intercept/messages/base.go | 8 ++--- intercept/messages/blocking.go | 4 +-- intercept/messages/streaming.go | 4 +-- intercept/responses/base.go | 11 ++++--- intercept/responses/blocking.go | 6 ++-- intercept/responses/streaming.go | 6 ++-- .../integrationtest/circuit_breaker_test.go | 28 ++++++++--------- internal/integrationtest/metrics_test.go | 24 +++++++-------- internal/testutil/mockprovider.go | 9 ++++-- metrics/metrics.go | 8 ++--- passthrough.go | 3 +- provider/anthropic.go | 10 +++++-- provider/copilot.go | 14 ++++++--- provider/openai.go | 14 ++++++--- provider/provider.go | 4 +++ recorder/recorder.go | 17 +++++++---- recorder/types.go | 1 + 25 files changed, 157 insertions(+), 120 deletions(-) diff --git a/bridge.go b/bridge.go index aeb73af7..d4a70e46 100644 --- a/bridge.go +++ b/bridge.go @@ -73,18 +73,19 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // Create per-provider circuit breaker if configured cfg := prov.CircuitBreakerConfig() providerName := prov.Name() - onChange := func(providerName, endpoint, model string, from, to gobreaker.State) { + onChange := func(upstreamName, endpoint, model string, from, to gobreaker.State) { logger.Info(context.Background(), "circuit breaker state change", slog.F("provider", providerName), + slog.F("upstream", upstreamName), slog.F("endpoint", endpoint), slog.F("model", model), slog.F("from", from.String()), slog.F("to", to.String()), ) if m != nil { - m.CircuitBreakerState.WithLabelValues(providerName, endpoint, model).Set(circuitbreaker.StateToGaugeValue(to)) + m.CircuitBreakerState.WithLabelValues(providerName, upstreamName, endpoint, model).Set(circuitbreaker.StateToGaugeValue(to)) if to == gobreaker.StateOpen { - m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint, model).Inc() + m.CircuitBreakerTrips.WithLabelValues(providerName, upstreamName, endpoint, model).Inc() } } } @@ -165,12 +166,12 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC return } - providerName := interceptor.ProviderName() + upstreamName := interceptor.ProviderName() if m != nil { start := time.Now() defer func() { - m.InterceptionDuration.WithLabelValues(providerName, interceptor.Model()).Observe(time.Since(start).Seconds()) + m.InterceptionDuration.WithLabelValues(p.Name(), upstreamName, interceptor.Model()).Observe(time.Since(start).Seconds()) }() } @@ -189,7 +190,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC // Record usage in the background to not block request flow. asyncRecorder := recorder.NewAsyncRecorder(logger, rec, recordingTimeout) asyncRecorder.WithMetrics(m) - asyncRecorder.WithProvider(providerName) + asyncRecorder.WithProvider(p.Name()) + asyncRecorder.WithUpstream(upstreamName) asyncRecorder.WithModel(interceptor.Model()) asyncRecorder.WithInitiatorID(actor.ID) asyncRecorder.WithClient(string(client)) @@ -200,7 +202,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC InitiatorID: actor.ID, Metadata: actor.Metadata, Model: interceptor.Model(), - Provider: providerName, + Provider: p.Name(), + Upstream: upstreamName, UserAgent: r.UserAgent(), Client: string(client), ClientSessionID: sessionID, @@ -215,7 +218,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) log := logger.With( slog.F("route", route), - slog.F("provider", providerName), + slog.F("provider", p.Name()), + slog.F("upstream", upstreamName), slog.F("interception_id", interceptor.ID()), slog.F("user_agent", r.UserAgent()), slog.F("streaming", interceptor.Streaming()), @@ -223,24 +227,24 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC log.Debug(ctx, "interception started") if m != nil { - m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Add(1) + m.InterceptionsInflight.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), route).Add(1) defer func() { - m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Sub(1) + m.InterceptionsInflight.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), route).Sub(1) }() } // Process request with circuit breaker protection if configured - if err := cbs.Execute(providerName, route, interceptor.Model(), w, func(rw http.ResponseWriter) error { + if err := cbs.Execute(upstreamName, route, interceptor.Model(), w, func(rw http.ResponseWriter) error { return interceptor.ProcessRequest(rw, r) }); err != nil { if m != nil { - m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1) + m.InterceptionCount.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1) } span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) log.Warn(ctx, "interception failed", slog.Error(err)) } else { if m != nil { - m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1) + m.InterceptionCount.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1) } log.Debug(ctx, "interception ended") } diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index a3f06a7b..f3fb3d1b 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -33,8 +33,8 @@ func DefaultIsFailure(statusCode int) bool { type ProviderCircuitBreakers struct { provider string config config.CircuitBreaker - breakers sync.Map // "providerName:endpoint:model" -> *gobreaker.CircuitBreaker[struct{}] - onChange func(providerName, endpoint, model string, from, to gobreaker.State) + breakers sync.Map // "upstreamName:endpoint:model" -> *gobreaker.CircuitBreaker[struct{}] + onChange func(upstreamName, endpoint, model string, from, to gobreaker.State) metrics *metrics.Metrics } @@ -42,7 +42,7 @@ type ProviderCircuitBreakers struct { // Returns nil if cfg is nil (no circuit breaker protection). // onChange is called when circuit state changes. // metrics is used to record circuit breaker reject counts (can be nil). -func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(providerName, endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers { +func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(upstreamName, endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers { if cfg == nil { return nil } @@ -71,9 +71,9 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte { return []byte(`{"error":"circuit breaker is open"}`) } -// Get returns the circuit breaker for a providerName/endpoint/model tuple, creating it if needed. -func (p *ProviderCircuitBreakers) Get(providerName, endpoint, model string) *gobreaker.CircuitBreaker[struct{}] { - key := providerName + ":" + endpoint + ":" + model +// Get returns the circuit breaker for an upstreamName/endpoint/model tuple, creating it if needed. +func (p *ProviderCircuitBreakers) Get(upstreamName, endpoint, model string) *gobreaker.CircuitBreaker[struct{}] { + key := upstreamName + ":" + endpoint + ":" + model if v, ok := p.breakers.Load(key); ok { return v.(*gobreaker.CircuitBreaker[struct{}]) } @@ -88,7 +88,7 @@ func (p *ProviderCircuitBreakers) Get(providerName, endpoint, model string) *gob }, OnStateChange: func(_ string, from, to gobreaker.State) { if p.onChange != nil { - p.onChange(providerName, endpoint, model, from, to) + p.onChange(upstreamName, endpoint, model, from, to) } }, } @@ -139,12 +139,12 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { // Otherwise, it returns the handler's error (or nil on success). // The handler receives a wrapped ResponseWriter that captures the status code. // If the receiver is nil (no circuit breaker configured), the handler is called directly. -func (p *ProviderCircuitBreakers) Execute(providerName, endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error { +func (p *ProviderCircuitBreakers) Execute(upstreamName, endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error { if p == nil { return handler(w) } - cb := p.Get(providerName, endpoint, model) + cb := p.Get(upstreamName, endpoint, model) // Wrap response writer to capture status code sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} @@ -160,7 +160,7 @@ func (p *ProviderCircuitBreakers) Execute(providerName, endpoint, model string, if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { if p.metrics != nil { - p.metrics.CircuitBreakerRejects.WithLabelValues(providerName, endpoint, model).Inc() + p.metrics.CircuitBreakerRejects.WithLabelValues(p.provider, upstreamName, endpoint, model).Inc() } w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 96eafe8d..b5c0f6f9 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -24,7 +24,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil) + }, func(upstreamName, endpoint, model string, from, to gobreaker.State) {}, nil) endpoint := "/v1/messages" sonnetModel := "claude-sonnet-4-20250514" @@ -73,7 +73,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil) + }, func(upstreamName, endpoint, model string, from, to gobreaker.State) {}, nil) model := "test-model" @@ -123,7 +123,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { IsFailure: func(statusCode int) bool { return statusCode == http.StatusBadGateway }, - }, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil) + }, func(upstreamName, endpoint, model string, from, to gobreaker.State) {}, nil) // First request returns 502, trips circuit w := httptest.NewRecorder() @@ -151,7 +151,7 @@ func TestExecute_OnStateChange(t *testing.T) { t.Parallel() var stateChanges []struct { - providerName string + upstreamName string endpoint string model string from gobreaker.State @@ -163,14 +163,14 @@ func TestExecute_OnStateChange(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(providerName, endpoint, model string, from, to gobreaker.State) { + }, func(upstreamName, endpoint, model string, from, to gobreaker.State) { stateChanges = append(stateChanges, struct { - providerName string + upstreamName string endpoint string model string from gobreaker.State to gobreaker.State - }{providerName, endpoint, model, from, to}) + }{upstreamName, endpoint, model, from, to}) }, nil) endpoint := "/v1/messages" @@ -185,7 +185,7 @@ func TestExecute_OnStateChange(t *testing.T) { // Verify state change callback was called with correct parameters assert.Len(t, stateChanges, 1) - assert.Equal(t, config.ProviderAnthropic, stateChanges[0].providerName) + assert.Equal(t, config.ProviderAnthropic, stateChanges[0].upstreamName) assert.Equal(t, endpoint, stateChanges[0].endpoint) assert.Equal(t, model, stateChanges[0].model) assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 73456988..5aa378a4 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -26,12 +26,11 @@ import ( ) type interceptionBase struct { - id uuid.UUID - providerName string - baseURL string - apiDumpDir string - req *ChatCompletionNewParamsWrapper - cfg config.OpenAIInterceptor + id uuid.UUID + upstream intercept.ResolvedUpstream + apiDumpDir string + req *ChatCompletionNewParamsWrapper + cfg config.OpenAIInterceptor // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header @@ -45,7 +44,7 @@ type interceptionBase struct { } func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { - opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.baseURL)} + opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.upstream.URL)} // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. @@ -77,7 +76,7 @@ func (i *interceptionBase) ID() uuid.UUID { } func (i *interceptionBase) ProviderName() string { - return i.providerName + return i.upstream.Name } func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 5644726c..efdfc774 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -31,8 +31,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, - providerName string, - baseURL string, + upstream intercept.ResolvedUpstream, apiDumpDir string, cfg config.OpenAIInterceptor, clientHeaders http.Header, @@ -41,8 +40,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, - providerName: providerName, - baseURL: baseURL, + upstream: upstream, apiDumpDir: apiDumpDir, req: req, cfg: cfg, diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 144888c6..d8871197 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -36,8 +36,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, - providerName string, - baseURL string, + upstream intercept.ResolvedUpstream, apiDumpDir string, cfg config.OpenAIInterceptor, clientHeaders http.Header, @@ -46,8 +45,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, - providerName: providerName, - baseURL: baseURL, + upstream: upstream, apiDumpDir: apiDumpDir, req: req, cfg: cfg, diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 7b6214a1..d7bcd889 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge/config" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/internal/testutil" "github.com/google/uuid" "github.com/openai/openai-go/v3" @@ -85,7 +86,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, mockServer.URL, "", cfg, httpReq.Header, "Authorization", tracer) + interceptor := NewStreamingInterceptor(uuid.New(), req, intercept.ResolvedUpstream{Name: config.ProviderOpenAI, URL: mockServer.URL}, "", cfg, httpReq.Header, "Authorization", tracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) diff --git a/intercept/interceptor.go b/intercept/interceptor.go index e8db619a..f4f4a60c 100644 --- a/intercept/interceptor.go +++ b/intercept/interceptor.go @@ -37,3 +37,13 @@ type Interceptor interface { // parent interception. CorrelatingToolCallID() *string } + +// ResolvedUpstream represents the resolved upstream destination for a request. +// Providers resolve this per-request, allowing a single provider to route to +// different upstreams based on request context. +type ResolvedUpstream struct { + // Name identifies the upstream provider domain. + Name string + // URL is the URL of the upstream provider. + URL string +} diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 214fb968..b4326277 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -63,9 +63,9 @@ var bedrockSupportedBetaFlags = map[string]bool{ } type interceptionBase struct { - id uuid.UUID - providerName string - reqPayload MessagesRequestPayload + id uuid.UUID + upstream intercept.ResolvedUpstream + reqPayload MessagesRequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -86,7 +86,7 @@ func (i *interceptionBase) ID() uuid.UUID { } func (i *interceptionBase) ProviderName() string { - return i.providerName + return i.upstream.Name } func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 3072b125..7dcc6aa6 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -31,7 +31,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, reqPayload MessagesRequestPayload, - providerName string, + upstream intercept.ResolvedUpstream, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -40,7 +40,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, - providerName: providerName, + upstream: upstream, reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 395c6021..a028f0e8 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -37,7 +37,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, reqPayload MessagesRequestPayload, - providerName string, + upstream intercept.ResolvedUpstream, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -46,7 +46,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, - providerName: providerName, + upstream: upstream, reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, diff --git a/intercept/responses/base.go b/intercept/responses/base.go index e63d314c..76427a74 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -36,10 +36,9 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID - providerName string - baseURL string - apiDumpDir string + id uuid.UUID + upstream intercept.ResolvedUpstream + apiDumpDir string // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header authHeaderName string @@ -54,7 +53,7 @@ type responsesInterceptionBase struct { } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { - opts := []option.RequestOption{option.WithBaseURL(i.baseURL), option.WithAPIKey(i.cfg.Key)} + opts := []option.RequestOption{option.WithBaseURL(i.upstream.URL), option.WithAPIKey(i.cfg.Key)} // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. @@ -86,7 +85,7 @@ func (i *responsesInterceptionBase) ID() uuid.UUID { } func (i *responsesInterceptionBase) ProviderName() string { - return i.providerName + return i.upstream.Name } func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index f2e0a55c..bfdb5d57 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -28,8 +28,7 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, - providerName string, - baseURL string, + upstream intercept.ResolvedUpstream, apiDumpDir string, cfg config.OpenAIInterceptor, clientHeaders http.Header, @@ -39,8 +38,7 @@ func NewBlockingInterceptor( return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, - providerName: providerName, - baseURL: baseURL, + upstream: upstream, apiDumpDir: apiDumpDir, reqPayload: reqPayload, cfg: cfg, diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 606c6ee2..2d295a25 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -35,8 +35,7 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, - providerName string, - baseURL string, + upstream intercept.ResolvedUpstream, apiDumpDir string, cfg config.OpenAIInterceptor, clientHeaders http.Header, @@ -46,8 +45,7 @@ func NewStreamingInterceptor( return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, - providerName: providerName, - baseURL: baseURL, + upstream: upstream, apiDumpDir: apiDumpDir, reqPayload: reqPayload, cfg: cfg, diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 4e392649..48e5e2ef 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -157,13 +157,13 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") // Verify metrics show circuit is open - trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") - state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") - rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") // Phase 3: Wait for timeout to transition to half-open @@ -179,7 +179,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Verify circuit is now closed - state = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed) after recovery") // Phase 5: Verify circuit is fully functional again @@ -193,7 +193,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+4, upstreamCalls.Load(), "All requests should reach upstream after circuit closes") // Rejects count should not have increased - rejects = promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects = promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should still be 1 (no new rejects)") }) } @@ -305,7 +305,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { resp := doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") // Phase 2: Wait for half-open state @@ -323,10 +323,10 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") // Verify metrics: trips should be 2 now (tripped twice) - trips = promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips = promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 2.0, trips, "CircuitBreakerTrips should be 2 after half-open failure") - state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open) after half-open failure") }) } @@ -495,7 +495,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { "%d requests should be rejected (ErrTooManyRequests)", totalRequests-maxRequests) // Verify rejects metric increased - rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, float64(1+totalRequests-maxRequests), rejects, "CircuitBreakerRejects should include half-open rejections") }) @@ -577,10 +577,10 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open") // Verify sonnet metrics show circuit is open - sonnetTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + sonnetTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) assert.Equal(t, 1.0, sonnetTrips, "Sonnet CircuitBreakerTrips should be 1") - sonnetState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + sonnetState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)") // Phase 2: Haiku model should still work (independent circuit) @@ -596,10 +596,10 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream") // Verify haiku circuit is still closed (no trips) - haikuTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) + haikuTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) assert.Equal(t, 0.0, haikuTrips, "Haiku CircuitBreakerTrips should be 0") - haikuState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) + haikuState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) assert.Equal(t, 0.0, haikuState, "Haiku CircuitBreakerState should be 0 (closed)") // Phase 3: Sonnet recovers after timeout @@ -610,6 +610,6 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout") // Verify sonnet circuit is now closed - sonnetState = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + sonnetState = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) assert.Equal(t, 0.0, sonnetState, "Sonnet CircuitBreakerState should be 0 (closed) after recovery") } diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index 69e4eea5..a84657b1 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -158,7 +158,7 @@ func TestMetrics_Interception(t *testing.T) { require.NoError(t, err) count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues( - tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", defaultActorID, string(tc.expectClient))) + tc.expectProvider, tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", defaultActorID, string(tc.expectClient))) require.Equal(t, 1.0, count) require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionDuration)) require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionCount)) @@ -204,7 +204,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { // Wait until request is detected as inflight. require.Eventually(t, func() bool { return promtest.ToFloat64( - m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 1 }, time.Second*10, time.Millisecond*50) @@ -219,7 +219,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { // Metric is not updated immediately after request completes, so wait until it is. require.Eventually(t, func() bool { return promtest.ToFloat64( - m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 0 }, time.Second*10, time.Millisecond*50) } @@ -263,7 +263,7 @@ func TestMetrics_PromptCount(t *testing.T) { require.NoError(t, err) prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues( - config.ProviderOpenAI, "gpt-4.1", defaultActorID, string(aibridge.ClientClaudeCode))) + config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", defaultActorID, string(aibridge.ClientClaudeCode))) require.Equal(t, 1.0, prompts) } @@ -290,16 +290,16 @@ func TestMetrics_TokenUseCount(t *testing.T) { // Token metrics are recorded asynchronously; wait for them to appear. require.Eventually(t, func() bool { return promtest.ToFloat64(m.TokenUseCount.WithLabelValues( - config.ProviderOpenAI, "gpt-4.1", "input", defaultActorID, clientLabel)) > 0 + config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", "input", defaultActorID, clientLabel)) > 0 }, time.Second*10, time.Millisecond*50) - require.Equal(t, 129.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, "gpt-4.1", "input", defaultActorID, clientLabel))) // 12033 - 11904 (cached) - require.Equal(t, 44.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, "gpt-4.1", "output", defaultActorID, clientLabel))) + require.Equal(t, 129.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", "input", defaultActorID, clientLabel))) // 12033 - 11904 (cached) + require.Equal(t, 44.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", "output", defaultActorID, clientLabel))) // ExtraTokenTypes - require.Equal(t, 11904.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, "gpt-4.1", "input_cached", defaultActorID, clientLabel))) - require.Equal(t, 0.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, "gpt-4.1", "output_reasoning", defaultActorID, clientLabel))) - require.Equal(t, 12077.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, "gpt-4.1", "total_tokens", defaultActorID, clientLabel))) + require.Equal(t, 11904.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", "input_cached", defaultActorID, clientLabel))) + require.Equal(t, 0.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", "output_reasoning", defaultActorID, clientLabel))) + require.Equal(t, 12077.0, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", "total_tokens", defaultActorID, clientLabel))) } func TestMetrics_NonInjectedToolUseCount(t *testing.T) { @@ -322,7 +322,7 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { require.NoError(t, err) count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues( - config.ProviderOpenAI, "gpt-4.1", "read_file")) + config.ProviderOpenAI, config.ProviderOpenAI, "gpt-4.1", "read_file")) require.Equal(t, 1.0, count) } @@ -363,6 +363,6 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { actualServerURL := *recorder.ToolUsages()[0].ServerURL count := promtest.ToFloat64(m.InjectedToolUseCount.WithLabelValues( - config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) + config.ProviderAnthropic, config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) require.Equal(t, 1.0, count) } diff --git a/internal/testutil/mockprovider.go b/internal/testutil/mockprovider.go index a21ac6a4..c7cb722d 100644 --- a/internal/testutil/mockprovider.go +++ b/internal/testutil/mockprovider.go @@ -17,9 +17,12 @@ type MockProvider struct { InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) } -func (m *MockProvider) Name() string { return m.Name_ } -func (m *MockProvider) BaseURL() string { return m.URL } -func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.Name_) } +func (m *MockProvider) Name() string { return m.Name_ } +func (m *MockProvider) BaseURL() string { return m.URL } +func (m *MockProvider) ResolveUpstream(_ *http.Request) intercept.ResolvedUpstream { + return intercept.ResolvedUpstream{Name: m.Name_, URL: m.URL} +} +func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.Name_) } func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough } func (m *MockProvider) AuthHeader() string { return "Authorization" } diff --git a/metrics/metrics.go b/metrics/metrics.go index 6d14ab9d..bc5ab763 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -5,7 +5,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) -var baseLabels []string = []string{"provider", "model"} +var baseLabels []string = []string{"provider", "upstream", "model"} const ( InterceptionCountStatusFailed = "failed" @@ -115,18 +115,18 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Subsystem: "circuit_breaker", Name: "state", Help: "Current state of the circuit breaker (0=closed, 0.5=half-open, 1=open).", - }, []string{"provider", "endpoint", "model"}), + }, []string{"provider", "upstream", "endpoint", "model"}), // Pessimistic cardinality: 3 providers, 2 endpoints, 5 models = up to 30. CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "trips_total", Help: "Total number of times the circuit breaker transitioned to open state.", - }, []string{"provider", "endpoint", "model"}), + }, []string{"provider", "upstream", "endpoint", "model"}), // Pessimistic cardinality: 3 providers, 2 endpoints, 5 models = up to 30. CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "rejects_total", Help: "Total number of requests rejected due to open circuit breaker.", - }, []string{"provider", "endpoint", "model"}), + }, []string{"provider", "upstream", "endpoint", "model"}), } } diff --git a/passthrough.go b/passthrough.go index c6b59edd..084b5623 100644 --- a/passthrough.go +++ b/passthrough.go @@ -32,7 +32,8 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met )) defer span.End() - upURL, err := url.Parse(provider.BaseURL()) + upstream := provider.ResolveUpstream(r) + upURL, err := url.Parse(upstream.URL) if err != nil { logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err)) http.Error(w, "request error", http.StatusBadGateway) diff --git a/provider/anthropic.go b/provider/anthropic.go index e980a163..769483ca 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -135,11 +135,13 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr // TODO(ssncferreira): when Bedrock is added as a separate provider, pass the // resolved provider name instead of p.Name() here. + upstream := p.ResolveUpstream(r) + var interceptor intercept.Interceptor if reqPayload.Stream() { - interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) + interceptor = messages.NewStreamingInterceptor(id, reqPayload, upstream, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) + interceptor = messages.NewBlockingInterceptor(id, reqPayload, upstream, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil @@ -153,6 +155,10 @@ func (p *Anthropic) BaseURL() string { return p.cfg.BaseURL } +func (p *Anthropic) ResolveUpstream(_ *http.Request) intercept.ResolvedUpstream { + return intercept.ResolvedUpstream{Name: p.Name(), URL: p.cfg.BaseURL} +} + func (p *Anthropic) AuthHeader() string { return "X-Api-Key" } diff --git a/provider/copilot.go b/provider/copilot.go index 00be8436..e414311d 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -75,6 +75,10 @@ func (p *Copilot) BaseURL() string { return p.cfg.BaseURL } +func (p *Copilot) ResolveUpstream(_ *http.Request) intercept.ResolvedUpstream { + return intercept.ResolvedUpstream{Name: p.Name(), URL: p.cfg.BaseURL} +} + func (p *Copilot) RoutePrefix() string { return fmt.Sprintf("/%s", p.Name()) } @@ -134,6 +138,8 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac ExtraHeaders: extractCopilotHeaders(r), } + upstream := p.ResolveUpstream(r) + var interceptor intercept.Interceptor path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) @@ -145,9 +151,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } case routeCopilotResponses: @@ -161,9 +167,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/openai.go b/provider/openai.go index ce841706..ad2401ea 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -115,6 +115,8 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace SendActorHeaders: p.cfg.SendActorHeaders, } + upstream := p.ResolveUpstream(r) + path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) switch path { case routeChatCompletions: @@ -124,9 +126,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } case routeResponses: @@ -139,9 +141,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, upstream, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } default: @@ -156,6 +158,10 @@ func (p *OpenAI) BaseURL() string { return p.cfg.BaseURL } +func (p *OpenAI) ResolveUpstream(_ *http.Request) intercept.ResolvedUpstream { + return intercept.ResolvedUpstream{Name: p.Name(), URL: p.cfg.BaseURL} +} + func (p *OpenAI) AuthHeader() string { return "Authorization" } diff --git a/provider/provider.go b/provider/provider.go index f2a70f18..de99e2c2 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -48,6 +48,10 @@ type Provider interface { Name() string // BaseURL defines the base URL endpoint for this provider's API. BaseURL() string + // ResolveUpstream returns the resolved upstream for the given request. + // Providers that support multiple upstreams can inspect the request to + // determine the correct destination. + ResolveUpstream(*http.Request) intercept.ResolvedUpstream // CreateInterceptor starts a new [Interceptor] which is responsible for intercepting requests, // communicating with the upstream provider and formulating a response to be sent to the requesting client. diff --git a/recorder/recorder.go b/recorder/recorder.go index c4f427c5..9bb0d576 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -150,6 +150,7 @@ type AsyncRecorder struct { metrics *metrics.Metrics provider string + upstream string model string initiatorID string client string @@ -171,6 +172,10 @@ func (a *AsyncRecorder) WithProvider(provider string) { a.provider = provider } +func (a *AsyncRecorder) WithUpstream(upstream string) { + a.upstream = upstream +} + func (a *AsyncRecorder) WithModel(model string) { a.model = model } @@ -218,7 +223,7 @@ func (a *AsyncRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageR } if a.metrics != nil && req.Prompt != "" { // TODO: will be irrelevant once https://github.com/coder/aibridge/issues/55 is fixed. - a.metrics.PromptCount.WithLabelValues(a.provider, a.model, a.initiatorID, a.client).Add(1) + a.metrics.PromptCount.WithLabelValues(a.provider, a.upstream, a.model, a.initiatorID, a.client).Add(1) } }() @@ -238,10 +243,10 @@ func (a *AsyncRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRec } if a.metrics != nil { - a.metrics.TokenUseCount.WithLabelValues(a.provider, a.model, "input", a.initiatorID, a.client).Add(float64(req.Input)) - a.metrics.TokenUseCount.WithLabelValues(a.provider, a.model, "output", a.initiatorID, a.client).Add(float64(req.Output)) + a.metrics.TokenUseCount.WithLabelValues(a.provider, a.upstream, a.model, "input", a.initiatorID, a.client).Add(float64(req.Input)) + a.metrics.TokenUseCount.WithLabelValues(a.provider, a.upstream, a.model, "output", a.initiatorID, a.client).Add(float64(req.Output)) for k, v := range req.ExtraTokenTypes { - a.metrics.TokenUseCount.WithLabelValues(a.provider, a.model, k, a.initiatorID, a.client).Add(float64(v)) + a.metrics.TokenUseCount.WithLabelValues(a.provider, a.upstream, a.model, k, a.initiatorID, a.client).Add(float64(v)) } } }() @@ -267,9 +272,9 @@ func (a *AsyncRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecor if req.ServerURL != nil { srvURL = *req.ServerURL } - a.metrics.InjectedToolUseCount.WithLabelValues(a.provider, a.model, srvURL, req.Tool).Add(1) + a.metrics.InjectedToolUseCount.WithLabelValues(a.provider, a.upstream, a.model, srvURL, req.Tool).Add(1) } else { - a.metrics.NonInjectedToolUseCount.WithLabelValues(a.provider, a.model, req.Tool).Add(1) + a.metrics.NonInjectedToolUseCount.WithLabelValues(a.provider, a.upstream, a.model, req.Tool).Add(1) } } }() diff --git a/recorder/types.go b/recorder/types.go index 20e735f4..eb0176d3 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -33,6 +33,7 @@ type InterceptionRecord struct { Metadata Metadata Model string Provider string + Upstream string StartedAt time.Time ClientSessionID *string Client string