Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions aibtrace/aibtrace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package aibtrace

import (
"context"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)

type traceInterceptionAttrsContextKey struct{}

const (
// trace attribute key constants
InterceptionID = "interception_id"
UserID = "user_id"
Provider = "provider"
Model = "model"
Streaming = "streaming"
IsBedrock = "aws_bedrock"
MCPToolName = "mcp_tool_name"
PassthroughURL = "passthrough_url"
PassthroughMethod = "passthrough_method"
)

func WithTraceInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context {
return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs)
}

func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue {
attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue)
if !ok {
return nil
}

return attrs
}

func EndSpanErr(span trace.Span, err *error) {
if span == nil {
return
}

if err != nil && *err != nil {
span.SetStatus(codes.Error, (*err).Error())
}
span.End()
}
7 changes: 4 additions & 3 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"cdr.dev/slog"
"github.com/coder/aibridge/mcp"
"go.opentelemetry.io/otel/trace"

"github.com/hashicorp/go-multierror"
)
Expand Down Expand Up @@ -47,20 +48,20 @@ var _ http.Handler = &RequestBridge{}
// A [Recorder] is also required to record prompt, tool, and token use.
//
// mcpProxy will be closed when the [RequestBridge] is closed.
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) {
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer, logger slog.Logger) (*RequestBridge, error) {
mux := http.NewServeMux()

for _, provider := range providers {
// Add the known provider-specific routes which are bridged (i.e. intercepted and augmented).
for _, path := range provider.BridgedRoutes() {
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics))
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics, tracer))
}

// Any requests which passthrough to this will be reverse-proxied to the upstream.
//
// We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be
// configured, so we should just reverse-proxy known-safe routes.
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics)
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics, tracer)
for _, path := range provider.PassthroughRoutes() {
prefix := fmt.Sprintf("/%s", provider.Name())
route := fmt.Sprintf("%s%s", prefix, path)
Expand Down
87 changes: 47 additions & 40 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,22 @@ import (

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"go.uber.org/goleak"
"golang.org/x/tools/txtar"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
"github.com/coder/aibridge"
"github.com/coder/aibridge/mcp"
"github.com/google/uuid"
mcplib "github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/openai/openai-go/v2"
oaissestream "github.com/openai/openai-go/v2/packages/ssestream"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/openai/openai-go/v2"
oaissestream "github.com/openai/openai-go/v2/packages/ssestream"

mcplib "github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/tidwall/sjson"
"go.opentelemetry.io/otel"
"go.uber.org/goleak"
"golang.org/x/tools/txtar"
)

var (
Expand Down Expand Up @@ -65,6 +64,8 @@ var (
oaiMidStreamErr []byte
//go:embed fixtures/openai/non_stream_error.txtar
oaiNonStreamErr []byte

defaultTracer = otel.Tracer("github.com/coder/aibridge")
)

const (
Expand All @@ -90,8 +91,9 @@ func TestAnthropicMessages(t *testing.T) {
t.Parallel()

cases := []struct {
streaming bool
expectedInputTokens, expectedOutputTokens int
streaming bool
expectedInputTokens int
expectedOutputTokens int
}{
{
streaming: true,
Expand Down Expand Up @@ -133,7 +135,8 @@ func TestAnthropicMessages(t *testing.T) {
recorderClient := &mockRecorderClient{}

logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}
b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -214,7 +217,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{
aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg),
}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
}, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -312,7 +315,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(
ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)},
recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
require.NoError(t, err)

mockBridgeSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -399,7 +402,8 @@ func TestOpenAIChatCompletions(t *testing.T) {
recorderClient := &mockRecorderClient{}

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}
b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -466,7 +470,8 @@ func TestSimple(t *testing.T) {
fixture: antSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger)
provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -504,7 +509,8 @@ func TestSimple(t *testing.T) {
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -618,17 +624,8 @@ func TestSimple(t *testing.T) {
}

func setJSON(in []byte, key string, val bool) ([]byte, error) {
var body map[string]any
err := json.Unmarshal(in, &body)
if err != nil {
return nil, err
}
body[key] = val
out, err := json.Marshal(body)
if err != nil {
return nil, err
}
return out, nil
out, err := sjson.Set(string(in), key, val)
return []byte(out), err
}

func TestFallthrough(t *testing.T) {
Expand All @@ -645,7 +642,7 @@ func TestFallthrough(t *testing.T) {
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger)
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
require.NoError(t, err)
return provider, bridge
},
Expand All @@ -656,7 +653,7 @@ func TestFallthrough(t *testing.T) {
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger)
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
require.NoError(t, err)
return provider, bridge
},
Expand Down Expand Up @@ -762,7 +759,8 @@ func TestAnthropicInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -843,7 +841,8 @@ func TestOpenAIInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -1029,7 +1028,8 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
},
responseHandlerFn: func(resp *http.Response) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
Expand All @@ -1046,7 +1046,8 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
},
responseHandlerFn: func(resp *http.Response) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
Expand Down Expand Up @@ -1134,7 +1135,8 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
},
responseHandlerFn: func(resp *http.Response) {
// Server responds first with 200 OK then starts streaming.
Expand All @@ -1152,7 +1154,8 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
},
responseHandlerFn: func(resp *http.Response) {
// Server responds first with 200 OK then starts streaming.
Expand Down Expand Up @@ -1238,15 +1241,17 @@ func TestStableRequestEncoding(t *testing.T) {
fixture: antSimple,
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
},
},
{
name: aibridge.ProviderOpenAI,
fixture: oaiSimple,
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
},
},
}
Expand Down Expand Up @@ -1352,7 +1357,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
fixture: antSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger)
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
},
createRequest: createAnthropicMessagesReq,
envVars: map[string]string{
Expand All @@ -1365,7 +1371,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
},
createRequest: createOpenAIChatCompletionsReq,
envVars: map[string]string{
Expand Down
14 changes: 11 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ require (
github.com/openai/openai-go/v2 v2.7.0
)

require (
github.com/google/go-cmp v0.7.0
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
)

require (
github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
Expand All @@ -46,6 +53,8 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/lipgloss v0.7.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
Expand All @@ -61,14 +70,13 @@ require (
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/otel v1.33.0 // indirect
go.opentelemetry.io/otel/trace v1.33.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/term v0.34.0 // indirect
Expand Down
Loading