diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go new file mode 100644 index 0000000..9e45c0c --- /dev/null +++ b/aibtrace/aibtrace.go @@ -0,0 +1,48 @@ +package aibtrace + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +type traceInterceptionAttrsContextKey struct{} + +const ( + // trace attribute key constants + InterceptionID = "interception_id" + UserID = "user_id" + Provider = "provider" + Model = "model" + Streaming = "streaming" + IsBedrock = "aws_bedrock" + MCPToolName = "mcp_tool_name" + PassthroughURL = "passthrough_url" + PassthroughMethod = "passthrough_method" +) + +func WithTraceInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context { + return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs) +} + +func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue { + attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue) + if !ok { + return nil + } + + return attrs +} + +func EndSpanErr(span trace.Span, err *error) { + if span == nil { + return + } + + if err != nil && *err != nil { + span.SetStatus(codes.Error, (*err).Error()) + } + span.End() +} diff --git a/bridge.go b/bridge.go index 4f5428d..0f723ea 100644 --- a/bridge.go +++ b/bridge.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" + "go.opentelemetry.io/otel/trace" "github.com/hashicorp/go-multierror" ) @@ -47,20 +48,20 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. -func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) { +func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer, logger slog.Logger) (*RequestBridge, error) { mux := http.NewServeMux() for _, provider := range providers { // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics)) + mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics, tracer)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. // // We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be // configured, so we should just reverse-proxy known-safe routes. - ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics) + ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics, tracer) for _, path := range provider.PassthroughRoutes() { prefix := fmt.Sprintf("/%s", provider.Name()) route := fmt.Sprintf("%s%s", prefix, path) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index f8ef57e..2ea570a 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -20,23 +20,22 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" - "go.uber.org/goleak" - "golang.org/x/tools/txtar" - "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/coder/aibridge" "github.com/coder/aibridge/mcp" "github.com/google/uuid" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/openai/openai-go/v2" + oaissestream "github.com/openai/openai-go/v2/packages/ssestream" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - - "github.com/openai/openai-go/v2" - oaissestream "github.com/openai/openai-go/v2/packages/ssestream" - - mcplib "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel" + "go.uber.org/goleak" + "golang.org/x/tools/txtar" ) var ( @@ -65,6 +64,8 @@ var ( oaiMidStreamErr []byte //go:embed fixtures/openai/non_stream_error.txtar oaiNonStreamErr []byte + + defaultTracer = otel.Tracer("github.com/coder/aibridge") ) const ( @@ -90,8 +91,9 @@ func TestAnthropicMessages(t *testing.T) { t.Parallel() cases := []struct { - streaming bool - expectedInputTokens, expectedOutputTokens int + streaming bool + expectedInputTokens int + expectedOutputTokens int }{ { streaming: true, @@ -133,7 +135,8 @@ func TestAnthropicMessages(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)} + b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -214,7 +217,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + }, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -312,7 +315,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -399,7 +402,8 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))} + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -466,7 +470,8 @@ func TestSimple(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) + provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -504,7 +509,8 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -618,17 +624,8 @@ func TestSimple(t *testing.T) { } func setJSON(in []byte, key string, val bool) ([]byte, error) { - var body map[string]any - err := json.Unmarshal(in, &body) - if err != nil { - return nil, err - } - body[key] = val - out, err := json.Marshal(body) - if err != nil { - return nil, err - } - return out, nil + out, err := sjson.Set(string(in), key, val) + return []byte(out), err } func TestFallthrough(t *testing.T) { @@ -645,7 +642,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -656,7 +653,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -762,7 +759,8 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -843,7 +841,8 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -1029,7 +1028,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1046,7 +1046,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1134,7 +1135,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1152,7 +1154,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1238,7 +1241,8 @@ func TestStableRequestEncoding(t *testing.T) { fixture: antSimple, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, }, { @@ -1246,7 +1250,8 @@ func TestStableRequestEncoding(t *testing.T) { fixture: oaiSimple, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, }, } @@ -1352,7 +1357,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1365,7 +1371,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ diff --git a/go.mod b/go.mod index 47fd45d..cfe3a4a 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,13 @@ require ( github.com/openai/openai-go/v2 v2.7.0 ) +require ( + github.com/google/go-cmp v0.7.0 + go.opentelemetry.io/otel v1.38.0 + go.opentelemetry.io/otel/sdk v1.38.0 + go.opentelemetry.io/otel/trace v1.38.0 +) + require ( github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect @@ -46,6 +53,8 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/lipgloss v0.7.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect @@ -61,14 +70,13 @@ require ( github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/rivo/uniseg v0.4.4 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - go.opentelemetry.io/otel v1.33.0 // indirect - go.opentelemetry.io/otel/trace v1.33.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/term v0.34.0 // indirect diff --git a/go.sum b/go.sum index d0b79c8..385345d 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= -github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -130,14 +131,16 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw= -go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I= -go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ= -go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M= -go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE= -go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= -go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= -go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 5049e54..0ee12ea 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -14,8 +14,11 @@ import ( "github.com/anthropics/anthropic-sdk-go/option" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -27,6 +30,7 @@ type AnthropicMessagesInterceptionBase struct { cfg AnthropicConfig bedrockCfg *AWSBedrockConfig + tracer trace.Tracer logger slog.Logger recorder Recorder @@ -59,6 +63,17 @@ func (i *AnthropicMessagesInterceptionBase) Model() string { return string(i.req.Model) } +func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(ctx context.Context, streaming bool) []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(aibtrace.Provider, ProviderAnthropic), + attribute.String(aibtrace.InterceptionID, s.id.String()), + attribute.String(aibtrace.Model, s.Model()), + attribute.String(aibtrace.UserID, actorFromContext(ctx).id), + attribute.Bool(aibtrace.Streaming, streaming), + attribute.Bool(aibtrace.IsBedrock, s.bedrockCfg != nil), + } +} + func (i *AnthropicMessagesInterceptionBase) injectTools() { if i.req == nil || i.mcpProxy == nil { return diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index a1f71e6..0773a13 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -1,6 +1,7 @@ package aibridge import ( + "context" "encoding/json" "fmt" "net/http" @@ -10,7 +11,10 @@ import ( "github.com/anthropics/anthropic-sdk-go/option" "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" // TODO: abstract this away so callers need no knowledge of underlying lib. + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "cdr.dev/slog" @@ -22,29 +26,35 @@ type AnthropicMessagesBlockingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesBlockingInterception { +func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesBlockingInterception { return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, cfg: cfg, bedrockCfg: bedrockCfg, + tracer: tracer, }} } -func (s *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) { - s.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +func (i *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) { + i.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +} + +func (i *AnthropicMessagesBlockingInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return i.AnthropicMessagesInterceptionBase.baseTraceAttributes(ctx, false) } func (s *AnthropicMessagesBlockingInterception) Streaming() bool { return false } -func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } - ctx := r.Context() + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) i.injectTools() @@ -77,7 +87,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr var cumulativeUsage anthropic.Usage for { - resp, err = svc.New(ctx, messages) + resp, err = i.traceNewMessage(ctx, svc, messages) // traces client.Messages.New(ctx, msgParams) call if err != nil { if isConnError(err) { // Can't write a response, just error out. @@ -166,7 +176,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr continue } - res, err := tool.Call(ctx, tc.Input) + res, err := tool.Call(ctx, i.tracer, tc.Input) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -285,3 +295,10 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr return nil } + +func (i *AnthropicMessagesBlockingInterception) traceNewMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + + return svc.New(ctx, msgParams) +} diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index ef8aabd..139f7c2 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -11,10 +11,14 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -25,12 +29,13 @@ type AnthropicMessagesStreamingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesStreamingInterception { +func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesStreamingInterception { return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, cfg: cfg, bedrockCfg: bedrockCfg, + tracer: tracer, }} } @@ -42,6 +47,10 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool { return true } +func (s *AnthropicMessagesStreamingInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return s.AnthropicMessagesInterceptionBase.baseTraceAttributes(ctx, true) +} + // ProcessRequest handles a request to /v1/messages. // This API has a state-machine behind it, which is described in https://docs.claude.com/en/docs/build-with-claude/streaming#event-types. // @@ -61,13 +70,16 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool { // b) if the tool is injected, it will be invoked by the [mcp.ServerProxier] in the remote MCP server, and its // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. -func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) + // Allow us to interrupt watch via cancel. - ctx, cancel := context.WithCancel(r.Context()) + ctx, cancel := context.WithCancel(ctx) defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. @@ -129,7 +141,7 @@ newStream: pendingToolCalls := make(map[string]string) - for stream.Next() { + for i.traceStreamNext(ctx, stream) { // traces stream.Next() call event := stream.Current() if err := message.Accumulate(event); err != nil { logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON())) @@ -269,7 +281,7 @@ newStream: continue } - res, err := tool.Call(streamCtx, input) + res, err := tool.Call(streamCtx, i.tracer, input) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -508,3 +520,10 @@ func (s *AnthropicMessagesStreamingInterception) encodeForStream(payload []byte, buf.WriteString("\n\n") return buf.Bytes() } + +func (s *AnthropicMessagesStreamingInterception) traceStreamNext(ctx context.Context, stream *ssestream.Stream[anthropic.MessageStreamEventUnion]) bool { + _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer span.End() + + return stream.Next() +} diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 20db323..f81b245 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -6,21 +6,26 @@ import ( "net/http" "strings" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" "github.com/openai/openai-go/v2/shared" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) type OpenAIChatInterceptionBase struct { - id uuid.UUID - req *ChatCompletionNewParamsWrapper + id uuid.UUID + req *ChatCompletionNewParamsWrapper + baseURL string + key string - baseURL, key string - logger slog.Logger + tracer trace.Tracer + logger slog.Logger recorder Recorder mcpProxy mcp.ServerProxier @@ -42,6 +47,16 @@ func (i *OpenAIChatInterceptionBase) Setup(logger slog.Logger, recorder Recorder i.mcpProxy = mcpProxy } +func (s *OpenAIChatInterceptionBase) baseTraceAttributes(ctx context.Context, streaming bool) []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(aibtrace.Provider, ProviderOpenAI), + attribute.String(aibtrace.InterceptionID, s.id.String()), + attribute.String(aibtrace.Model, s.Model()), + attribute.String(aibtrace.UserID, actorFromContext(ctx).id), + attribute.Bool(aibtrace.Streaming, streaming), + } +} + func (i *OpenAIChatInterceptionBase) Model() string { if i.req == nil { return "coder-aibridge-unknown" diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 757c933..f4d1db3 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -2,16 +2,20 @@ package aibridge import ( "bytes" + "context" "encoding/json" "fmt" "net/http" "strings" "time" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -22,12 +26,13 @@ type OpenAIBlockingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIBlockingChatInterception { +func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *OpenAIBlockingChatInterception { return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, baseURL: baseURL, key: key, + tracer: tracer, }} } @@ -39,12 +44,18 @@ func (s *OpenAIBlockingChatInterception) Streaming() bool { return false } -func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (s *OpenAIBlockingChatInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(ctx, false) +} + +func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } - ctx := r.Context() + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) + svc := i.newCompletionsService(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) @@ -65,7 +76,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout - completion, err = svc.New(ctx, i.req.ChatCompletionNewParams, opts...) + completion, err = i.traceChatCompletionsNew(ctx, svc, opts) // traces svc.New call if err != nil { break } @@ -145,7 +156,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r ) _ = json.NewEncoder(&buf).Encode(tc.Function.Arguments) _ = json.NewDecoder(&buf).Decode(&args) - res, err := tool.Call(ctx, args) + res, err := tool.Call(ctx, i.tracer, args) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -227,3 +238,10 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r return nil } + +func (i *OpenAIBlockingChatInterception) traceChatCompletionsNew(ctx context.Context, client openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + + return client.New(ctx, i.req.ChatCompletionNewParams, opts...) +} diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index ccabb35..f1ec165 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -10,11 +10,14 @@ import ( "strings" "time" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/ssestream" "github.com/tidwall/sjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -25,12 +28,13 @@ type OpenAIStreamingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIStreamingChatInterception { +func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *OpenAIStreamingChatInterception { return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, baseURL: baseURL, key: key, + tracer: tracer, }} } @@ -42,6 +46,10 @@ func (i *OpenAIStreamingChatInterception) Streaming() bool { return true } +func (s *OpenAIStreamingChatInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(ctx, true) +} + // ProcessRequest handles a request to /v1/chat/completions. // See https://platform.openai.com/docs/api-reference/chat-streaming/streaming. // @@ -54,18 +62,21 @@ func (i *OpenAIStreamingChatInterception) Streaming() bool { // b) if the tool is injected, it will be invoked by the [mcp.ServerProxier] in the remote MCP server, and its // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. -func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) + // Include token usage. i.req.StreamOptions.IncludeUsage = openai.Bool(true) i.injectTools() // Allow us to interrupt watch via cancel. - ctx, cancel := context.WithCancel(r.Context()) + ctx, cancel := context.WithCancel(ctx) defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. @@ -109,7 +120,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, var toolCall *openai.FinishedChatCompletionToolCall - for stream.Next() { + for i.traceStreamNext(ctx, stream) { // traces stream.Next() call chunk := stream.Current() canRelay := processor.process(chunk) @@ -230,7 +241,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, id := toolCall.ID args := i.unmarshalArgs(toolCall.Arguments) - toolRes, toolErr := tool.Call(streamCtx, args) + toolRes, toolErr := tool.Call(streamCtx, i.tracer, args) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -336,6 +347,13 @@ func (i *OpenAIStreamingChatInterception) encodeForStream(payload []byte) []byte return buf.Bytes() } +func (i *OpenAIStreamingChatInterception) traceStreamNext(ctx context.Context, stream *ssestream.Stream[openai.ChatCompletionChunk]) bool { + _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer span.End() + + return stream.Next() +} + type openAIStreamProcessor struct { ctx context.Context logger slog.Logger diff --git a/interception.go b/interception.go index 8210c41..7af954d 100644 --- a/interception.go +++ b/interception.go @@ -1,6 +1,7 @@ package aibridge import ( + "context" "errors" "fmt" "net/http" @@ -8,8 +9,12 @@ import ( "time" "cdr.dev/slog" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) // Interceptor describes a (potentially) stateful interaction with an AI provider. @@ -25,6 +30,8 @@ type Interceptor interface { ProcessRequest(w http.ResponseWriter, r *http.Request) error // Specifies whether an interceptor handles streaming or not. Streaming() bool + // TraceAttributes returns tracing attributes for this [Interceptor] + TraceAttributes(context.Context) []attribute.KeyValue } var UnknownRoute = errors.New("unknown route") @@ -34,11 +41,15 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics) http.HandlerFunc { +func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - interceptor, err := p.CreateInterceptor(w, r) + ctx, span := tracer.Start(r.Context(), "Intercept") + defer span.End() + + interceptor, err := p.CreateInterceptor(tracer, w, r.WithContext(ctx)) if err != nil { - logger.Warn(r.Context(), "failed to create interceptor", slog.Error(err), slog.F("path", r.URL.Path)) + span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) + logger.Warn(ctx, "failed to create interceptor", slog.Error(err), slog.F("path", r.URL.Path)) http.Error(w, fmt.Sprintf("failed to create %q interceptor", r.URL.Path), http.StatusInternalServerError) return } @@ -50,13 +61,18 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, }() } - actor := actorFromContext(r.Context()) + actor := actorFromContext(ctx) if actor == nil { - logger.Warn(r.Context(), "no actor found in context") + logger.Warn(ctx, "no actor found in context") http.Error(w, "no actor found", http.StatusBadRequest) return } + traceAttrs := interceptor.TraceAttributes(ctx) + span.SetAttributes(traceAttrs...) + ctx = aibtrace.WithTraceInterceptionAttributesInContext(ctx, traceAttrs) + r = r.WithContext(ctx) + // Record usage in the background to not block request flow. asyncRecorder := NewAsyncRecorder(logger, recorder, recordingTimeout) asyncRecorder.WithMetrics(metrics) @@ -65,14 +81,15 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, asyncRecorder.WithInitiatorID(actor.id) interceptor.Setup(logger, asyncRecorder, mcpProxy) - if err := recorder.RecordInterception(r.Context(), &InterceptionRecord{ + if err := recorder.RecordInterception(ctx, &InterceptionRecord{ ID: interceptor.ID().String(), Metadata: actor.metadata, InitiatorID: actor.id, Provider: p.Name(), Model: interceptor.Model(), }); err != nil { - logger.Warn(r.Context(), "failed to record interception", slog.Error(err)) + span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err)) + logger.Warn(ctx, "failed to record interception", slog.Error(err)) http.Error(w, "failed to record interception", http.StatusInternalServerError) return } @@ -86,27 +103,27 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, slog.F("streaming", interceptor.Streaming()), ) - log.Debug(r.Context(), "interception started") + log.Debug(ctx, "interception started") if metrics != nil { metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Add(1) + defer func() { + metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + }() } if err := interceptor.ProcessRequest(w, r); err != nil { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusFailed, route, r.Method, actor.id).Add(1) } - log.Warn(r.Context(), "interception failed", slog.Error(err)) + span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) + log.Warn(ctx, "interception failed", slog.Error(err)) } else { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusCompleted, route, r.Method, actor.id).Add(1) } - log.Debug(r.Context(), "interception ended") - } - asyncRecorder.RecordInterceptionEnded(r.Context(), &InterceptionRecordEnded{ID: interceptor.ID().String()}) - - if metrics != nil { - metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + log.Debug(ctx, "interception ended") } + asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/mcp/tool.go b/mcp/tool.go index 2c01535..23dffc4 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -7,7 +7,10 @@ import ( "strings" "cdr.dev/slog" + "github.com/coder/aibridge/aibtrace" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) const ( @@ -34,14 +37,21 @@ type Tool struct { Required []string } -func (t *Tool) Call(ctx context.Context, input any) (*mcp.CallToolResult, error) { +func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp.CallToolResult, outErr error) { if t == nil { - return nil, errors.New("nil tool!") + return nil, errors.New("nil tool") } if t.Client == nil { - return nil, errors.New("nil client!") + return nil, errors.New("nil client") } + spanAttrs := append( + aibtrace.TraceInterceptionAttributesFromContext(ctx), + attribute.String(aibtrace.MCPToolName, t.Name), + ) + ctx, span := tracer.Start(ctx, "Intercept.RecordInterception.ToolCall", trace.WithAttributes(spanAttrs...)) + defer aibtrace.EndSpanErr(span, &outErr) + return t.Client.CallTool(ctx, mcp.CallToolRequest{ Params: mcp.CallToolParams{ Name: t.Name, diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 3696de2..1b17c6f 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -16,6 +16,7 @@ import ( "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" "golang.org/x/tools/txtar" ) @@ -48,7 +49,7 @@ func TestMetrics_Interception(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -89,7 +90,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil) - bridgeSrv := newTestSrv(t, ctx, provider, metrics) + bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) // Make request in background. doneCh := make(chan struct{}) @@ -141,7 +142,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) - srv := newTestSrv(t, t.Context(), provider, metrics) + srv, _ := newTestSrv(t, t.Context(), provider, metrics, defaultTracer) req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) require.NoError(t, err) @@ -170,7 +171,7 @@ func TestMetrics_PromptCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -198,7 +199,7 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -240,7 +241,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { mcpMgr := mcp.NewServerProxyManager(tools) require.NoError(t, mcpMgr.Init(ctx)) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, defaultTracer, logger) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -272,13 +273,17 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { require.Equal(t, 1.0, count) } -func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics) *httptest.Server { +func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics, tracer trace.Tracer) (*httptest.Server, *mockRecorderClient) { t.Helper() - recorder := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + mockRecorder := &mockRecorderClient{} + clientFn := func() (aibridge.Recorder, error) { + return mockRecorder, nil + } + wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcp.NewServerProxyManager(nil), metrics, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil), metrics, tracer, logger) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -288,5 +293,5 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m srv.Start() t.Cleanup(srv.Close) - return srv + return srv, mockRecorder } diff --git a/passthrough.go b/passthrough.go index 6788672..2c9f45b 100644 --- a/passthrough.go +++ b/passthrough.go @@ -8,19 +8,28 @@ import ( "time" "cdr.dev/slog" + "github.com/coder/aibridge/aibtrace" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically // by a [Provider]. -func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metrics) http.HandlerFunc { +func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if metrics != nil { metrics.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1) } + ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( + attribute.String(aibtrace.PassthroughURL, r.URL.Path), + attribute.String(aibtrace.PassthroughMethod, r.Method), + )) + defer span.End() + upURL, err := url.Parse(provider.BaseURL()) if err != nil { - logger.Warn(r.Context(), "failed to parse provider base URL", slog.Error(err)) + logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err)) http.Error(w, "request error", http.StatusBadGateway) return } diff --git a/provider.go b/provider.go index 3787af6..78bd771 100644 --- a/provider.go +++ b/provider.go @@ -2,6 +2,8 @@ package aibridge import ( "net/http" + + "go.opentelemetry.io/otel/trace" ) // Provider describes an AI provider client's behaviour. @@ -14,7 +16,7 @@ type Provider interface { // CreateInterceptor starts a new [Interceptor] which is responsible for intercepting requests, // communicating with the upstream provider and formulating a response to be sent to the requesting client. - CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) + CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (Interceptor, error) // BridgedRoutes returns a slice of [http.ServeMux]-compatible routes which will have special handling. // See https://pkg.go.dev/net/http#hdr-Patterns-ServeMux. diff --git a/provider_anthropic.go b/provider_anthropic.go index 7e9c99f..009c30f 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -11,7 +11,10 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared" "github.com/anthropics/anthropic-sdk-go/shared/constant" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var _ Provider = &AnthropicProvider{} @@ -58,14 +61,16 @@ func (p *AnthropicProvider) PassthroughRoutes() []string { } } -func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) { +func (p *AnthropicProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (_ Interceptor, outErr error) { + id := uuid.New() + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer aibtrace.EndSpanErr(span, &outErr) + payload, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("read body: %w", err) } - id := uuid.New() - switch r.URL.Path { case routeMessages: var req MessageNewParamsWrapper @@ -73,13 +78,17 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req return nil, fmt.Errorf("failed to unmarshal request: %w", err) } + var interceptor Interceptor if req.Stream { - return NewAnthropicMessagesStreamingInterception(id, &req, p.cfg, p.bedrockCfg), nil + interceptor = NewAnthropicMessagesStreamingInterception(id, &req, p.cfg, p.bedrockCfg, tracer) + } else { + interceptor = NewAnthropicMessagesBlockingInterception(id, &req, p.cfg, p.bedrockCfg, tracer) } - - return NewAnthropicMessagesBlockingInterception(id, &req, p.cfg, p.bedrockCfg), nil + span.SetAttributes(interceptor.TraceAttributes(r.Context())...) + return interceptor, nil } + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) return nil, UnknownRoute } diff --git a/provider_openai.go b/provider_openai.go index 0fc31a6..7cc154a 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -7,7 +7,10 @@ import ( "net/http" "os" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var _ Provider = &OpenAIProvider{} @@ -58,14 +61,17 @@ func (p *OpenAIProvider) PassthroughRoutes() []string { } } -func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) { +func (p *OpenAIProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (_ Interceptor, outErr error) { + id := uuid.New() + + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer aibtrace.EndSpanErr(span, &outErr) + payload, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("read body: %w", err) } - id := uuid.New() - switch r.URL.Path { case routeChatCompletions: var req ChatCompletionNewParamsWrapper @@ -73,13 +79,17 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques return nil, fmt.Errorf("unmarshal request body: %w", err) } + var interceptor Interceptor if req.Stream { - return NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key), nil + interceptor = NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key, tracer) } else { - return NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key), nil + interceptor = NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key, tracer) } + span.SetAttributes(interceptor.TraceAttributes(r.Context())...) + return interceptor, nil } + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) return nil, UnknownRoute } diff --git a/recorder.go b/recorder.go index edd68f7..375ef23 100644 --- a/recorder.go +++ b/recorder.go @@ -7,18 +7,28 @@ import ( "time" "cdr.dev/slog" + + aibtrace "github.com/coder/aibridge/aibtrace" + "go.opentelemetry.io/otel/trace" ) -var _ Recorder = &RecorderWrapper{} +var ( + _ Recorder = &RecorderWrapper{} + _ Recorder = &AsyncRecorder{} +) // RecorderWrapper is a convenience struct which implements RecorderClient and resolves a client before calling each method. // It also sets the start/creation time of each record. type RecorderWrapper struct { logger slog.Logger + tracer trace.Tracer clientFn func() (Recorder, error) } -func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) error { +func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -33,7 +43,10 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept return err } -func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error { +func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -48,7 +61,10 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte return err } -func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { +func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -63,7 +79,10 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag return err } -func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error { +func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -78,7 +97,10 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR return err } -func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error { +func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -93,12 +115,14 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec return err } -func NewRecorder(logger slog.Logger, clientFn func() (Recorder, error)) *RecorderWrapper { - return &RecorderWrapper{logger: logger, clientFn: clientFn} +func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper { + return &RecorderWrapper{ + logger: logger, + tracer: tracer, + clientFn: clientFn, + } } -var _ Recorder = &AsyncRecorder{} - // AsyncRecorder calls [Recorder] methods asynchronously and logs any errors which may occur. type AsyncRecorder struct { logger slog.Logger @@ -141,7 +165,7 @@ func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *Interc a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordInterceptionEnded(timedCtx, req) @@ -153,11 +177,11 @@ func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *Interc return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRecord) error { +func (a *AsyncRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordPromptUsage(timedCtx, req) @@ -173,11 +197,11 @@ func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRec return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordTokenUsage(_ context.Context, req *TokenUsageRecord) error { +func (a *AsyncRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordTokenUsage(timedCtx, req) @@ -197,11 +221,11 @@ func (a *AsyncRecorder) RecordTokenUsage(_ context.Context, req *TokenUsageRecor return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordToolUsage(_ context.Context, req *ToolUsageRecord) error { +func (a *AsyncRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordToolUsage(timedCtx, req) @@ -228,3 +252,10 @@ func (a *AsyncRecorder) RecordToolUsage(_ context.Context, req *ToolUsageRecord) func (a *AsyncRecorder) Wait() { a.wg.Wait() } + +// returns detrached context with tracing information copied from provided context +func (a *AsyncRecorder) timedContext(ctx context.Context) (context.Context, context.CancelFunc) { + timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx = aibtrace.WithTraceInterceptionAttributesInContext(timedCtx, aibtrace.TraceInterceptionAttributesFromContext(ctx)) + return timedCtx, cancel +} diff --git a/trace_integration_test.go b/trace_integration_test.go new file mode 100644 index 0000000..16b1744 --- /dev/null +++ b/trace_integration_test.go @@ -0,0 +1,514 @@ +package aibridge_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/aibridge" + "github.com/coder/aibridge/aibtrace" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "golang.org/x/tools/txtar" +) + +// expect 'count' amount of traces named 'name' with status 'status' +type expectTrace struct { + name string + count int + status codes.Code +} + +func TestTraceAnthropic(t *testing.T) { + expectNonStreaming := []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.RecordToolUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + } + + cases := []struct { + name string + streaming bool + bedrock bool + expectTraceCounts []expectTrace + }{ + { + name: "trace_anthr_non_streaming", + expectTraceCounts: expectNonStreaming, + }, + { + name: "trace_bedrock_non_streaming", + bedrock: true, + expectTraceCounts: expectNonStreaming, + }, + { + name: "trace_anthr_streaming", + streaming: true, + expectTraceCounts: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 2, codes.Unset}, + {"Intercept.RecordToolUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 9, codes.Unset}, + }, + }, + { + name: "trace_bedrock_streaming", + streaming: true, + bedrock: true, + expectTraceCounts: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + }, + }, + } + + arc := txtar.Parse(antSingleBuiltinTool) + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + fixtureReqBody := files[fixtureRequest] + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + var bedrockCfg *aibridge.AWSBedrockConfig + if tc.bedrock { + bedrockCfg = &aibridge.AWSBedrockConfig{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + EndpointOverride: mockAPI.URL, + } + } + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createAnthropicMessagesReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + model := gjson.Get(string(reqBody), "model").Str + if tc.bedrock { + model = "beddel" + } + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), + attribute.String(aibtrace.Model, model), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + attribute.Bool(aibtrace.IsBedrock, tc.bedrock), + } + + verifyCommonTraceAttrs(t, sr, tc.expectTraceCounts, attrs) + }) + } +} + +func TestTraceAnthropicErr(t *testing.T) { + cases := []struct { + name string + streaming bool + expect []expectTrace + }{ + { + name: "trace_anthr_non_streaming_err", + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + }, + }, + { + name: "trace_anthr_streaming_err", + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 3, codes.Unset}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + var arc *txtar.Archive + if tc.streaming { + arc = txtar.Parse(antMidStreamErr) + } else { + arc = txtar.Parse(antNonStreamErr) + } + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + if tc.streaming { + require.Contains(t, files, fixtureStreamingResponse) + } else { + require.Contains(t, files, fixtureNonStreamingResponse) + } + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createAnthropicMessagesReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + if tc.streaming { + require.Equal(t, http.StatusOK, resp.StatusCode) + } else { + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + } + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), + attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + attribute.Bool(aibtrace.IsBedrock, false), + } + + verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + }) + } +} + +func TestTraceOpenAI(t *testing.T) { + cases := []struct { + name string + fixture []byte + streaming bool + expect []expectTrace + }{ + { + name: "trace_openai_streaming", + fixture: oaiSimple, + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 242, codes.Unset}, + }, + }, + { + name: "trace_openai_non_streaming", + fixture: oaiSimple, + streaming: false, + expect: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + arc := txtar.Parse(tc.fixture) + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), + attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + } + + verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + }) + } +} + +func TestTraceOpenAIErr(t *testing.T) { + cases := []struct { + name string + streaming bool + expect []expectTrace + }{ + { + name: "trace_openai_streaming_err", + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 5, codes.Unset}, + }, + }, + { + name: "trace_openai_non_streaming_err", + streaming: false, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + var arc *txtar.Archive + if tc.streaming { + arc = txtar.Parse(oaiMidStreamErr) + } else { + arc = txtar.Parse(oaiNonStreamErr) + } + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + if tc.streaming { + require.Contains(t, files, fixtureStreamingResponse) + } else { + require.Contains(t, files, fixtureNonStreamingResponse) + } + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + if tc.streaming { + require.Equal(t, http.StatusOK, resp.StatusCode) + } else { + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + } + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), + attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + } + + verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + }) + } +} + +func TestTracePassthrough(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(oaiFallthrough) + files := filesMap(arc) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureResponse]) + })) + t.Cleanup(upstream.Close) + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) + srv, _ := newTestSrv(t, t.Context(), provider, nil, tracer) + + req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + srv.Close() + + spans := sr.Ended() + require.Equal(t, len(spans), 1) + assert.Equal(t, spans[0].Name(), "Passthrough") + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.PassthroughURL, "/v1/models"), + attribute.String(aibtrace.PassthroughMethod, "GET"), + } + if attrDiff := cmp.Diff(spans[0].Attributes(), attrs, cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { + t.Errorf("unexpectet attrs diff: %s", attrDiff) + } +} + +func cmpAttrKeyVal(a attribute.KeyValue, b attribute.KeyValue) bool { + return a.Key < b.Key +} + +func verifyCommonTraceAttrs(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []expectTrace, attrs []attribute.KeyValue) { + spans := spanRecorder.Ended() + + totalCount := 0 + for _, e := range expect { + totalCount += e.count + } + assert.Equal(t, totalCount, len(spans)) + + for _, e := range expect { + found := 0 + for _, s := range spans { + if s.Name() != e.name || s.Status().Code != e.status { + continue + } + found++ + if attrDiff := cmp.Diff(s.Attributes(), attrs, cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { + t.Errorf("unexpectet attrs for span named: %v, diff: %s", e.name, attrDiff) + } + assert.Equalf(t, e.status, s.Status().Code, "unexpected status for trace naned: %v got: %v want: %v", e.name, s.Status().Code, e.status) + } + if found != e.count { + t.Errorf("found unexpected number of spans named: %v with status %v, got: %v want: %v", e.name, e.status, found, e.count) + } + } +}