Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,19 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
// Create per-provider circuit breaker if configured
cfg := prov.CircuitBreakerConfig()
providerName := prov.Name()
onChange := func(providerName, endpoint, model string, from, to gobreaker.State) {
onChange := func(upstreamName, endpoint, model string, from, to gobreaker.State) {
logger.Info(context.Background(), "circuit breaker state change",
slog.F("provider", providerName),
slog.F("upstream", upstreamName),
slog.F("endpoint", endpoint),
slog.F("model", model),
slog.F("from", from.String()),
slog.F("to", to.String()),
)
if m != nil {
m.CircuitBreakerState.WithLabelValues(providerName, endpoint, model).Set(circuitbreaker.StateToGaugeValue(to))
m.CircuitBreakerState.WithLabelValues(providerName, upstreamName, endpoint, model).Set(circuitbreaker.StateToGaugeValue(to))
if to == gobreaker.StateOpen {
m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint, model).Inc()
m.CircuitBreakerTrips.WithLabelValues(providerName, upstreamName, endpoint, model).Inc()
}
}
}
Expand Down Expand Up @@ -165,12 +166,12 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
return
}

providerName := interceptor.ProviderName()
upstreamName := interceptor.ProviderName()

if m != nil {
start := time.Now()
defer func() {
m.InterceptionDuration.WithLabelValues(providerName, interceptor.Model()).Observe(time.Since(start).Seconds())
m.InterceptionDuration.WithLabelValues(p.Name(), upstreamName, interceptor.Model()).Observe(time.Since(start).Seconds())
}()
}

Expand All @@ -189,7 +190,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
// Record usage in the background to not block request flow.
asyncRecorder := recorder.NewAsyncRecorder(logger, rec, recordingTimeout)
asyncRecorder.WithMetrics(m)
asyncRecorder.WithProvider(providerName)
asyncRecorder.WithProvider(p.Name())
asyncRecorder.WithUpstream(upstreamName)
asyncRecorder.WithModel(interceptor.Model())
asyncRecorder.WithInitiatorID(actor.ID)
asyncRecorder.WithClient(string(client))
Expand All @@ -200,7 +202,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
InitiatorID: actor.ID,
Metadata: actor.Metadata,
Model: interceptor.Model(),
Provider: providerName,
Provider: p.Name(),
Upstream: upstreamName,
UserAgent: r.UserAgent(),
Client: string(client),
ClientSessionID: sessionID,
Expand All @@ -215,32 +218,33 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name()))
log := logger.With(
slog.F("route", route),
slog.F("provider", providerName),
slog.F("provider", p.Name()),
slog.F("upstream", upstreamName),
slog.F("interception_id", interceptor.ID()),
slog.F("user_agent", r.UserAgent()),
slog.F("streaming", interceptor.Streaming()),
)

log.Debug(ctx, "interception started")
if m != nil {
m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Add(1)
m.InterceptionsInflight.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), route).Add(1)
defer func() {
m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Sub(1)
m.InterceptionsInflight.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), route).Sub(1)
}()
}

// Process request with circuit breaker protection if configured
if err := cbs.Execute(providerName, route, interceptor.Model(), w, func(rw http.ResponseWriter) error {
if err := cbs.Execute(upstreamName, route, interceptor.Model(), w, func(rw http.ResponseWriter) error {
return interceptor.ProcessRequest(rw, r)
}); err != nil {
if m != nil {
m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1)
m.InterceptionCount.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1)
}
span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err))
log.Warn(ctx, "interception failed", slog.Error(err))
} else {
if m != nil {
m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1)
m.InterceptionCount.WithLabelValues(p.Name(), upstreamName, interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1)
}
log.Debug(ctx, "interception ended")
}
Expand Down
20 changes: 10 additions & 10 deletions circuitbreaker/circuitbreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ func DefaultIsFailure(statusCode int) bool {
type ProviderCircuitBreakers struct {
provider string
config config.CircuitBreaker
breakers sync.Map // "providerName:endpoint:model" -> *gobreaker.CircuitBreaker[struct{}]
onChange func(providerName, endpoint, model string, from, to gobreaker.State)
breakers sync.Map // "upstreamName:endpoint:model" -> *gobreaker.CircuitBreaker[struct{}]
onChange func(upstreamName, endpoint, model string, from, to gobreaker.State)
metrics *metrics.Metrics
}

// NewProviderCircuitBreakers creates circuit breakers for a single provider.
// Returns nil if cfg is nil (no circuit breaker protection).
// onChange is called when circuit state changes.
// metrics is used to record circuit breaker reject counts (can be nil).
func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(providerName, endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers {
func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(upstreamName, endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers {
if cfg == nil {
return nil
}
Expand Down Expand Up @@ -71,9 +71,9 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte {
return []byte(`{"error":"circuit breaker is open"}`)
}

// Get returns the circuit breaker for a providerName/endpoint/model tuple, creating it if needed.
func (p *ProviderCircuitBreakers) Get(providerName, endpoint, model string) *gobreaker.CircuitBreaker[struct{}] {
key := providerName + ":" + endpoint + ":" + model
// Get returns the circuit breaker for an upstreamName/endpoint/model tuple, creating it if needed.
func (p *ProviderCircuitBreakers) Get(upstreamName, endpoint, model string) *gobreaker.CircuitBreaker[struct{}] {
key := upstreamName + ":" + endpoint + ":" + model
if v, ok := p.breakers.Load(key); ok {
return v.(*gobreaker.CircuitBreaker[struct{}])
}
Expand All @@ -88,7 +88,7 @@ func (p *ProviderCircuitBreakers) Get(providerName, endpoint, model string) *gob
},
OnStateChange: func(_ string, from, to gobreaker.State) {
if p.onChange != nil {
p.onChange(providerName, endpoint, model, from, to)
p.onChange(upstreamName, endpoint, model, from, to)
}
},
}
Expand Down Expand Up @@ -139,12 +139,12 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter {
// Otherwise, it returns the handler's error (or nil on success).
// The handler receives a wrapped ResponseWriter that captures the status code.
// If the receiver is nil (no circuit breaker configured), the handler is called directly.
func (p *ProviderCircuitBreakers) Execute(providerName, endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error {
func (p *ProviderCircuitBreakers) Execute(upstreamName, endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error {
if p == nil {
return handler(w)
}

cb := p.Get(providerName, endpoint, model)
cb := p.Get(upstreamName, endpoint, model)

// Wrap response writer to capture status code
sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK}
Expand All @@ -160,7 +160,7 @@ func (p *ProviderCircuitBreakers) Execute(providerName, endpoint, model string,

if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) {
if p.metrics != nil {
p.metrics.CircuitBreakerRejects.WithLabelValues(providerName, endpoint, model).Inc()
p.metrics.CircuitBreakerRejects.WithLabelValues(p.provider, upstreamName, endpoint, model).Inc()
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds())))
Expand Down
16 changes: 8 additions & 8 deletions circuitbreaker/circuitbreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil)
}, func(upstreamName, endpoint, model string, from, to gobreaker.State) {}, nil)

endpoint := "/v1/messages"
sonnetModel := "claude-sonnet-4-20250514"
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil)
}, func(upstreamName, endpoint, model string, from, to gobreaker.State) {}, nil)

model := "test-model"

Expand Down Expand Up @@ -123,7 +123,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
IsFailure: func(statusCode int) bool {
return statusCode == http.StatusBadGateway
},
}, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil)
}, func(upstreamName, endpoint, model string, from, to gobreaker.State) {}, nil)

// First request returns 502, trips circuit
w := httptest.NewRecorder()
Expand Down Expand Up @@ -151,7 +151,7 @@ func TestExecute_OnStateChange(t *testing.T) {
t.Parallel()

var stateChanges []struct {
providerName string
upstreamName string
endpoint string
model string
from gobreaker.State
Expand All @@ -163,14 +163,14 @@ func TestExecute_OnStateChange(t *testing.T) {
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(providerName, endpoint, model string, from, to gobreaker.State) {
}, func(upstreamName, endpoint, model string, from, to gobreaker.State) {
stateChanges = append(stateChanges, struct {
providerName string
upstreamName string
endpoint string
model string
from gobreaker.State
to gobreaker.State
}{providerName, endpoint, model, from, to})
}{upstreamName, endpoint, model, from, to})
}, nil)

endpoint := "/v1/messages"
Expand All @@ -185,7 +185,7 @@ func TestExecute_OnStateChange(t *testing.T) {

// Verify state change callback was called with correct parameters
assert.Len(t, stateChanges, 1)
assert.Equal(t, config.ProviderAnthropic, stateChanges[0].providerName)
assert.Equal(t, config.ProviderAnthropic, stateChanges[0].upstreamName)
assert.Equal(t, endpoint, stateChanges[0].endpoint)
assert.Equal(t, model, stateChanges[0].model)
assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from)
Expand Down
15 changes: 7 additions & 8 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ import (
)

type interceptionBase struct {
id uuid.UUID
providerName string
baseURL string
apiDumpDir string
req *ChatCompletionNewParamsWrapper
cfg config.OpenAIInterceptor
id uuid.UUID
upstream intercept.ResolvedUpstream
apiDumpDir string
req *ChatCompletionNewParamsWrapper
cfg config.OpenAIInterceptor

// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
Expand All @@ -45,7 +44,7 @@ type interceptionBase struct {
}

func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.baseURL)}
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.upstream.URL)}

// Add extra headers if configured.
// Some providers require additional headers that are not added by the SDK.
Expand Down Expand Up @@ -77,7 +76,7 @@ func (i *interceptionBase) ID() uuid.UUID {
}

func (i *interceptionBase) ProviderName() string {
return i.providerName
return i.upstream.Name
}

func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
Expand Down
6 changes: 2 additions & 4 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ type BlockingInterception struct {
func NewBlockingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
baseURL string,
upstream intercept.ResolvedUpstream,
apiDumpDir string,
cfg config.OpenAIInterceptor,
clientHeaders http.Header,
Expand All @@ -41,8 +40,7 @@ func NewBlockingInterceptor(
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
baseURL: baseURL,
upstream: upstream,
apiDumpDir: apiDumpDir,
req: req,
cfg: cfg,
Expand Down
6 changes: 2 additions & 4 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ type StreamingInterception struct {
func NewStreamingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
baseURL string,
upstream intercept.ResolvedUpstream,
apiDumpDir string,
cfg config.OpenAIInterceptor,
clientHeaders http.Header,
Expand All @@ -46,8 +45,7 @@ func NewStreamingInterceptor(
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
baseURL: baseURL,
upstream: upstream,
apiDumpDir: apiDumpDir,
req: req,
cfg: cfg,
Expand Down
3 changes: 2 additions & 1 deletion intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/internal/testutil"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
Expand Down Expand Up @@ -85,7 +86,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, mockServer.URL, "", cfg, httpReq.Header, "Authorization", tracer)
interceptor := NewStreamingInterceptor(uuid.New(), req, intercept.ResolvedUpstream{Name: config.ProviderOpenAI, URL: mockServer.URL}, "", cfg, httpReq.Header, "Authorization", tracer)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
Expand Down
10 changes: 10 additions & 0 deletions intercept/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,13 @@ type Interceptor interface {
// parent interception.
CorrelatingToolCallID() *string
}

// ResolvedUpstream represents the resolved upstream destination for a request.
// Providers resolve this per-request, allowing a single provider to route to
// different upstreams based on request context.
type ResolvedUpstream struct {
// Name identifies the upstream provider domain.
Name string
// URL is the URL of the upstream provider.
URL string
}
8 changes: 4 additions & 4 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ var bedrockSupportedBetaFlags = map[string]bool{
}

type interceptionBase struct {
id uuid.UUID
providerName string
reqPayload MessagesRequestPayload
id uuid.UUID
upstream intercept.ResolvedUpstream
reqPayload MessagesRequestPayload

cfg aibconfig.Anthropic
bedrockCfg *aibconfig.AWSBedrock
Expand All @@ -86,7 +86,7 @@ func (i *interceptionBase) ID() uuid.UUID {
}

func (i *interceptionBase) ProviderName() string {
return i.providerName
return i.upstream.Name
}

func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
Expand Down
4 changes: 2 additions & 2 deletions intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type BlockingInterception struct {
func NewBlockingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
providerName string,
upstream intercept.ResolvedUpstream,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
clientHeaders http.Header,
Expand All @@ -40,7 +40,7 @@ func NewBlockingInterceptor(
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
upstream: upstream,
reqPayload: reqPayload,
cfg: cfg,
bedrockCfg: bedrockCfg,
Expand Down
4 changes: 2 additions & 2 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type StreamingInterception struct {
func NewStreamingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
providerName string,
upstream intercept.ResolvedUpstream,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
clientHeaders http.Header,
Expand All @@ -46,7 +46,7 @@ func NewStreamingInterceptor(
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
upstream: upstream,
reqPayload: reqPayload,
cfg: cfg,
bedrockCfg: bedrockCfg,
Expand Down
Loading
Loading