From 055fdb8f369fae5bdf382d239cf2c6037df13928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Tue, 21 Oct 2025 12:19:35 +0000 Subject: [PATCH 1/3] fix: fix openai stream chunk marshaling --- go.mod | 2 +- intercept_openai_chat_streaming.go | 25 +++++++++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index e9545f3..6c241fe 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/mark3labs/mcp-go v0.38.0 github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 // indirect - github.com/tidwall/sjson v1.2.5 // indirect + github.com/tidwall/sjson v1.2.5 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 931b131..92fd402 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/ssestream" + "github.com/tidwall/sjson" "cdr.dev/slog" ) @@ -126,7 +127,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, chunk.ID = i.ID().String() // Marshal and relay chunk to client. - payload, err := i.marshal(chunk) + payload, err := i.marshalChunk(chunk) if err != nil { logger.Warn(ctx, "failed to marshal chunk", slog.Error(err), chunk.RawJSON()) lastErr = fmt.Errorf("marshal chunk: %w", err) @@ -202,7 +203,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, } if interceptionErr != nil { - payload, err := i.marshal(interceptionErr) + payload, err := i.marshalErr(interceptionErr) if err != nil { logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) } else if err := events.Send(streamCtx, payload); err != nil { @@ -291,8 +292,24 @@ func (i *OpenAIStreamingChatInterception) getInjectedToolByName(name string) *mc return i.mcpProxy.GetTool(name) } -func (i *OpenAIStreamingChatInterception) marshal(payload any) ([]byte, error) { - data, err := json.Marshal(payload) +func (i *OpenAIStreamingChatInterception) marshalChunk(chunk openai.ChatCompletionChunk) ([]byte, error) { + sj, err := sjson.Set(chunk.RawJSON(), "id", chunk.ID) + if err != nil { + return nil, err + } + + if chunk.JSON.Usage.Valid() { + sj, err = sjson.Set(sj, "usage", chunk.Usage) + if err != nil { + return nil, err + } + } + + return i.encodeForStream([]byte(sj)), nil +} + +func (i *OpenAIStreamingChatInterception) marshalErr(err error) ([]byte, error) { + data, err := json.Marshal(err) if err != nil { return nil, fmt.Errorf("marshal payload: %w", err) } From 65be39da67cef2dba0aead709a7320a525216bc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 22 Oct 2025 14:12:34 +0000 Subject: [PATCH 2/3] pr comments: added comments and changed function signature a bit --- intercept_openai_chat_streaming.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 92fd402..a7b0cfd 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -118,16 +118,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, continue } - // If usage information is available, relay the cumulative usage once all tool invocations have completed. - if chunk.Usage.CompletionTokens > 0 { - chunk.Usage = processor.getCumulativeUsage() - } - - // Overwrite response identifier since proxy obscures injected tool call invocations. - chunk.ID = i.ID().String() - // Marshal and relay chunk to client. - payload, err := i.marshalChunk(chunk) + payload, err := i.marshalChunk(&chunk, i.ID(), processor) if err != nil { logger.Warn(ctx, "failed to marshal chunk", slog.Error(err), chunk.RawJSON()) lastErr = fmt.Errorf("marshal chunk: %w", err) @@ -292,14 +284,23 @@ func (i *OpenAIStreamingChatInterception) getInjectedToolByName(name string) *mc return i.mcpProxy.GetTool(name) } -func (i *OpenAIStreamingChatInterception) marshalChunk(chunk openai.ChatCompletionChunk) ([]byte, error) { - sj, err := sjson.Set(chunk.RawJSON(), "id", chunk.ID) +// Mashals received stream chunk. +// Overrides id (since proxy obscures injected tool call invocations). +// If usage field was set in original chunk overrides it to culminative usage. +// +// sjson is used instead of normal struct marshaling so forwarded data +// is as close to the original as possible. Structs from openai library lack +// `ommitzero/ommitempty` annotations which adds additional empty fields +// when marshaling structs. Those additional empty fields can break Codex client. +func (i *OpenAIStreamingChatInterception) marshalChunk(chunk *openai.ChatCompletionChunk, id uuid.UUID, prc *openAIStreamProcessor) ([]byte, error) { + sj, err := sjson.Set(chunk.RawJSON(), "id", id.String()) if err != nil { return nil, err } if chunk.JSON.Usage.Valid() { - sj, err = sjson.Set(sj, "usage", chunk.Usage) + u := prc.getCumulativeUsage() + sj, err = sjson.Set(sj, "usage", u) if err != nil { return nil, err } From b82b95752ee2920110a94874f4e0a053eea8df0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 22 Oct 2025 15:20:45 +0000 Subject: [PATCH 3/3] review 2 --- intercept_openai_chat_streaming.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index a7b0cfd..0c5f554 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -290,19 +290,20 @@ func (i *OpenAIStreamingChatInterception) getInjectedToolByName(name string) *mc // // sjson is used instead of normal struct marshaling so forwarded data // is as close to the original as possible. Structs from openai library lack -// `ommitzero/ommitempty` annotations which adds additional empty fields +// `omitzero/omitempty` annotations which adds additional empty fields // when marshaling structs. Those additional empty fields can break Codex client. func (i *OpenAIStreamingChatInterception) marshalChunk(chunk *openai.ChatCompletionChunk, id uuid.UUID, prc *openAIStreamProcessor) ([]byte, error) { sj, err := sjson.Set(chunk.RawJSON(), "id", id.String()) if err != nil { - return nil, err + return nil, fmt.Errorf("marshal chunk id failed: %w", err) } + // If usage information is available, relay the cumulative usage once all tool invocations have completed. if chunk.JSON.Usage.Valid() { u := prc.getCumulativeUsage() sj, err = sjson.Set(sj, "usage", u) if err != nil { - return nil, err + return nil, fmt.Errorf("marshal chunk usage failed: %w", err) } } @@ -312,7 +313,7 @@ func (i *OpenAIStreamingChatInterception) marshalChunk(chunk *openai.ChatComplet func (i *OpenAIStreamingChatInterception) marshalErr(err error) ([]byte, error) { data, err := json.Marshal(err) if err != nil { - return nil, fmt.Errorf("marshal payload: %w", err) + return nil, fmt.Errorf("marshal error failed: %w", err) } return i.encodeForStream(data), nil