Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return numTokens only for non-streaming messages #26058

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 17 additions & 11 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Idx int `json:"idx"`
// NumTokens is the number of completion tokens for the (non-streaming) message
NumTokens int `json:"-"`
}

// Chat represents a conversation between a user and an assistant with context memory.
Expand Down Expand Up @@ -95,6 +97,8 @@ type CompletionCommand struct {
Command string `json:"command,omitempty"`
Nodes []string `json:"nodes,omitempty"`
Labels []Label `json:"labels,omitempty"`
// NumTokens is the number of completion tokens for the (non-streaming) message
NumTokens int `json:"-"`
}

// Summary create a short summary for the given input.
Expand Down Expand Up @@ -135,7 +139,7 @@ type StreamingMessage struct {

// Complete completes the conversation with a message from the assistant based on the current context.
// On success, it returns the message and the number of tokens used for the completion.
func (chat *Chat) Complete(ctx context.Context) (any, int, error) {
func (chat *Chat) Complete(ctx context.Context) (any, error) {
var numTokens int

// if the chat is empty, return the initial response we predefine instead of querying GPT-4
Expand All @@ -144,7 +148,7 @@ func (chat *Chat) Complete(ctx context.Context) (any, int, error) {
Role: openai.ChatMessageRoleAssistant,
Content: initialAIResponse,
Idx: len(chat.messages) - 1,
}, numTokens, nil
}, nil
}

// if not, copy the current chat log to a new slice and append the suffix instruction
Expand All @@ -167,15 +171,15 @@ func (chat *Chat) Complete(ctx context.Context) (any, int, error) {
},
)
if err != nil {
return nil, numTokens, trace.Wrap(err)
return nil, trace.Wrap(err)
}

// fetch the first delta to check for a possible JSON payload
var response openai.ChatCompletionStreamResponse
top:
response, err = stream.Recv()
if err != nil {
return nil, numTokens, trace.Wrap(err)
return nil, trace.Wrap(err)
}
numTokens++

Expand All @@ -194,7 +198,7 @@ top:
case errors.Is(err, io.EOF):
break outer
case err != nil:
return nil, numTokens, trace.Wrap(err)
return nil, trace.Wrap(err)
}
numTokens++

Expand All @@ -206,13 +210,15 @@ top:
err = json.Unmarshal([]byte(payload), &c)
switch err {
case nil:
return &c, numTokens, nil
c.NumTokens = numTokens
return &c, nil
default:
return &Message{
Role: openai.ChatMessageRoleAssistant,
Content: payload,
Idx: len(chat.messages) - 1,
}, numTokens, nil
Role: openai.ChatMessageRoleAssistant,
Content: payload,
Idx: len(chat.messages) - 1,
NumTokens: numTokens,
}, nil
}
}

Expand Down Expand Up @@ -246,5 +252,5 @@ top:
Idx: len(chat.messages) - 1,
Chunks: chunks,
Error: errCh,
}, numTokens, nil
}, nil
}
6 changes: 5 additions & 1 deletion lib/web/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,10 @@ func tryFindEmbeddedCommand(message string) *ai.CompletionCommand {
func processComplete(ctx context.Context, h *Handler, chat *ai.Chat, conversationID string,
ws *websocket.Conn, authClient auth.ClientI,
) (int, error) {
var numTokens int

// query the assistant and fetch an answer
message, numTokens, err := chat.Complete(ctx)
message, err := chat.Complete(ctx)
if err != nil {
return numTokens, trace.Wrap(err)
}
Expand Down Expand Up @@ -577,6 +579,7 @@ func processComplete(ctx context.Context, h *Handler, chat *ai.Chat, conversatio
}
}
case *ai.Message:
numTokens = message.NumTokens
// write assistant message to both in-memory chain and persistent storage
chat.Insert(message.Role, message.Content)
protoMsg := &proto.AssistantMessage{
Expand All @@ -593,6 +596,7 @@ func processComplete(ctx context.Context, h *Handler, chat *ai.Chat, conversatio
return numTokens, trace.Wrap(err)
}
case *ai.CompletionCommand:
numTokens = message.NumTokens
payload := commandPayload{
Command: message.Command,
Nodes: message.Nodes,
Expand Down