diff --git a/go.mod b/go.mod index e9545f3..6c241fe 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/mark3labs/mcp-go v0.38.0 github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 // indirect - github.com/tidwall/sjson v1.2.5 // indirect + github.com/tidwall/sjson v1.2.5 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 931b131..0c5f554 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/ssestream" + "github.com/tidwall/sjson" "cdr.dev/slog" ) @@ -117,16 +118,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, continue } - // If usage information is available, relay the cumulative usage once all tool invocations have completed. - if chunk.Usage.CompletionTokens > 0 { - chunk.Usage = processor.getCumulativeUsage() - } - - // Overwrite response identifier since proxy obscures injected tool call invocations. - chunk.ID = i.ID().String() - // Marshal and relay chunk to client. - payload, err := i.marshal(chunk) + payload, err := i.marshalChunk(&chunk, i.ID(), processor) if err != nil { logger.Warn(ctx, "failed to marshal chunk", slog.Error(err), chunk.RawJSON()) lastErr = fmt.Errorf("marshal chunk: %w", err) @@ -202,7 +195,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, } if interceptionErr != nil { - payload, err := i.marshal(interceptionErr) + payload, err := i.marshalErr(interceptionErr) if err != nil { logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) } else if err := events.Send(streamCtx, payload); err != nil { @@ -291,10 +284,36 @@ func (i *OpenAIStreamingChatInterception) getInjectedToolByName(name string) *mc return i.mcpProxy.GetTool(name) } -func (i *OpenAIStreamingChatInterception) marshal(payload any) ([]byte, error) { - data, err := json.Marshal(payload) +// Mashals received stream chunk. +// Overrides id (since proxy obscures injected tool call invocations). +// If usage field was set in original chunk overrides it to culminative usage. +// +// sjson is used instead of normal struct marshaling so forwarded data +// is as close to the original as possible. Structs from openai library lack +// `omitzero/omitempty` annotations which adds additional empty fields +// when marshaling structs. Those additional empty fields can break Codex client. +func (i *OpenAIStreamingChatInterception) marshalChunk(chunk *openai.ChatCompletionChunk, id uuid.UUID, prc *openAIStreamProcessor) ([]byte, error) { + sj, err := sjson.Set(chunk.RawJSON(), "id", id.String()) + if err != nil { + return nil, fmt.Errorf("marshal chunk id failed: %w", err) + } + + // If usage information is available, relay the cumulative usage once all tool invocations have completed. + if chunk.JSON.Usage.Valid() { + u := prc.getCumulativeUsage() + sj, err = sjson.Set(sj, "usage", u) + if err != nil { + return nil, fmt.Errorf("marshal chunk usage failed: %w", err) + } + } + + return i.encodeForStream([]byte(sj)), nil +} + +func (i *OpenAIStreamingChatInterception) marshalErr(err error) ([]byte, error) { + data, err := json.Marshal(err) if err != nil { - return nil, fmt.Errorf("marshal payload: %w", err) + return nil, fmt.Errorf("marshal error failed: %w", err) } return i.encodeForStream(data), nil