-
Notifications
You must be signed in to change notification settings - Fork 3
feat: add path of providers base url to pass through requests #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pawbana
wants to merge
6
commits into
main
Choose a base branch
from
pb/base_url_with_path
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+311
−110
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
41548c5
fix: add base url path as a prefix to passed through requests
pawbana 1fc0682
remove accidentally aded new line
pawbana 3a57519
review 1: added passthrough test
pawbana 4a65a3f
fmt fix
pawbana 0f1cb26
review 2: fixed test case names in TestFallthrough
pawbana 5ca5be8
review 2: extended comment about leading slash
pawbana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,7 +109,7 @@ func TestAnthropicMessages(t *testing.T) { | |
|
|
||
| ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) | ||
| t.Cleanup(cancel) | ||
| srv := newMockServer(ctx, t, files, nil) | ||
| srv := newMockServer(ctx, t, files, nil, nil) | ||
| t.Cleanup(srv.Close) | ||
|
|
||
| recorderClient := &testutil.MockRecorder{} | ||
|
|
@@ -379,7 +379,7 @@ func TestOpenAIChatCompletions(t *testing.T) { | |
|
|
||
| ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) | ||
| t.Cleanup(cancel) | ||
| srv := newMockServer(ctx, t, files, nil) | ||
| srv := newMockServer(ctx, t, files, nil, nil) | ||
| t.Cleanup(srv.Close) | ||
|
|
||
| recorderClient := &testutil.MockRecorder{} | ||
|
|
@@ -483,7 +483,7 @@ func TestOpenAIChatCompletions(t *testing.T) { | |
| t.Cleanup(cancel) | ||
|
|
||
| // Setup mock server with response mutator for multi-turn interaction. | ||
| srv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte { | ||
| srv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte { | ||
| if reqCount == 1 { | ||
| // First request gets the tool call response | ||
| return resp | ||
|
|
@@ -556,91 +556,128 @@ func TestOpenAIChatCompletions(t *testing.T) { | |
| func TestSimple(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| getAnthropicResponseID := func(streaming bool, resp *http.Response) (string, error) { | ||
| if streaming { | ||
| decoder := ssestream.NewDecoder(resp) | ||
| stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) | ||
| var message anthropic.Message | ||
| for stream.Next() { | ||
| event := stream.Current() | ||
| if err := message.Accumulate(event); err != nil { | ||
| return "", fmt.Errorf("accumulate event: %w", err) | ||
| } | ||
| } | ||
| if stream.Err() != nil { | ||
| return "", fmt.Errorf("stream error: %w", stream.Err()) | ||
| } | ||
| return message.ID, nil | ||
| } | ||
|
|
||
| body, err := io.ReadAll(resp.Body) | ||
| if err != nil { | ||
| return "", fmt.Errorf("read body: %w", err) | ||
| } | ||
|
|
||
| var message anthropic.Message | ||
| if err := json.Unmarshal(body, &message); err != nil { | ||
| return "", fmt.Errorf("unmarshal response: %w", err) | ||
| } | ||
| return message.ID, nil | ||
| } | ||
|
|
||
| getOpenAIResponseID := func(streaming bool, resp *http.Response) (string, error) { | ||
| if streaming { | ||
| // Parse the response stream. | ||
| decoder := oaissestream.NewDecoder(resp) | ||
| stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) | ||
| var message openai.ChatCompletionAccumulator | ||
| for stream.Next() { | ||
| chunk := stream.Current() | ||
| message.AddChunk(chunk) | ||
| } | ||
| if stream.Err() != nil { | ||
| return "", fmt.Errorf("stream error: %w", stream.Err()) | ||
| } | ||
| return message.ID, nil | ||
| } | ||
|
|
||
| // Parse & unmarshal the response. | ||
| body, err := io.ReadAll(resp.Body) | ||
| if err != nil { | ||
| return "", fmt.Errorf("read body: %w", err) | ||
| } | ||
|
|
||
| var message openai.ChatCompletion | ||
| if err := json.Unmarshal(body, &message); err != nil { | ||
| return "", fmt.Errorf("unmarshal response: %w", err) | ||
| } | ||
| return message.ID, nil | ||
| } | ||
|
|
||
| // Common configuration functions for each provider type. | ||
| configureAnthropic := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { | ||
| t.Helper() | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} | ||
| return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) | ||
| } | ||
|
|
||
| configureOpenAI := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { | ||
| t.Helper() | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} | ||
| return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) | ||
| } | ||
|
|
||
| testCases := []struct { | ||
| name string | ||
| fixture []byte | ||
| configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error) | ||
| getResponseIDFunc func(bool, *http.Response) (string, error) | ||
| basePath string | ||
| expectedPath string | ||
| configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error) | ||
| getResponseIDFunc func(streaming bool, resp *http.Response) (string, error) | ||
| createRequest func(*testing.T, string, []byte) *http.Request | ||
| expectedMsgID string | ||
| }{ | ||
| { | ||
| name: config.ProviderAnthropic, | ||
| fixture: fixtures.AntSimple, | ||
| configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| provider := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} | ||
| return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) | ||
| }, | ||
| getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { | ||
| if streaming { | ||
| decoder := ssestream.NewDecoder(resp) | ||
| stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) | ||
| var message anthropic.Message | ||
| for stream.Next() { | ||
| event := stream.Current() | ||
| if err := message.Accumulate(event); err != nil { | ||
| return "", fmt.Errorf("accumulate event: %w", err) | ||
| } | ||
| } | ||
| if stream.Err() != nil { | ||
| return "", fmt.Errorf("stream error: %w", stream.Err()) | ||
| } | ||
| return message.ID, nil | ||
| } | ||
|
|
||
| body, err := io.ReadAll(resp.Body) | ||
| if err != nil { | ||
| return "", fmt.Errorf("read body: %w", err) | ||
| } | ||
|
|
||
| var message anthropic.Message | ||
| if err := json.Unmarshal(body, &message); err != nil { | ||
| return "", fmt.Errorf("unmarshal response: %w", err) | ||
| } | ||
| return message.ID, nil | ||
| }, | ||
| createRequest: createAnthropicMessagesReq, | ||
| expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", | ||
| name: config.ProviderAnthropic, | ||
| fixture: fixtures.AntSimple, | ||
| basePath: "", | ||
| expectedPath: "/v1/messages", | ||
| configureFunc: configureAnthropic, | ||
| getResponseIDFunc: getAnthropicResponseID, | ||
| createRequest: createAnthropicMessagesReq, | ||
| expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", | ||
| }, | ||
| { | ||
| name: config.ProviderOpenAI, | ||
| fixture: fixtures.OaiChatSimple, | ||
| configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} | ||
| return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) | ||
| }, | ||
| getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { | ||
| if streaming { | ||
| // Parse the response stream. | ||
| decoder := oaissestream.NewDecoder(resp) | ||
| stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) | ||
| var message openai.ChatCompletionAccumulator | ||
| for stream.Next() { | ||
| chunk := stream.Current() | ||
| message.AddChunk(chunk) | ||
| } | ||
| if stream.Err() != nil { | ||
| return "", fmt.Errorf("stream error: %w", stream.Err()) | ||
| } | ||
| return message.ID, nil | ||
| } | ||
|
|
||
| // Parse & unmarshal the response. | ||
| body, err := io.ReadAll(resp.Body) | ||
| if err != nil { | ||
| return "", fmt.Errorf("read body: %w", err) | ||
| } | ||
|
|
||
| var message openai.ChatCompletion | ||
| if err := json.Unmarshal(body, &message); err != nil { | ||
| return "", fmt.Errorf("unmarshal response: %w", err) | ||
| } | ||
| return message.ID, nil | ||
| }, | ||
| createRequest: createOpenAIChatCompletionsReq, | ||
| expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", | ||
| name: config.ProviderOpenAI, | ||
| fixture: fixtures.OaiChatSimple, | ||
| basePath: "", | ||
| expectedPath: "/chat/completions", | ||
| configureFunc: configureOpenAI, | ||
| getResponseIDFunc: getOpenAIResponseID, | ||
| createRequest: createOpenAIChatCompletionsReq, | ||
| expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", | ||
| }, | ||
| { | ||
| name: config.ProviderAnthropic + "_baseURL_path", | ||
| fixture: fixtures.AntSimple, | ||
| basePath: "/api", | ||
| expectedPath: "/api/v1/messages", | ||
| configureFunc: configureAnthropic, | ||
| getResponseIDFunc: getAnthropicResponseID, | ||
| createRequest: createAnthropicMessagesReq, | ||
| expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", | ||
| }, | ||
| { | ||
| name: config.ProviderOpenAI + "_baseURL_path", | ||
| fixture: fixtures.OaiChatSimple, | ||
| basePath: "/api", | ||
| expectedPath: "/api/chat/completions", | ||
| configureFunc: configureOpenAI, | ||
| getResponseIDFunc: getOpenAIResponseID, | ||
| createRequest: createOpenAIChatCompletionsReq, | ||
| expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", | ||
| }, | ||
| } | ||
|
|
||
|
|
@@ -671,12 +708,14 @@ func TestSimple(t *testing.T) { | |
| // Given: a mock API server and a Bridge through which the requests will flow. | ||
| ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) | ||
| t.Cleanup(cancel) | ||
| srv := newMockServer(ctx, t, files, nil) | ||
| srv := newMockServer(ctx, t, files, func(r *http.Request) { | ||
| require.Equal(t, tc.expectedPath, r.URL.Path) | ||
| }, nil) | ||
| t.Cleanup(srv.Close) | ||
|
|
||
| recorderClient := &testutil.MockRecorder{} | ||
|
|
||
| b, err := tc.configureFunc(srv.URL, recorderClient) | ||
| b, err := tc.configureFunc(t, srv.URL+tc.basePath, recorderClient) | ||
| require.NoError(t, err) | ||
|
|
||
| mockSrv := httptest.NewUnstartedServer(b) | ||
|
|
@@ -734,12 +773,16 @@ func TestFallthrough(t *testing.T) { | |
|
|
||
| testCases := []struct { | ||
| name string | ||
| providerName string | ||
| fixture []byte | ||
| basePath string | ||
| configureFunc func(string, aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) | ||
| }{ | ||
| { | ||
| name: config.ProviderAnthropic, | ||
| fixture: fixtures.AntFallthrough, | ||
| name: "ant_empty_base_url_path", | ||
| providerName: config.ProviderAnthropic, | ||
| fixture: fixtures.AntFallthrough, | ||
| basePath: "", | ||
| configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) | ||
|
|
@@ -749,8 +792,36 @@ func TestFallthrough(t *testing.T) { | |
| }, | ||
| }, | ||
| { | ||
| name: config.ProviderOpenAI, | ||
| fixture: fixtures.OaiChatFallthrough, | ||
| name: "oai_empty_base_url_path", | ||
| providerName: config.ProviderOpenAI, | ||
| fixture: fixtures.OaiChatFallthrough, | ||
| basePath: "", | ||
| configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) | ||
| bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) | ||
| require.NoError(t, err) | ||
| return provider, bridge | ||
| }, | ||
| }, | ||
| { | ||
| name: "ant_some_base_url_path", | ||
| providerName: config.ProviderAnthropic, | ||
| fixture: fixtures.AntFallthrough, | ||
| basePath: "/api", | ||
| configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) | ||
| bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) | ||
| require.NoError(t, err) | ||
| return provider, bridge | ||
| }, | ||
| }, | ||
| { | ||
| name: "oai_some_base_url_path", | ||
| providerName: config.ProviderOpenAI, | ||
| fixture: fixtures.OaiChatFallthrough, | ||
| basePath: "/api", | ||
| configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { | ||
| logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) | ||
| provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) | ||
|
|
@@ -770,11 +841,12 @@ func TestFallthrough(t *testing.T) { | |
|
|
||
| files := filesMap(arc) | ||
| require.Contains(t, files, fixtureResponse) | ||
| expectedPath := tc.basePath + "/v1/models" | ||
|
|
||
| var receivedHeaders *http.Header | ||
| respBody := files[fixtureResponse] | ||
| upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| if r.URL.Path != "/v1/models" { | ||
| if r.URL.Path != expectedPath { | ||
| t.Errorf("unexpected request path: %q", r.URL.Path) | ||
| t.FailNow() | ||
| } | ||
|
|
@@ -789,7 +861,8 @@ func TestFallthrough(t *testing.T) { | |
|
|
||
| recorderClient := &testutil.MockRecorder{} | ||
|
|
||
| provider, bridge := tc.configureFunc(upstream.URL, recorderClient) | ||
| upstreamURL := upstream.URL + tc.basePath | ||
| provider, bridge := tc.configureFunc(upstreamURL, recorderClient) | ||
|
|
||
| bridgeSrv := httptest.NewUnstartedServer(bridge) | ||
| bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { | ||
|
|
@@ -798,7 +871,7 @@ func TestFallthrough(t *testing.T) { | |
| bridgeSrv.Start() | ||
| t.Cleanup(bridgeSrv.Close) | ||
|
|
||
| req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.name), nil) | ||
| req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.providerName), nil) | ||
| require.NoError(t, err) | ||
|
|
||
| resp, err := http.DefaultClient.Do(req) | ||
|
|
@@ -1074,7 +1147,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu | |
| t.Cleanup(cancel) | ||
|
|
||
| // Setup mock server with response mutator for multi-turn interaction. | ||
| mockSrv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte { | ||
| mockSrv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte { | ||
| if reqCount == 1 { | ||
| return resp // First request gets the normal response (with tool call). | ||
| } | ||
|
|
@@ -1310,7 +1383,7 @@ func TestErrorHandling(t *testing.T) { | |
| reqBody := files[fixtureRequest] | ||
|
|
||
| // Setup mock server. | ||
| mockSrv := newMockServer(ctx, t, files, nil) | ||
| mockSrv := newMockServer(ctx, t, files, nil, nil) | ||
| mockSrv.statusCode = http.StatusInternalServerError | ||
| t.Cleanup(mockSrv.Close) | ||
|
|
||
|
|
@@ -1983,11 +2056,15 @@ type mockServer struct { | |
| statusCode int | ||
| } | ||
|
|
||
| func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer { | ||
| func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requestValidatorFn func(*http.Request), responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For later: we should refactor this to take a variadic set of options.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will be done in: #73 |
||
| t.Helper() | ||
|
|
||
| ms := &mockServer{} | ||
| srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| if requestValidatorFn != nil { | ||
| requestValidatorFn(r) | ||
| } | ||
|
|
||
| statusCode := http.StatusOK | ||
| if ms.statusCode != 0 { | ||
| statusCode = ms.statusCode | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let's move these into helper funcs; they're probably reusable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be done in: #73