Skip to content

Commit 3fb2522

Browse files
committed
feat: add interception tracing
1 parent 9e9491b commit 3fb2522

20 files changed

+913
-144
lines changed

aibtrace/aibtrace.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package aibtrace
2+
3+
import (
4+
"context"
5+
6+
"go.opentelemetry.io/otel/attribute"
7+
"go.opentelemetry.io/otel/codes"
8+
"go.opentelemetry.io/otel/trace"
9+
)
10+
11+
type traceInterceptionAttrsContextKey struct{}
12+
13+
const (
14+
// trace attribute key constants
15+
InterceptionID = "interception_id"
16+
UserID = "user_id"
17+
Provider = "provider"
18+
Model = "model"
19+
Streaming = "streaming"
20+
IsBedrock = "aws_bedrock"
21+
MCPToolName = "mcp_tool_name"
22+
PassthroughURL = "passthrough_url"
23+
PassthroughMethod = "passthrough_method"
24+
)
25+
26+
func WithTraceInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context {
27+
return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs)
28+
}
29+
30+
func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue {
31+
attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue)
32+
if !ok {
33+
return nil
34+
}
35+
36+
return attrs
37+
}
38+
39+
func EndSpanErr(span trace.Span, err *error) {
40+
if err != nil && *err != nil {
41+
span.SetStatus(codes.Error, (*err).Error())
42+
}
43+
span.End()
44+
}

bridge.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"cdr.dev/slog"
1111
"github.com/coder/aibridge/mcp"
12+
"go.opentelemetry.io/otel/trace"
1213

1314
"github.com/hashicorp/go-multierror"
1415
)
@@ -39,28 +40,30 @@ type RequestBridge struct {
3940
closed chan struct{}
4041
}
4142

42-
var _ http.Handler = &RequestBridge{}
43+
var (
44+
_ http.Handler = &RequestBridge{}
45+
)
4346

4447
// NewRequestBridge creates a new *[RequestBridge] and registers the HTTP routes defined by the given providers.
4548
// Any routes which are requested but not registered will be reverse-proxied to the upstream service.
4649
//
4750
// A [Recorder] is also required to record prompt, tool, and token use.
4851
//
4952
// mcpProxy will be closed when the [RequestBridge] is closed.
50-
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) {
53+
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer, logger slog.Logger) (*RequestBridge, error) {
5154
mux := http.NewServeMux()
5255

5356
for _, provider := range providers {
5457
// Add the known provider-specific routes which are bridged (i.e. intercepted and augmented).
5558
for _, path := range provider.BridgedRoutes() {
56-
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics))
59+
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics, tracer))
5760
}
5861

5962
// Any requests which passthrough to this will be reverse-proxied to the upstream.
6063
//
6164
// We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be
6265
// configured, so we should just reverse-proxy known-safe routes.
63-
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics)
66+
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics, tracer)
6467
for _, path := range provider.PassthroughRoutes() {
6568
prefix := fmt.Sprintf("/%s", provider.Name())
6669
route := fmt.Sprintf("%s%s", prefix, path)
@@ -74,7 +77,7 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record
7477
http.Error(w, fmt.Sprintf("route not supported: %s %s", r.Method, r.URL.Path), http.StatusNotFound)
7578
})
7679

77-
inflightCtx, cancel := context.WithCancel(context.Background())
80+
inflightCtx, cancel := context.WithCancel(ctx)
7881
return &RequestBridge{
7982
mux: mux,
8083
logger: logger,

bridge_integration_test.go

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,22 @@ import (
2020

2121
"cdr.dev/slog"
2222
"cdr.dev/slog/sloggers/slogtest"
23-
"go.uber.org/goleak"
24-
"golang.org/x/tools/txtar"
25-
2623
"github.com/anthropics/anthropic-sdk-go"
2724
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
2825
"github.com/coder/aibridge"
2926
"github.com/coder/aibridge/mcp"
3027
"github.com/google/uuid"
28+
mcplib "github.com/mark3labs/mcp-go/mcp"
29+
"github.com/mark3labs/mcp-go/server"
30+
"github.com/openai/openai-go/v2"
31+
oaissestream "github.com/openai/openai-go/v2/packages/ssestream"
3132
"github.com/stretchr/testify/assert"
3233
"github.com/stretchr/testify/require"
3334
"github.com/tidwall/gjson"
34-
35-
"github.com/openai/openai-go/v2"
36-
oaissestream "github.com/openai/openai-go/v2/packages/ssestream"
37-
38-
mcplib "github.com/mark3labs/mcp-go/mcp"
39-
"github.com/mark3labs/mcp-go/server"
35+
"github.com/tidwall/sjson"
36+
"go.opentelemetry.io/otel"
37+
"go.uber.org/goleak"
38+
"golang.org/x/tools/txtar"
4039
)
4140

4241
var (
@@ -65,6 +64,8 @@ var (
6564
oaiMidStreamErr []byte
6665
//go:embed fixtures/openai/non_stream_error.txtar
6766
oaiNonStreamErr []byte
67+
68+
defaultTracer = otel.Tracer("github.com/coder/aibridge")
6869
)
6970

7071
const (
@@ -90,8 +91,9 @@ func TestAnthropicMessages(t *testing.T) {
9091
t.Parallel()
9192

9293
cases := []struct {
93-
streaming bool
94-
expectedInputTokens, expectedOutputTokens int
94+
streaming bool
95+
expectedInputTokens int
96+
expectedOutputTokens int
9597
}{
9698
{
9799
streaming: true,
@@ -133,7 +135,8 @@ func TestAnthropicMessages(t *testing.T) {
133135
recorderClient := &mockRecorderClient{}
134136

135137
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
136-
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
138+
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}
139+
b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
137140
require.NoError(t, err)
138141

139142
mockSrv := httptest.NewUnstartedServer(b)
@@ -214,7 +217,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
214217
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
215218
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{
216219
aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg),
217-
}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
220+
}, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
218221
require.NoError(t, err)
219222

220223
mockSrv := httptest.NewUnstartedServer(b)
@@ -312,7 +315,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
312315
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
313316
b, err := aibridge.NewRequestBridge(
314317
ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)},
315-
recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
318+
recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
316319
require.NoError(t, err)
317320

318321
mockBridgeSrv := httptest.NewUnstartedServer(b)
@@ -399,7 +402,8 @@ func TestOpenAIChatCompletions(t *testing.T) {
399402
recorderClient := &mockRecorderClient{}
400403

401404
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
402-
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger)
405+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}
406+
b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
403407
require.NoError(t, err)
404408

405409
mockSrv := httptest.NewUnstartedServer(b)
@@ -466,7 +470,8 @@ func TestSimple(t *testing.T) {
466470
fixture: antSimple,
467471
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
468472
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
469-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger)
473+
provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
474+
return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
470475
},
471476
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
472477
if streaming {
@@ -504,7 +509,8 @@ func TestSimple(t *testing.T) {
504509
fixture: oaiSimple,
505510
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
506511
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
507-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger)
512+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
513+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
508514
},
509515
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
510516
if streaming {
@@ -618,17 +624,8 @@ func TestSimple(t *testing.T) {
618624
}
619625

620626
func setJSON(in []byte, key string, val bool) ([]byte, error) {
621-
var body map[string]any
622-
err := json.Unmarshal(in, &body)
623-
if err != nil {
624-
return nil, err
625-
}
626-
body[key] = val
627-
out, err := json.Marshal(body)
628-
if err != nil {
629-
return nil, err
630-
}
631-
return out, nil
627+
out, err := sjson.Set(string(in), key, val)
628+
return []byte(out), err
632629
}
633630

634631
func TestFallthrough(t *testing.T) {
@@ -645,7 +642,7 @@ func TestFallthrough(t *testing.T) {
645642
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
646643
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
647644
provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)
648-
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger)
645+
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
649646
require.NoError(t, err)
650647
return provider, bridge
651648
},
@@ -656,7 +653,7 @@ func TestFallthrough(t *testing.T) {
656653
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
657654
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
658655
provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))
659-
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger)
656+
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
660657
require.NoError(t, err)
661658
return provider, bridge
662659
},
@@ -762,7 +759,8 @@ func TestAnthropicInjectedTools(t *testing.T) {
762759

763760
configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
764761
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
765-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
762+
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
763+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
766764
}
767765

768766
// Build the requirements & make the assertions which are common to all providers.
@@ -843,7 +841,8 @@ func TestOpenAIInjectedTools(t *testing.T) {
843841

844842
configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
845843
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
846-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
844+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
845+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
847846
}
848847

849848
// Build the requirements & make the assertions which are common to all providers.
@@ -1029,7 +1028,8 @@ func TestErrorHandling(t *testing.T) {
10291028
createRequestFunc: createAnthropicMessagesReq,
10301029
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
10311030
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1032-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
1031+
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1032+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
10331033
},
10341034
responseHandlerFn: func(resp *http.Response) {
10351035
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
@@ -1046,7 +1046,8 @@ func TestErrorHandling(t *testing.T) {
10461046
createRequestFunc: createOpenAIChatCompletionsReq,
10471047
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
10481048
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1049-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
1049+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1050+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
10501051
},
10511052
responseHandlerFn: func(resp *http.Response) {
10521053
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
@@ -1134,7 +1135,8 @@ func TestErrorHandling(t *testing.T) {
11341135
createRequestFunc: createAnthropicMessagesReq,
11351136
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
11361137
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1137-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
1138+
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1139+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
11381140
},
11391141
responseHandlerFn: func(resp *http.Response) {
11401142
// Server responds first with 200 OK then starts streaming.
@@ -1152,7 +1154,8 @@ func TestErrorHandling(t *testing.T) {
11521154
createRequestFunc: createOpenAIChatCompletionsReq,
11531155
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
11541156
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1155-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
1157+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1158+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
11561159
},
11571160
responseHandlerFn: func(resp *http.Response) {
11581161
// Server responds first with 200 OK then starts streaming.
@@ -1238,15 +1241,17 @@ func TestStableRequestEncoding(t *testing.T) {
12381241
fixture: antSimple,
12391242
createRequestFunc: createAnthropicMessagesReq,
12401243
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1241-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger)
1244+
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1245+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
12421246
},
12431247
},
12441248
{
12451249
name: aibridge.ProviderOpenAI,
12461250
fixture: oaiSimple,
12471251
createRequestFunc: createOpenAIChatCompletionsReq,
12481252
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1249-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger)
1253+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1254+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
12501255
},
12511256
},
12521257
}
@@ -1352,7 +1357,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
13521357
fixture: antSimple,
13531358
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
13541359
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
1355-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger)
1360+
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1361+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
13561362
},
13571363
createRequest: createAnthropicMessagesReq,
13581364
envVars: map[string]string{
@@ -1365,7 +1371,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
13651371
fixture: oaiSimple,
13661372
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
13671373
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
1368-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger)
1374+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1375+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger)
13691376
},
13701377
createRequest: createOpenAIChatCompletionsReq,
13711378
envVars: map[string]string{

0 commit comments

Comments
 (0)