Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -736,11 +736,13 @@ func TestOpenAIInjectedTools(t *testing.T) {
require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned.

// Check the token usage from the client's perspective.
assert.EqualValues(t, 9911, message.Usage.PromptTokens)
// This *should* work but the openai SDK doesn't accumulate the prompt token details :(.
// See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147.
// assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens)
assert.EqualValues(t, 105, message.Usage.CompletionTokens)

// Ensure tokens used during injected tool invocation are accounted for.
require.EqualValues(t, 9911, calculateTotalInputTokens(recorderClient.tokenUsages))
require.EqualValues(t, 5047, calculateTotalInputTokens(recorderClient.tokenUsages))
require.EqualValues(t, 105, calculateTotalOutputTokens(recorderClient.tokenUsages))
})
}
Expand Down
2 changes: 1 addition & 1 deletion intercept_openai_chat_blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
_ = i.recorder.RecordTokenUsage(ctx, &TokenUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
Input: lastUsage.PromptTokens,
Input: calculateActualInputTokenUsage(lastUsage),
Output: lastUsage.CompletionTokens,
Metadata: Metadata{
"prompt_audio": lastUsage.PromptTokensDetails.AudioTokens,
Expand Down
2 changes: 1 addition & 1 deletion intercept_openai_chat_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
_ = i.recorder.RecordTokenUsage(streamCtx, &TokenUsageRecord{
InterceptionID: i.ID().String(),
MsgID: processor.getMsgID(),
Input: lastUsage.PromptTokens,
Input: calculateActualInputTokenUsage(lastUsage),
Output: lastUsage.CompletionTokens,
Metadata: Metadata{
"prompt_audio": lastUsage.PromptTokensDetails.AudioTokens,
Expand Down
9 changes: 9 additions & 0 deletions openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
}
}

// calculateActualInputTokenUsage accounts for cached tokens which are included in [openai.CompletionUsage].PromptTokens.
func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
// Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage.
// The original value can be reconstructed by referencing the "prompt_cached" field in metadata.
// See https://platform.openai.com/docs/api-reference/usage/completions_object#usage/completions_object-input_tokens.
return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ -
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
}

func getOpenAIErrorResponse(err error) *OpenAIErrorResponse {
var apierr *openai.Error
if !errors.As(err, &apierr) {
Expand Down