Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 103 additions & 6 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ func TestSimple(t *testing.T) {
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, mcp.NewServerProxyManager(nil))
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -655,7 +655,7 @@ func TestFallthrough(t *testing.T) {
fixture: oaiFallthrough,
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))
provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
require.NoError(t, err)
return provider, bridge
Expand Down Expand Up @@ -843,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -1046,7 +1046,7 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(resp *http.Response) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
Expand Down Expand Up @@ -1152,7 +1152,7 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(resp *http.Response) {
// Server responds first with 200 OK then starts streaming.
Expand Down Expand Up @@ -1246,7 +1246,7 @@ func TestStableRequestEncoding(t *testing.T) {
fixture: oaiSimple,
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
},
},
}
Expand Down Expand Up @@ -1334,6 +1334,103 @@ func TestStableRequestEncoding(t *testing.T) {
}
}

func TestEnvironmentDoNotLeak(t *testing.T) {
// NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.

// Test that environment variables containing API keys/tokens are not leaked to upstream requests.
// See https://github.com/coder/aibridge/issues/60.
testCases := []struct {
name string
fixture []byte
configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error)
createRequest func(*testing.T, string, []byte) *http.Request
envVars map[string]string
headerName string
}{
{
name: aibridge.ProviderAnthropic,
fixture: antSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, mcp.NewServerProxyManager(nil))
},
createRequest: createAnthropicMessagesReq,
envVars: map[string]string{
"ANTHROPIC_AUTH_TOKEN": "should-not-leak",
},
headerName: "Authorization", // We only send through the X-Api-Key, so this one should not be present.
},
{
name: aibridge.ProviderOpenAI,
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
},
createRequest: createOpenAIChatCompletionsReq,
envVars: map[string]string{
"OPENAI_ORG_ID": "should-not-leak",
},
headerName: "OpenAI-Organization",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution.

arc := txtar.Parse(tc.fixture)
files := filesMap(arc)
reqBody := files[fixtureRequest]

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

// Track headers received by the upstream server.
var receivedHeaders http.Header
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(files[fixtureNonStreamingResponse])
}))
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
srv.Start()
t.Cleanup(srv.Close)

// Set environment variables that the SDK would automatically read.
// These should NOT leak into upstream requests.
for key, val := range tc.envVars {
t.Setenv(key, val)
}

recorderClient := &mockRecorderClient{}
b, err := tc.configureFunc(srv.URL, recorderClient)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
t.Cleanup(mockSrv.Close)
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibridge.AsActor(ctx, userID, nil)
}
mockSrv.Start()

req := tc.createRequest(t, mockSrv.URL, reqBody)
client := &http.Client{}
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()

// Verify that environment values did not leak.
require.NotNil(t, receivedHeaders)
require.Empty(t, receivedHeaders.Get(tc.headerName))
})
}
}

func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 {
var total int64
for _, el := range in {
Expand Down
6 changes: 3 additions & 3 deletions intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (i *AnthropicMessagesInterceptionBase) isSmallFastModel() bool {
return strings.Contains(string(i.req.Model), "haiku")
}

func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Context, opts ...option.RequestOption) (anthropic.Client, error) {
func (i *AnthropicMessagesInterceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) {
opts = append(opts, option.WithAPIKey(i.cfg.Key))
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))

Expand All @@ -105,7 +105,7 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte
defer cancel()
bedrockOpt, err := i.withAWSBedrock(ctx, i.bedrockCfg)
if err != nil {
return anthropic.Client{}, err
return anthropic.MessageService{}, err
}
opts = append(opts, bedrockOpt)
i.augmentRequestForBedrock()
Expand All @@ -122,7 +122,7 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte
}
}

return anthropic.NewClient(opts...), nil
return anthropic.NewMessageService(opts...), nil
}

func (i *AnthropicMessagesInterceptionBase) withAWSBedrock(ctx context.Context, cfg *AWSBedrockConfig) (option.RequestOption, error) {
Expand Down
4 changes: 2 additions & 2 deletions intercept_anthropic_messages_blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr

opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout

client, err := i.newAnthropicClient(ctx, opts...)
svc, err := i.newMessagesService(ctx, opts...)
if err != nil {
err = fmt.Errorf("create anthropic client: %w", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand All @@ -73,7 +73,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
var cumulativeUsage anthropic.Usage

for {
resp, err = client.Messages.New(ctx, messages)
resp, err = svc.New(ctx, messages)
if err != nil {
if isConnError(err) {
// Can't write a response, just error out.
Expand Down
4 changes: 2 additions & 2 deletions intercept_anthropic_messages_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW
streamCtx, streamCancel := context.WithCancelCause(ctx)
defer streamCancel(errors.New("deferred"))

client, err := i.newAnthropicClient(streamCtx)
svc, err := i.newMessagesService(streamCtx)
if err != nil {
err = fmt.Errorf("create anthropic client: %w", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -118,7 +118,7 @@ newStream:
break
}

stream := client.Messages.NewStreaming(streamCtx, messages)
stream := svc.NewStreaming(streamCtx, messages)

var message anthropic.Message
var lastToolName string
Expand Down
8 changes: 3 additions & 5 deletions intercept_openai_chat_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ type OpenAIChatInterceptionBase struct {
mcpProxy mcp.ServerProxier
}

func (i *OpenAIChatInterceptionBase) newOpenAIClient(baseURL, key string) openai.Client {
var opts []option.RequestOption
opts = append(opts, option.WithAPIKey(key))
opts = append(opts, option.WithBaseURL(baseURL))
func (i *OpenAIChatInterceptionBase) newCompletionsService(baseURL, key string) openai.ChatCompletionService {
opts := []option.RequestOption{option.WithAPIKey(key), option.WithBaseURL(baseURL)}

return openai.NewClient(opts...)
return openai.NewChatCompletionService(opts...)
}

func (i *OpenAIChatInterceptionBase) ID() uuid.UUID {
Expand Down
4 changes: 2 additions & 2 deletions intercept_openai_chat_blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
}

ctx := r.Context()
client := i.newOpenAIClient(i.baseURL, i.key)
svc := i.newCompletionsService(i.baseURL, i.key)
logger := i.logger.With(slog.F("model", i.req.Model))

var (
Expand All @@ -61,7 +61,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
var opts []option.RequestOption
opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout

completion, err = client.Chat.Completions.New(ctx, i.req.ChatCompletionNewParams, opts...)
completion, err = svc.New(ctx, i.req.ChatCompletionNewParams, opts...)
if err != nil {
break
}
Expand Down
4 changes: 2 additions & 2 deletions intercept_openai_chat_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
defer cancel()
r = r.WithContext(ctx) // Rewire context for SSE cancellation.

client := i.newOpenAIClient(i.baseURL, i.key)
svc := i.newCompletionsService(i.baseURL, i.key)
logger := i.logger.With(slog.F("model", i.req.Model))

streamCtx, streamCancel := context.WithCancelCause(ctx)
Expand Down Expand Up @@ -100,7 +100,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
interceptionErr error
)
for {
stream = client.Chat.Completions.NewStreaming(streamCtx, i.req.ChatCompletionNewParams)
stream = svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams)
processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName)

var toolCall *openai.FinishedChatCompletionToolCall
Expand Down