Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apidump_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestAPIDump(t *testing.T) {
reqBody := files[fixtureRequest]

// Setup mock upstream server.
srv := newMockServer(ctx, t, files, nil)
srv := newMockServer(ctx, t, files, nil, nil)
t.Cleanup(srv.Close)

// Create temp dir for API dumps.
Expand Down
259 changes: 168 additions & 91 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestAnthropicMessages(t *testing.T) {

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)
srv := newMockServer(ctx, t, files, nil)
srv := newMockServer(ctx, t, files, nil, nil)
t.Cleanup(srv.Close)

recorderClient := &testutil.MockRecorder{}
Expand Down Expand Up @@ -379,7 +379,7 @@ func TestOpenAIChatCompletions(t *testing.T) {

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)
srv := newMockServer(ctx, t, files, nil)
srv := newMockServer(ctx, t, files, nil, nil)
t.Cleanup(srv.Close)

recorderClient := &testutil.MockRecorder{}
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
t.Cleanup(cancel)

// Setup mock server with response mutator for multi-turn interaction.
srv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte {
srv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte {
if reqCount == 1 {
// First request gets the tool call response
return resp
Expand Down Expand Up @@ -556,91 +556,128 @@ func TestOpenAIChatCompletions(t *testing.T) {
func TestSimple(t *testing.T) {
t.Parallel()

getAnthropicResponseID := func(streaming bool, resp *http.Response) (string, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's move these into helper funcs; they're probably reusable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be done in: #73

if streaming {
decoder := ssestream.NewDecoder(resp)
stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil)
var message anthropic.Message
for stream.Next() {
event := stream.Current()
if err := message.Accumulate(event); err != nil {
return "", fmt.Errorf("accumulate event: %w", err)
}
}
if stream.Err() != nil {
return "", fmt.Errorf("stream error: %w", stream.Err())
}
return message.ID, nil
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read body: %w", err)
}

var message anthropic.Message
if err := json.Unmarshal(body, &message); err != nil {
return "", fmt.Errorf("unmarshal response: %w", err)
}
return message.ID, nil
}

getOpenAIResponseID := func(streaming bool, resp *http.Response) (string, error) {
if streaming {
// Parse the response stream.
decoder := oaissestream.NewDecoder(resp)
stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil)
var message openai.ChatCompletionAccumulator
for stream.Next() {
chunk := stream.Current()
message.AddChunk(chunk)
}
if stream.Err() != nil {
return "", fmt.Errorf("stream error: %w", stream.Err())
}
return message.ID, nil
}

// Parse & unmarshal the response.
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read body: %w", err)
}

var message openai.ChatCompletion
if err := json.Unmarshal(body, &message); err != nil {
return "", fmt.Errorf("unmarshal response: %w", err)
}
return message.ID, nil
}

// Common configuration functions for each provider type.
configureAnthropic := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
}

configureOpenAI := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
}

testCases := []struct {
name string
fixture []byte
configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error)
getResponseIDFunc func(bool, *http.Response) (string, error)
basePath string
expectedPath string
configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error)
getResponseIDFunc func(streaming bool, resp *http.Response) (string, error)
createRequest func(*testing.T, string, []byte) *http.Request
expectedMsgID string
}{
{
name: config.ProviderAnthropic,
fixture: fixtures.AntSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
decoder := ssestream.NewDecoder(resp)
stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil)
var message anthropic.Message
for stream.Next() {
event := stream.Current()
if err := message.Accumulate(event); err != nil {
return "", fmt.Errorf("accumulate event: %w", err)
}
}
if stream.Err() != nil {
return "", fmt.Errorf("stream error: %w", stream.Err())
}
return message.ID, nil
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read body: %w", err)
}

var message anthropic.Message
if err := json.Unmarshal(body, &message); err != nil {
return "", fmt.Errorf("unmarshal response: %w", err)
}
return message.ID, nil
},
createRequest: createAnthropicMessagesReq,
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
name: config.ProviderAnthropic,
fixture: fixtures.AntSimple,
basePath: "",
expectedPath: "/v1/messages",
configureFunc: configureAnthropic,
getResponseIDFunc: getAnthropicResponseID,
createRequest: createAnthropicMessagesReq,
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
},
{
name: config.ProviderOpenAI,
fixture: fixtures.OaiChatSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
// Parse the response stream.
decoder := oaissestream.NewDecoder(resp)
stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil)
var message openai.ChatCompletionAccumulator
for stream.Next() {
chunk := stream.Current()
message.AddChunk(chunk)
}
if stream.Err() != nil {
return "", fmt.Errorf("stream error: %w", stream.Err())
}
return message.ID, nil
}

// Parse & unmarshal the response.
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read body: %w", err)
}

var message openai.ChatCompletion
if err := json.Unmarshal(body, &message); err != nil {
return "", fmt.Errorf("unmarshal response: %w", err)
}
return message.ID, nil
},
createRequest: createOpenAIChatCompletionsReq,
expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N",
name: config.ProviderOpenAI,
fixture: fixtures.OaiChatSimple,
basePath: "",
expectedPath: "/chat/completions",
configureFunc: configureOpenAI,
getResponseIDFunc: getOpenAIResponseID,
createRequest: createOpenAIChatCompletionsReq,
expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N",
},
{
name: config.ProviderAnthropic + "_baseURL_path",
fixture: fixtures.AntSimple,
basePath: "/api",
expectedPath: "/api/v1/messages",
configureFunc: configureAnthropic,
getResponseIDFunc: getAnthropicResponseID,
createRequest: createAnthropicMessagesReq,
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
},
{
name: config.ProviderOpenAI + "_baseURL_path",
fixture: fixtures.OaiChatSimple,
basePath: "/api",
expectedPath: "/api/chat/completions",
configureFunc: configureOpenAI,
getResponseIDFunc: getOpenAIResponseID,
createRequest: createOpenAIChatCompletionsReq,
expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N",
},
}

Expand Down Expand Up @@ -671,12 +708,14 @@ func TestSimple(t *testing.T) {
// Given: a mock API server and a Bridge through which the requests will flow.
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)
srv := newMockServer(ctx, t, files, nil)
srv := newMockServer(ctx, t, files, func(r *http.Request) {
require.Equal(t, tc.expectedPath, r.URL.Path)
}, nil)
t.Cleanup(srv.Close)

recorderClient := &testutil.MockRecorder{}

b, err := tc.configureFunc(srv.URL, recorderClient)
b, err := tc.configureFunc(t, srv.URL+tc.basePath, recorderClient)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -734,12 +773,16 @@ func TestFallthrough(t *testing.T) {

testCases := []struct {
name string
providerName string
fixture []byte
basePath string
configureFunc func(string, aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge)
}{
{
name: config.ProviderAnthropic,
fixture: fixtures.AntFallthrough,
name: "ant_empty_base_url_path",
providerName: config.ProviderAnthropic,
fixture: fixtures.AntFallthrough,
basePath: "",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)
Expand All @@ -749,8 +792,36 @@ func TestFallthrough(t *testing.T) {
},
},
{
name: config.ProviderOpenAI,
fixture: fixtures.OaiChatFallthrough,
name: "oai_empty_base_url_path",
providerName: config.ProviderOpenAI,
fixture: fixtures.OaiChatFallthrough,
basePath: "",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewOpenAI(openaiCfg(addr, apiKey))
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
require.NoError(t, err)
return provider, bridge
},
},
{
name: "ant_some_base_url_path",
providerName: config.ProviderAnthropic,
fixture: fixtures.AntFallthrough,
basePath: "/api",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
require.NoError(t, err)
return provider, bridge
},
},
{
name: "oai_some_base_url_path",
providerName: config.ProviderOpenAI,
fixture: fixtures.OaiChatFallthrough,
basePath: "/api",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewOpenAI(openaiCfg(addr, apiKey))
Expand All @@ -770,11 +841,12 @@ func TestFallthrough(t *testing.T) {

files := filesMap(arc)
require.Contains(t, files, fixtureResponse)
expectedPath := tc.basePath + "/v1/models"

var receivedHeaders *http.Header
respBody := files[fixtureResponse]
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/models" {
if r.URL.Path != expectedPath {
t.Errorf("unexpected request path: %q", r.URL.Path)
t.FailNow()
}
Expand All @@ -789,7 +861,8 @@ func TestFallthrough(t *testing.T) {

recorderClient := &testutil.MockRecorder{}

provider, bridge := tc.configureFunc(upstream.URL, recorderClient)
upstreamURL := upstream.URL + tc.basePath
provider, bridge := tc.configureFunc(upstreamURL, recorderClient)

bridgeSrv := httptest.NewUnstartedServer(bridge)
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
Expand All @@ -798,7 +871,7 @@ func TestFallthrough(t *testing.T) {
bridgeSrv.Start()
t.Cleanup(bridgeSrv.Close)

req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.name), nil)
req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.providerName), nil)
require.NoError(t, err)

resp, err := http.DefaultClient.Do(req)
Expand Down Expand Up @@ -1074,7 +1147,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
t.Cleanup(cancel)

// Setup mock server with response mutator for multi-turn interaction.
mockSrv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte {
mockSrv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte {
if reqCount == 1 {
return resp // First request gets the normal response (with tool call).
}
Expand Down Expand Up @@ -1310,7 +1383,7 @@ func TestErrorHandling(t *testing.T) {
reqBody := files[fixtureRequest]

// Setup mock server.
mockSrv := newMockServer(ctx, t, files, nil)
mockSrv := newMockServer(ctx, t, files, nil, nil)
mockSrv.statusCode = http.StatusInternalServerError
t.Cleanup(mockSrv.Close)

Expand Down Expand Up @@ -1983,11 +2056,15 @@ type mockServer struct {
statusCode int
}

func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer {
func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requestValidatorFn func(*http.Request), responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later: we should refactor this to take a variadic set of options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be done in: #73

t.Helper()

ms := &mockServer{}
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if requestValidatorFn != nil {
requestValidatorFn(r)
}

statusCode := http.StatusOK
if ms.statusCode != 0 {
statusCode = ms.statusCode
Expand Down
Loading