diff --git a/bridge_integration_test.go b/bridge_integration_test.go index c2f307e..a5df3e0 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -736,11 +736,13 @@ func TestOpenAIInjectedTools(t *testing.T) { require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. // Check the token usage from the client's perspective. - assert.EqualValues(t, 9911, message.Usage.PromptTokens) + // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. + // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. + // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) assert.EqualValues(t, 105, message.Usage.CompletionTokens) // Ensure tokens used during injected tool invocation are accounted for. - require.EqualValues(t, 9911, calculateTotalInputTokens(recorderClient.tokenUsages)) + require.EqualValues(t, 5047, calculateTotalInputTokens(recorderClient.tokenUsages)) require.EqualValues(t, 105, calculateTotalOutputTokens(recorderClient.tokenUsages)) }) } diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 7cefc5c..b6a0343 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -81,7 +81,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r _ = i.recorder.RecordTokenUsage(ctx, &TokenUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, - Input: lastUsage.PromptTokens, + Input: calculateActualInputTokenUsage(lastUsage), Output: lastUsage.CompletionTokens, Metadata: Metadata{ "prompt_audio": lastUsage.PromptTokensDetails.AudioTokens, diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 314d20c..f10c26e 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -165,7 +165,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, _ = i.recorder.RecordTokenUsage(streamCtx, &TokenUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), - Input: lastUsage.PromptTokens, + Input: calculateActualInputTokenUsage(lastUsage), Output: lastUsage.CompletionTokens, Metadata: Metadata{ "prompt_audio": lastUsage.PromptTokensDetails.AudioTokens, diff --git a/openai.go b/openai.go index b296123..dc3abc8 100644 --- a/openai.go +++ b/openai.go @@ -96,6 +96,15 @@ func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage { } } +// calculateActualInputTokenUsage accounts for cached tokens which are included in [openai.CompletionUsage].PromptTokens. +func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 { + // Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage. + // The original value can be reconstructed by referencing the "prompt_cached" field in metadata. + // See https://platform.openai.com/docs/api-reference/usage/completions_object#usage/completions_object-input_tokens. + return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ - + in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */ +} + func getOpenAIErrorResponse(err error) *OpenAIErrorResponse { var apierr *openai.Error if !errors.As(err, &apierr) {