From 6f7b1f1d96e62da78fb8c5c0d38e14f66e1a64f5 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 18 Nov 2025 09:48:11 +0200 Subject: [PATCH 1/3] chore: drive-by refactor Signed-off-by: Danny Kopping --- bridge_integration_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index e570dd9..5722f48 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -504,7 +504,7 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, mcp.NewServerProxyManager(nil)) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -655,7 +655,7 @@ func TestFallthrough(t *testing.T) { fixture: oaiFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey))) + provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) require.NoError(t, err) return provider, bridge @@ -843,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) } // Build the requirements & make the assertions which are common to all providers. @@ -1046,7 +1046,7 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1152,7 +1152,7 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1246,7 +1246,7 @@ func TestStableRequestEncoding(t *testing.T) { fixture: oaiSimple, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) }, }, } From 527d314b3a201e41533e42cbee390b6fe5bd5e60 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 18 Nov 2025 15:20:22 +0200 Subject: [PATCH 2/3] chore: prevent default options adding unwanted headers Signed-off-by: Danny Kopping --- intercept_anthropic_messages_base.go | 6 +++--- intercept_anthropic_messages_blocking.go | 4 ++-- intercept_anthropic_messages_streaming.go | 4 ++-- intercept_openai_chat_base.go | 8 +++----- intercept_openai_chat_blocking.go | 4 ++-- intercept_openai_chat_streaming.go | 4 ++-- 6 files changed, 14 insertions(+), 16 deletions(-) diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 2367933..5049e54 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -96,7 +96,7 @@ func (i *AnthropicMessagesInterceptionBase) isSmallFastModel() bool { return strings.Contains(string(i.req.Model), "haiku") } -func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Context, opts ...option.RequestOption) (anthropic.Client, error) { +func (i *AnthropicMessagesInterceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { opts = append(opts, option.WithAPIKey(i.cfg.Key)) opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) @@ -105,7 +105,7 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte defer cancel() bedrockOpt, err := i.withAWSBedrock(ctx, i.bedrockCfg) if err != nil { - return anthropic.Client{}, err + return anthropic.MessageService{}, err } opts = append(opts, bedrockOpt) i.augmentRequestForBedrock() @@ -122,7 +122,7 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte } } - return anthropic.NewClient(opts...), nil + return anthropic.NewMessageService(opts...), nil } func (i *AnthropicMessagesInterceptionBase) withAWSBedrock(ctx context.Context, cfg *AWSBedrockConfig) (option.RequestOption, error) { diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index f978113..9750d30 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -58,7 +58,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout - client, err := i.newAnthropicClient(ctx, opts...) + client, err := i.newMessagesService(ctx, opts...) if err != nil { err = fmt.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -73,7 +73,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr var cumulativeUsage anthropic.Usage for { - resp, err = client.Messages.New(ctx, messages) + resp, err = client.New(ctx, messages) if err != nil { if isConnError(err) { // Can't write a response, just error out. diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 15bb6d8..daa0352 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -88,7 +88,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW streamCtx, streamCancel := context.WithCancelCause(ctx) defer streamCancel(errors.New("deferred")) - client, err := i.newAnthropicClient(streamCtx) + svc, err := i.newMessagesService(streamCtx) if err != nil { err = fmt.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -118,7 +118,7 @@ newStream: break } - stream := client.Messages.NewStreaming(streamCtx, messages) + stream := svc.NewStreaming(streamCtx, messages) var message anthropic.Message var lastToolName string diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 44ef582..20db323 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -26,12 +26,10 @@ type OpenAIChatInterceptionBase struct { mcpProxy mcp.ServerProxier } -func (i *OpenAIChatInterceptionBase) newOpenAIClient(baseURL, key string) openai.Client { - var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(key)) - opts = append(opts, option.WithBaseURL(baseURL)) +func (i *OpenAIChatInterceptionBase) newCompletionsService(baseURL, key string) openai.ChatCompletionService { + opts := []option.RequestOption{option.WithAPIKey(key), option.WithBaseURL(baseURL)} - return openai.NewClient(opts...) + return openai.NewChatCompletionService(opts...) } func (i *OpenAIChatInterceptionBase) ID() uuid.UUID { diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 4c019a6..26a5dfc 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -41,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := i.newOpenAIClient(i.baseURL, i.key) + client := i.newCompletionsService(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) var ( @@ -61,7 +61,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout - completion, err = client.Chat.Completions.New(ctx, i.req.ChatCompletionNewParams, opts...) + completion, err = client.New(ctx, i.req.ChatCompletionNewParams, opts...) if err != nil { break } diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index cc1a64a..9798505 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -65,7 +65,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - client := i.newOpenAIClient(i.baseURL, i.key) + svc := i.newCompletionsService(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) @@ -100,7 +100,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, interceptionErr error ) for { - stream = client.Chat.Completions.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) + stream = svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName) var toolCall *openai.FinishedChatCompletionToolCall From 2d6ded83190c86c254edf642532cea49f08fd31a Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 18 Nov 2025 18:28:06 +0200 Subject: [PATCH 3/3] chore: add test Signed-off-by: Danny Kopping --- bridge_integration_test.go | 97 ++++++++++++++++++++++++ intercept_anthropic_messages_blocking.go | 4 +- intercept_openai_chat_blocking.go | 4 +- 3 files changed, 101 insertions(+), 4 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 5722f48..36f2a3e 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1334,6 +1334,103 @@ func TestStableRequestEncoding(t *testing.T) { } } +func TestEnvironmentDoNotLeak(t *testing.T) { + // NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution. + + // Test that environment variables containing API keys/tokens are not leaked to upstream requests. + // See https://github.com/coder/aibridge/issues/60. + testCases := []struct { + name string + fixture []byte + configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error) + createRequest func(*testing.T, string, []byte) *http.Request + envVars map[string]string + headerName string + }{ + { + name: aibridge.ProviderAnthropic, + fixture: antSimple, + configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, mcp.NewServerProxyManager(nil)) + }, + createRequest: createAnthropicMessagesReq, + envVars: map[string]string{ + "ANTHROPIC_AUTH_TOKEN": "should-not-leak", + }, + headerName: "Authorization", // We only send through the X-Api-Key, so this one should not be present. + }, + { + name: aibridge.ProviderOpenAI, + fixture: oaiSimple, + configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) + }, + createRequest: createOpenAIChatCompletionsReq, + envVars: map[string]string{ + "OPENAI_ORG_ID": "should-not-leak", + }, + headerName: "OpenAI-Organization", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution. + + arc := txtar.Parse(tc.fixture) + files := filesMap(arc) + reqBody := files[fixtureRequest] + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Track headers received by the upstream server. + var receivedHeaders http.Header + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureNonStreamingResponse]) + })) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + srv.Start() + t.Cleanup(srv.Close) + + // Set environment variables that the SDK would automatically read. + // These should NOT leak into upstream requests. + for key, val := range tc.envVars { + t.Setenv(key, val) + } + + recorderClient := &mockRecorderClient{} + b, err := tc.configureFunc(srv.URL, recorderClient) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(b) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + mockSrv.Start() + + req := tc.createRequest(t, mockSrv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + + // Verify that environment values did not leak. + require.NotNil(t, receivedHeaders) + require.Empty(t, receivedHeaders.Get(tc.headerName)) + }) + } +} + func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 { var total int64 for _, el := range in { diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index 9750d30..3aef2dd 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -58,7 +58,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout - client, err := i.newMessagesService(ctx, opts...) + svc, err := i.newMessagesService(ctx, opts...) if err != nil { err = fmt.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -73,7 +73,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr var cumulativeUsage anthropic.Usage for { - resp, err = client.New(ctx, messages) + resp, err = svc.New(ctx, messages) if err != nil { if isConnError(err) { // Can't write a response, just error out. diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 26a5dfc..d9bb50e 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -41,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := i.newCompletionsService(i.baseURL, i.key) + svc := i.newCompletionsService(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) var ( @@ -61,7 +61,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout - completion, err = client.New(ctx, i.req.ChatCompletionNewParams, opts...) + completion, err = svc.New(ctx, i.req.ChatCompletionNewParams, opts...) if err != nil { break }