diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 5bf519e..e570dd9 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -30,6 +30,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "github.com/openai/openai-go/v2" oaissestream "github.com/openai/openai-go/v2/packages/ssestream" @@ -47,8 +48,10 @@ var ( antSingleInjectedTool []byte //go:embed fixtures/anthropic/fallthrough.txtar antFallthrough []byte - //go:embed fixtures/anthropic/error.txtar - antErr []byte + //go:embed fixtures/anthropic/stream_error.txtar + antMidStreamErr []byte + //go:embed fixtures/anthropic/non_stream_error.txtar + antNonStreamErr []byte //go:embed fixtures/openai/simple.txtar oaiSimple []byte @@ -58,8 +61,10 @@ var ( oaiSingleInjectedTool []byte //go:embed fixtures/openai/fallthrough.txtar oaiFallthrough []byte - //go:embed fixtures/openai/error.txtar - oaiErr []byte + //go:embed fixtures/openai/stream_error.txtar + oaiMidStreamErr []byte + //go:embed fixtures/openai/non_stream_error.txtar + oaiNonStreamErr []byte ) const ( @@ -676,11 +681,11 @@ func TestFallthrough(t *testing.T) { t.FailNow() } + receivedHeaders = &r.Header + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write(respBody) - - receivedHeaders = &r.Header })) t.Cleanup(upstream.Close) @@ -1009,23 +1014,129 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu func TestErrorHandling(t *testing.T) { t.Parallel() - cases := []struct { - name string - fixture []byte - createRequestFunc createRequestFunc - configureFunc configureFunc - responseHandlerFn func(streaming bool, resp *http.Response) - }{ - { - name: aibridge.ProviderAnthropic, - fixture: antErr, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + // Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected. + t.Run("non-stream error", func(t *testing.T) { + cases := []struct { + name string + fixture []byte + createRequestFunc createRequestFunc + configureFunc configureFunc + responseHandlerFn func(resp *http.Response) + }{ + { + name: aibridge.ProviderAnthropic, + fixture: antNonStreamErr, + createRequestFunc: createAnthropicMessagesReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "error", gjson.GetBytes(body, "type").Str) + require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) + require.Contains(t, gjson.GetBytes(body, "error.message").Str, "prompt is too long") + }, }, - responseHandlerFn: func(streaming bool, resp *http.Response) { - if streaming { + { + name: aibridge.ProviderOpenAI, + fixture: oaiNonStreamErr, + createRequestFunc: createOpenAIChatCompletionsReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "context_length_exceeded", gjson.GetBytes(body, "error.code").Str) + require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) + require.Contains(t, gjson.GetBytes(body, "error.message").Str, "Input tokens exceed the configured limit") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + arc := txtar.Parse(tc.fixture) + t.Logf("%s: %s", t.Name(), arc.Comment) + + files := filesMap(arc) + require.Len(t, files, 3) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + reqBody := files[fixtureRequest] + // Add the stream param to the request. + newBody, err := setJSON(reqBody, "stream", streaming) + require.NoError(t, err) + reqBody = newBody + + // Setup mock server. + mockResp := files[fixtureStreamingResponse] + if !streaming { + mockResp = files[fixtureNonStreamingResponse] + } + mockSrv := newMockHTTPReflector(ctx, t, mockResp) + t.Cleanup(mockSrv.Close) + + recorderClient := &mockRecorderClient{} + + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + require.NoError(t, err) + + // Invoke request to mocked API via aibridge. + bridgeSrv := httptest.NewUnstartedServer(b) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + t.Cleanup(bridgeSrv.Close) + + req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) + resp, err := http.DefaultClient.Do(req) + t.Cleanup(func() { _ = resp.Body.Close() }) + require.NoError(t, err) + + tc.responseHandlerFn(resp) + recorderClient.verifyAllInterceptionsEnded(t) + }) + } + }) + } + }) + + // Tests that errors which occur *during* a streaming response are handled as expected. + t.Run("mid-stream error", func(t *testing.T) { + cases := []struct { + name string + fixture []byte + createRequestFunc createRequestFunc + configureFunc configureFunc + responseHandlerFn func(resp *http.Response) + }{ + { + name: aibridge.ProviderAnthropic, + fixture: antMidStreamErr, + createRequestFunc: createAnthropicMessagesReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1033,24 +1144,17 @@ func TestErrorHandling(t *testing.T) { require.NoError(t, sp.Parse(resp.Body)) require.Len(t, sp.EventsByType("error"), 1) require.Contains(t, sp.EventsByType("error")[0].Data, "Overloaded") - } else { - require.Equal(t, resp.StatusCode, http.StatusInternalServerError) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Contains(t, string(body), "Overloaded") - } + }, }, - }, - { - name: aibridge.ProviderOpenAI, - fixture: oaiErr, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) - }, - responseHandlerFn: func(streaming bool, resp *http.Response) { - if streaming { + { + name: aibridge.ProviderOpenAI, + fixture: oaiMidStreamErr, + createRequestFunc: createOpenAIChatCompletionsReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1063,72 +1167,55 @@ func TestErrorHandling(t *testing.T) { errEvent := sp.MessageEvents()[len(sp.MessageEvents())-2] // Last event is termination marker ("[DONE]"). require.NotEmpty(t, errEvent) require.Contains(t, errEvent.Data, "The server had an error while processing your request. Sorry about that!") - } else { - require.Equal(t, resp.StatusCode, http.StatusInternalServerError) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Contains(t, string(body), "The server had an error while processing your request. Sorry about that") - } + }, }, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() + } - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) + arc := txtar.Parse(tc.fixture) + t.Logf("%s: %s", t.Name(), arc.Comment) - reqBody := files[fixtureRequest] + files := filesMap(arc) + require.Len(t, files, 2) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody + reqBody := files[fixtureRequest] - // Setup mock server. - mockSrv := newMockServer(ctx, t, files, nil) - mockSrv.statusCode = http.StatusInternalServerError - t.Cleanup(mockSrv.Close) + // Setup mock server. + mockSrv := newMockServer(ctx, t, files, nil) + mockSrv.statusCode = http.StatusInternalServerError + t.Cleanup(mockSrv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) - require.NoError(t, err) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + require.NoError(t, err) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + // Invoke request to mocked API via aibridge. + bridgeSrv := httptest.NewUnstartedServer(b) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + t.Cleanup(bridgeSrv.Close) - req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) - resp, err := http.DefaultClient.Do(req) - t.Cleanup(func() { _ = resp.Body.Close() }) - require.NoError(t, err) + req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) + resp, err := http.DefaultClient.Do(req) + t.Cleanup(func() { _ = resp.Body.Close() }) + require.NoError(t, err) - tc.responseHandlerFn(streaming, resp) - recorderClient.verifyAllInterceptionsEnded(t) - }) - } - }) - } + tc.responseHandlerFn(resp) + recorderClient.verifyAllInterceptionsEnded(t) + }) + } + }) } // TestStableRequestEncoding validates that a given intercepted request and a @@ -1297,6 +1384,44 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) return req } +type mockHTTPReflector struct { + *httptest.Server +} + +func newMockHTTPReflector(ctx context.Context, t *testing.T, resp []byte) *mockHTTPReflector { + ref := &mockHTTPReflector{} + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mock, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(resp)), r) + require.NoError(t, err) + defer mock.Body.Close() + + // Copy headers from the mocked response. + for key, values := range mock.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Write the status code. + w.WriteHeader(mock.StatusCode) + + // Copy the body. + _, err = io.Copy(w, mock.Body) + require.NoError(t, err) + })) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + + srv.Start() + t.Cleanup(srv.Close) + + ref.Server = srv + return ref +} + +// TODO: replace this with mockHTTPReflector. type mockServer struct { *httptest.Server diff --git a/fixtures/anthropic/non_stream_error.txtar b/fixtures/anthropic/non_stream_error.txtar new file mode 100644 index 0000000..76a9347 --- /dev/null +++ b/fixtures/anthropic/non_stream_error.txtar @@ -0,0 +1,35 @@ +Simple request + error which occurs before streaming begins (where applicable). + +-- request -- +{ + "max_tokens": 8192, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "yo" + } + ] + } + ], + "model": "claude-sonnet-4-0", + "temperature": 1 +} + +-- streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 164 +Content-Type: application/json + +{"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"} + + +-- non-streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 164 +Content-Type: application/json + +{"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"} + diff --git a/fixtures/anthropic/error.txtar b/fixtures/anthropic/stream_error.txtar similarity index 85% rename from fixtures/anthropic/error.txtar rename to fixtures/anthropic/stream_error.txtar index 81ed89d..8b63444 100644 --- a/fixtures/anthropic/error.txtar +++ b/fixtures/anthropic/stream_error.txtar @@ -15,7 +15,8 @@ Simple request + error. } ], "model": "claude-sonnet-4-0", - "temperature": 1 + "temperature": 1, + "stream": true } -- streaming -- @@ -31,12 +32,3 @@ data: {"type": "ping"} event: error data: {"type": "error", "error": {"type": "api_error", "message": "Overloaded"}} --- non-streaming -- -{ - "type": "error", - "error": { - "type": "api_error", - "message": "Overloaded" - }, - "request_id": null -} \ No newline at end of file diff --git a/fixtures/openai/non_stream_error.txtar b/fixtures/openai/non_stream_error.txtar new file mode 100644 index 0000000..e84ce09 --- /dev/null +++ b/fixtures/openai/non_stream_error.txtar @@ -0,0 +1,43 @@ +Simple request + error which occurs before streaming begins (where applicable). + +-- request -- +{ + "messages": [ + { + "role": "user", + "content": "how many angels can dance on the head of a pin\n" + } + ], + "model": "gpt-4.1", + "stream": true +} + +-- streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 281 +Content-Type: application/json + +{ + "error": { + "message": "Input tokens exceed the configured limit of 272000 tokens. Your messages resulted in 3148588 tokens. Please reduce the length of the messages.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded" + } +} + + +-- non-streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 281 +Content-Type: application/json + +{ + "error": { + "message": "Input tokens exceed the configured limit of 272000 tokens. Your messages resulted in 3148588 tokens. Please reduce the length of the messages.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded" + } +} + diff --git a/fixtures/openai/error.txtar b/fixtures/openai/stream_error.txtar similarity index 89% rename from fixtures/openai/error.txtar rename to fixtures/openai/stream_error.txtar index 8e9efae..678800b 100644 --- a/fixtures/openai/error.txtar +++ b/fixtures/openai/stream_error.txtar @@ -8,7 +8,8 @@ Simple request + error. "content": "how many angels can dance on the head of a pin\n" } ], - "model": "gpt-4.1" + "model": "gpt-4.1", + "stream": true } -- streaming -- @@ -22,10 +23,3 @@ data: {"id":"chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N","object":"chat.completion.c data: {"error": {"message": "The server had an error while processing your request. Sorry about that!", "type": "server_error"}} --- non-streaming -- -{ - "error": { - "message": "The server had an error while processing your request. Sorry about that!", - "type": "server_error" - } -} \ No newline at end of file diff --git a/go.mod b/go.mod index a0e1ffa..827a224 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 github.com/stretchr/testify v1.10.0 - github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 8459618..2367933 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -2,6 +2,7 @@ package aibridge import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -173,6 +174,33 @@ func (i *AnthropicMessagesInterceptionBase) augmentRequestForBedrock() { i.req.MessageNewParams.Model = anthropic.Model(i.Model()) } +// writeUpstreamError marshals and writes a given error. +func (i *AnthropicMessagesInterceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *AnthropicErrorResponse) { + if antErr == nil { + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(antErr.StatusCode) + + out, err := json.Marshal(antErr) + if err != nil { + i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", antErr))) + // Response has to match expected format. + // See https://docs.claude.com/en/api/errors#error-shapes. + _, _ = w.Write([]byte(fmt.Sprintf(`{ + "type":"error", + "error": { + "type": "error", + "message":"error marshaling upstream error" + }, + "request_id": "%s" +}`, i.ID().String()))) + } else { + _, _ = w.Write(out) + } +} + // redirectTransport is an HTTP RoundTripper that redirects requests to a different endpoint. // This is useful for testing when we need to redirect AWS Bedrock requests to a mock server. type redirectTransport struct { diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index cc2daaf..f978113 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -76,19 +76,17 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr resp, err = client.Messages.New(ctx, messages) if err != nil { if isConnError(err) { - logger.Warn(ctx, "upstream connection closed", slog.Error(err)) + // Can't write a response, just error out. return fmt.Errorf("upstream connection closed: %w", err) } - logger.Warn(ctx, "anthropic API error", slog.Error(err)) if antErr := getAnthropicErrorResponse(err); antErr != nil { - http.Error(w, antErr.Error(), antErr.StatusCode) - return fmt.Errorf("api error: %w", err) + i.writeUpstreamError(w, antErr) + return fmt.Errorf("anthropic API error: %w", err) } - logger.Warn(ctx, "upstream API error", slog.Error(err)) http.Error(w, "internal error", http.StatusInternalServerError) - return fmt.Errorf("upstream API error: %w", err) + return fmt.Errorf("internal error: %w", err) } if prompt != nil { diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 403d1e6..15bb6d8 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -97,7 +97,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. events := newEventStream(streamCtx, logger.Named("sse-sender"), i.pingPayload()) - go events.run(w, r) + go events.start(w, r) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() @@ -414,35 +414,40 @@ newStream: prompt = nil } - // Check if the stream encountered any errors. - if streamErr := stream.Err(); streamErr != nil { - if isUnrecoverableError(streamErr) { - logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) - // We can't reflect an error back if there's a connection error or the request context was canceled. - } else if antErr := getAnthropicErrorResponse(streamErr); antErr != nil { - logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) - interceptionErr = fmt.Errorf("stream error: %w", antErr) - } else { - logger.Warn(ctx, "unknown error", slog.Error(streamErr)) - // Unfortunately, the Anthropic SDK does not support parsing errors received in the stream - // into known types (i.e. [shared.OverloadedError]). - // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 - // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newAnthropicErr(fmt.Errorf("unknown stream error: %w", streamErr)) + if events.isStreaming() { + // Check if the stream encountered any errors. + if streamErr := stream.Err(); streamErr != nil { + if isUnrecoverableError(streamErr) { + logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) + // We can't reflect an error back if there's a connection error or the request context was canceled. + } else if antErr := getAnthropicErrorResponse(streamErr); antErr != nil { + logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) + interceptionErr = antErr + } else { + logger.Warn(ctx, "unknown error", slog.Error(streamErr)) + // Unfortunately, the Anthropic SDK does not support parsing errors received in the stream + // into known types (i.e. [shared.OverloadedError]). + // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 + // All it does is wrap the payload in an error - which is all we can return, currently. + interceptionErr = newAnthropicErr(fmt.Errorf("unknown stream error: %w", streamErr)) + } + } else if lastErr != nil { + // Otherwise check if any logical errors occurred during processing. + logger.Warn(ctx, "stream failed", slog.Error(lastErr)) + interceptionErr = newAnthropicErr(fmt.Errorf("processing error: %w", lastErr)) } - } else if lastErr != nil { - // Otherwise check if any logical errors occurred during processing. - logger.Warn(ctx, "stream failed", slog.Error(lastErr)) - interceptionErr = newAnthropicErr(fmt.Errorf("processing error: %w", lastErr)) - } - if interceptionErr != nil { - payload, err := i.marshal(interceptionErr) - if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) - } else if err := events.Send(streamCtx, payload); err != nil { - logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + if interceptionErr != nil { + payload, err := i.marshal(interceptionErr) + if err != nil { + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) + } else if err := events.Send(streamCtx, payload); err != nil { + logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + } } + } else { + // Stream has not started yet; write to response if present. + i.writeUpstreamError(w, getAnthropicErrorResponse(stream.Err())) } shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30) diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 36b8ff0..44ef582 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -3,11 +3,13 @@ package aibridge import ( "context" "encoding/json" + "net/http" "strings" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" "github.com/openai/openai-go/v2/shared" "cdr.dev/slog" @@ -24,6 +26,14 @@ type OpenAIChatInterceptionBase struct { mcpProxy mcp.ServerProxier } +func (i *OpenAIChatInterceptionBase) newOpenAIClient(baseURL, key string) openai.Client { + var opts []option.RequestOption + opts = append(opts, option.WithAPIKey(key)) + opts = append(opts, option.WithBaseURL(baseURL)) + + return openai.NewClient(opts...) +} + func (i *OpenAIChatInterceptionBase) ID() uuid.UUID { return i.id } @@ -92,3 +102,28 @@ func (i *OpenAIChatInterceptionBase) unmarshalArgs(in string) (args ToolArgs) { return args } + +// writeUpstreamError marshals and writes a given error. +func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *OpenAIErrorResponse) { + if oaiErr == nil { + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(oaiErr.StatusCode) + + out, err := json.Marshal(oaiErr) + if err != nil { + i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", oaiErr))) + // Response has to match expected format. + _, _ = w.Write([]byte(`{ + "error": { + "type": "error", + "message":"error marshaling upstream error", + "code": "server_error" + }, +}`)) + } else { + _, _ = w.Write(out) + } +} diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 3b1fa7e..4c019a6 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -3,7 +3,6 @@ package aibridge import ( "bytes" "encoding/json" - "errors" "fmt" "net/http" "strings" @@ -42,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := newOpenAIClient(i.baseURL, i.key) + client := i.newOpenAIClient(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) var ( @@ -184,18 +183,15 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } } - // TODO: these probably have to be formatted as JSON errs? if err != nil { if isConnError(err) { http.Error(w, err.Error(), http.StatusInternalServerError) return fmt.Errorf("upstream connection closed: %w", err) } - logger.Warn(ctx, "openai API error", slog.Error(err)) - var apierr *openai.Error - if errors.As(err, &apierr) { - http.Error(w, apierr.Message, apierr.StatusCode) - return fmt.Errorf("api error: %w", apierr) + if apiErr := getOpenAIErrorResponse(err); apiErr != nil { + i.writeUpstreamError(w, apiErr) + return fmt.Errorf("openai API error: %w", err) } http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 0c5f554..cc1a64a 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -65,7 +65,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - client := newOpenAIClient(i.baseURL, i.key) + client := i.newOpenAIClient(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) @@ -73,7 +73,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. events := newEventStream(streamCtx, logger.Named("sse-sender"), nil) - go events.run(w, r) + go events.start(w, r) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() @@ -172,35 +172,40 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, }) } - // Check if the stream encountered any errors. - if streamErr := stream.Err(); streamErr != nil { - if isUnrecoverableError(streamErr) { - logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) - // We can't reflect an error back if there's a connection error or the request context was canceled. - } else if oaiErr := getOpenAIErrorResponse(streamErr); oaiErr != nil { - logger.Warn(ctx, "openai stream error", slog.Error(streamErr)) - interceptionErr = oaiErr - } else { - logger.Warn(ctx, "unknown error", slog.Error(streamErr)) - // Unfortunately, the OpenAI SDK does not support parsing errors received in the stream - // into known types (i.e. [shared.OverloadedError]). - // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 - // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newOpenAIErr(fmt.Errorf("unknown stream error: %w", streamErr)) + if events.isStreaming() { + // Check if the stream encountered any errors. + if streamErr := stream.Err(); streamErr != nil { + if isUnrecoverableError(streamErr) { + logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) + // We can't reflect an error back if there's a connection error or the request context was canceled. + } else if oaiErr := getOpenAIErrorResponse(streamErr); oaiErr != nil { + logger.Warn(ctx, "openai stream error", slog.Error(streamErr)) + interceptionErr = oaiErr + } else { + logger.Warn(ctx, "unknown error", slog.Error(streamErr)) + // Unfortunately, the OpenAI SDK does not support parsing errors received in the stream + // into known types (i.e. [shared.OverloadedError]). + // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 + // All it does is wrap the payload in an error - which is all we can return, currently. + interceptionErr = newOpenAIErr(fmt.Errorf("unknown stream error: %w", streamErr)) + } + } else if lastErr != nil { + // Otherwise check if any logical errors occurred during processing. + logger.Warn(ctx, "stream failed", slog.Error(lastErr)) + interceptionErr = newOpenAIErr(fmt.Errorf("processing error: %w", lastErr)) } - } else if lastErr != nil { - // Otherwise check if any logical errors occurred during processing. - logger.Warn(ctx, "stream failed", slog.Error(lastErr)) - interceptionErr = newOpenAIErr(fmt.Errorf("processing error: %w", lastErr)) - } - if interceptionErr != nil { - payload, err := i.marshalErr(interceptionErr) - if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) - } else if err := events.Send(streamCtx, payload); err != nil { - logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + if interceptionErr != nil { + payload, err := i.marshalErr(interceptionErr) + if err != nil { + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) + } else if err := events.Send(streamCtx, payload); err != nil { + logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + } } + } else { + // Stream has not started yet; write to response if present. + i.writeUpstreamError(w, getOpenAIErrorResponse(stream.Err())) } // No tool call, nothing more to do. diff --git a/openai.go b/openai.go index dc3abc8..3a02fb9 100644 --- a/openai.go +++ b/openai.go @@ -1,14 +1,12 @@ package aibridge import ( - "encoding/json" "errors" - "github.com/anthropics/anthropic-sdk-go/shared" - "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/param" + "github.com/openai/openai-go/v2/shared" ) // ChatCompletionNewParamsWrapper exists because the "stream" param is not included in openai.ChatCompletionNewParams. @@ -106,57 +104,41 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 { } func getOpenAIErrorResponse(err error) *OpenAIErrorResponse { - var apierr *openai.Error - if !errors.As(err, &apierr) { + var apiErr *openai.Error + if !errors.As(err, &apiErr) { return nil } - msg := apierr.Error() - typ := string(constant.ValueOf[constant.APIError]()) - - var detail *shared.APIErrorObject - if field, ok := apierr.JSON.ExtraFields["error"]; ok { - _ = json.Unmarshal([]byte(field.Raw()), &detail) - } - if detail != nil { - msg = detail.Message - typ = string(detail.Type) - } - return &OpenAIErrorResponse{ - ErrorResponse: &shared.ErrorResponse{ - Error: shared.ErrorObjectUnion{ - Message: msg, - Type: typ, - }, - Type: constant.ValueOf[constant.Error](), + ErrorObject: &shared.ErrorObject{ + Code: apiErr.Code, + Message: apiErr.Message, + Type: apiErr.Type, }, - StatusCode: apierr.StatusCode, + StatusCode: apiErr.StatusCode, } } var _ error = &OpenAIErrorResponse{} type OpenAIErrorResponse struct { - *shared.ErrorResponse - - StatusCode int `json:"-"` + ErrorObject *shared.ErrorObject `json:"error"` + StatusCode int `json:"-"` } func newOpenAIErr(msg error) *OpenAIErrorResponse { return &OpenAIErrorResponse{ - ErrorResponse: &shared.ErrorResponse{ - Error: shared.ErrorObjectUnion{ - Message: msg.Error(), - Type: "error", - }, + ErrorObject: &shared.ErrorObject{ + Code: "error", + Message: msg.Error(), + Type: "error", }, } } func (a *OpenAIErrorResponse) Error() string { - if a.ErrorResponse == nil { + if a.ErrorObject == nil { return "" } - return a.ErrorResponse.Error.Message + return a.ErrorObject.Message } diff --git a/provider_openai.go b/provider_openai.go index 3a3db45..0fc31a6 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -8,8 +8,6 @@ import ( "os" "github.com/google/uuid" - "github.com/openai/openai-go/v2" - "github.com/openai/openai-go/v2/option" ) var _ Provider = &OpenAIProvider{} @@ -100,11 +98,3 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.key) } - -func newOpenAIClient(baseURL, key string) openai.Client { - var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(key)) - opts = append(opts, option.WithBaseURL(baseURL)) - - return openai.NewClient(opts...) -} diff --git a/streaming.go b/streaming.go index e6fe72d..a3216b5 100644 --- a/streaming.go +++ b/streaming.go @@ -9,6 +9,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -27,11 +28,14 @@ type eventStream struct { pingPayload []byte + initiated atomic.Bool + initiateOnce sync.Once + closeOnce sync.Once shutdownOnce sync.Once eventsCh chan event - // doneCh is closed when the run loop exits. + // doneCh is closed when the start loop exits. doneCh chan struct{} } @@ -48,27 +52,17 @@ func newEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) } } -// run handles sending Server-Sent Event to the client. -func (s *eventStream) run(w http.ResponseWriter, r *http.Request) { +// start handles sending Server-Sent Event to the client. +func (s *eventStream) start(w http.ResponseWriter, r *http.Request) { // Signal completion on exit so senders don't block indefinitely after closure. defer close(s.doneCh) ctx := r.Context() - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - - // Send initial flush to ensure connection is established. - if err := flush(w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - // Send periodic pings to keep connections alive. // The upstream provider may also send their own pings, but we can't rely on this. - tick := time.NewTicker(pingInterval) + tick := time.NewTicker(time.Nanosecond) + tick.Stop() // Ticker will start after stream initiation. defer tick.Stop() for { @@ -83,10 +77,38 @@ func (s *eventStream) run(w http.ResponseWriter, r *http.Request) { case <-ctx.Done(): s.logger.Debug(ctx, "request context canceled", slog.Error(ctx.Err())) return - case ev, open = <-s.eventsCh: + case ev, open = <-s.eventsCh: // Once closed, the buffered channel will drain all buffered values before showing as closed. if !open { + s.logger.Debug(ctx, "events channel closed") return } + + // Initiate the stream once the first event is received. + s.initiateOnce.Do(func() { + s.initiated.Store(true) + s.logger.Debug(ctx, "stream initiated") + + // Send headers for Server-Sent Event stream. + // + // We only send these once an event is processed because an error can occur in the upstream + // request prior to the stream starting, in which case the SSE headers are inappropriate to + // send to the client. + // + // See use of isStreaming(). + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + // Send initial flush to ensure connection is established. + if err := flush(w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Start ping ticker. + tick.Reset(pingInterval) + }) case <-tick.C: ev = s.pingPayload if ev == nil { @@ -150,20 +172,31 @@ func (s *eventStream) Shutdown(shutdownCtx context.Context) error { s.shutdownOnce.Do(func() { s.logger.Debug(shutdownCtx, "shutdown initiated", slog.F("outstanding_events", len(s.eventsCh))) - // Now it is safe to close the events channel; the run loop will exit + // Now it is safe to close the events channel; the start() loop will exit // after draining remaining events and receivers will stop ranging. close(s.eventsCh) }) + var err error select { case <-shutdownCtx.Done(): // If shutdownCtx completes, shutdown likely exceeded its timeout. - return fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) + err = fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) case <-s.ctx.Done(): - return fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) + err = fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) case <-s.doneCh: return nil } + + // Even if the context is canceled, we need to wait for start() to complete. + <-s.doneCh + return err +} + +// isStreaming checks if the stream has been initiated, or +// when events are buffered which - when processed - will initiate the stream. +func (s *eventStream) isStreaming() bool { + return s.initiated.Load() || len(s.eventsCh) > 0 } // isConnError checks if an error is related to client disconnection or context cancellation.