Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/mark3labs/mcp-go v0.38.0
github.com/stretchr/testify v1.10.0
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tidwall/sjson v1.2.5
go.uber.org/goleak v1.3.0
go.uber.org/mock v0.6.0
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
Expand Down
45 changes: 32 additions & 13 deletions intercept_openai_chat_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/google/uuid"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/packages/ssestream"
"github.com/tidwall/sjson"

"cdr.dev/slog"
)
Expand Down Expand Up @@ -117,16 +118,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
continue
}

// If usage information is available, relay the cumulative usage once all tool invocations have completed.
if chunk.Usage.CompletionTokens > 0 {
chunk.Usage = processor.getCumulativeUsage()
}

// Overwrite response identifier since proxy obscures injected tool call invocations.
chunk.ID = i.ID().String()

// Marshal and relay chunk to client.
payload, err := i.marshal(chunk)
payload, err := i.marshalChunk(&chunk, i.ID(), processor)
if err != nil {
logger.Warn(ctx, "failed to marshal chunk", slog.Error(err), chunk.RawJSON())
lastErr = fmt.Errorf("marshal chunk: %w", err)
Expand Down Expand Up @@ -202,7 +195,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
}

if interceptionErr != nil {
payload, err := i.marshal(interceptionErr)
payload, err := i.marshalErr(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
} else if err := events.Send(streamCtx, payload); err != nil {
Expand Down Expand Up @@ -291,10 +284,36 @@ func (i *OpenAIStreamingChatInterception) getInjectedToolByName(name string) *mc
return i.mcpProxy.GetTool(name)
}

func (i *OpenAIStreamingChatInterception) marshal(payload any) ([]byte, error) {
data, err := json.Marshal(payload)
// Mashals received stream chunk.
// Overrides id (since proxy obscures injected tool call invocations).
// If usage field was set in original chunk overrides it to culminative usage.
//
// sjson is used instead of normal struct marshaling so forwarded data
// is as close to the original as possible. Structs from openai library lack
// `omitzero/omitempty` annotations which adds additional empty fields
// when marshaling structs. Those additional empty fields can break Codex client.
func (i *OpenAIStreamingChatInterception) marshalChunk(chunk *openai.ChatCompletionChunk, id uuid.UUID, prc *openAIStreamProcessor) ([]byte, error) {
sj, err := sjson.Set(chunk.RawJSON(), "id", id.String())
if err != nil {
return nil, fmt.Errorf("marshal chunk id failed: %w", err)
}

// If usage information is available, relay the cumulative usage once all tool invocations have completed.
if chunk.JSON.Usage.Valid() {
u := prc.getCumulativeUsage()
sj, err = sjson.Set(sj, "usage", u)
if err != nil {
return nil, fmt.Errorf("marshal chunk usage failed: %w", err)
}
}

return i.encodeForStream([]byte(sj)), nil
}

func (i *OpenAIStreamingChatInterception) marshalErr(err error) ([]byte, error) {
data, err := json.Marshal(err)
if err != nil {
return nil, fmt.Errorf("marshal payload: %w", err)
return nil, fmt.Errorf("marshal error failed: %w", err)
}

return i.encodeForStream(data), nil
Expand Down