-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
tokencount.go
212 lines (184 loc) · 7.3 KB
/
tokencount.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
/*
* Teleport
* Copyright (C) 2023 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package tokens
import (
"sync"
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer/codec"
)
var defaultTokenizer = codec.NewCl100kBase()
// TokenCount holds TokenCounters for both Prompt and Completion tokens.
// As the agent performs multiple calls to the model, each call creates its own
// prompt and completion TokenCounter.
//
// Prompt TokenCounters can be created before doing the call as we know the
// full prompt and can tokenize it. This is the PromptTokenCounter purpose.
//
// Completion TokenCounters can be created after receiving the model response.
// Depending on the response type, we might have the full result already or get
// a stream that will provide the completion result in the future. For the latter,
// the token count will be evaluated lazily and asynchronously.
// StaticTokenCounter count tokens synchronously, while
// AsynchronousTokenCounter supports the streaming use-cases.
type TokenCount struct {
Prompt TokenCounters
Completion TokenCounters
}
// AddPromptCounter adds a TokenCounter to the Prompt list.
func (tc *TokenCount) AddPromptCounter(prompt TokenCounter) {
if prompt != nil {
tc.Prompt = append(tc.Prompt, prompt)
}
}
// AddCompletionCounter adds a TokenCounter to the Completion list.
func (tc *TokenCount) AddCompletionCounter(completion TokenCounter) {
if completion != nil {
tc.Completion = append(tc.Completion, completion)
}
}
// CountAll iterates over all counters and returns how many prompt and
// completion tokens were used. As completion token counting can require waiting
// for a response to be streamed, the caller should pass a context and use it to
// implement some kind of deadline to avoid hanging infinitely if something goes
// wrong (e.g. use `context.WithTimeout()`).
func (tc *TokenCount) CountAll() (int, int) {
return tc.Prompt.CountAll(), tc.Completion.CountAll()
}
// NewTokenCount initializes a new TokenCount struct.
func NewTokenCount() *TokenCount {
return &TokenCount{
Prompt: TokenCounters{},
Completion: TokenCounters{},
}
}
// TokenCounter is an interface for all token counters, regardless of the kind
// of token they count (prompt/completion) or the tokenizer used.
// TokenCount must be idempotent.
type TokenCounter interface {
TokenCount() int
}
// TokenCounters is a list of TokenCounter and offers function to iterate over
// all counters and compute the total.
type TokenCounters []TokenCounter
// CountAll iterates over a list of TokenCounter and returns the sum of the
// results of all counters. As the counting process might be blocking/take some
// time, the caller should set a Deadline on the context.
func (tc TokenCounters) CountAll() int {
var total int
for _, counter := range tc {
total += counter.TokenCount()
}
return total
}
// StaticTokenCounter is a token counter whose count has already been evaluated.
// This can be used to count prompt tokens (we already know the exact count),
// or to count how many tokens were used by an already finished completion
// request.
type StaticTokenCounter int
// TokenCount implements the TokenCounter interface.
func (tc *StaticTokenCounter) TokenCount() int {
return int(*tc)
}
// NewPromptTokenCounter takes a list of openai.ChatCompletionMessage and
// computes how many tokens are used by sending those messages to the model.
func NewPromptTokenCounter(prompt []openai.ChatCompletionMessage) (*StaticTokenCounter, error) {
var promptCount int
for _, message := range prompt {
promptTokens, _, err := defaultTokenizer.Encode(message.Content)
if err != nil {
return nil, trace.Wrap(err)
}
promptCount = promptCount + perMessage + perRole + len(promptTokens)
}
tc := StaticTokenCounter(promptCount)
return &tc, nil
}
// NewSynchronousTokenCounter takes the completion request output and
// computes how many tokens were used by the model to generate this result.
func NewSynchronousTokenCounter(completion string) (*StaticTokenCounter, error) {
completionTokens, _, err := defaultTokenizer.Encode(completion)
if err != nil {
return nil, trace.Wrap(err)
}
completionCount := perRequest + len(completionTokens)
tc := StaticTokenCounter(completionCount)
return &tc, nil
}
// AsynchronousTokenCounter counts completion tokens that are used by a
// streamed completion request. When creating a AsynchronousTokenCounter,
// the streaming might not be finished, and we can't evaluate how many tokens
// will be used. In this case, the streaming routine must add streamed
// completion result with the Add() method and call Finish() once the
// completion is finished. TokenCount() will hang until either Finish() is
// called or the context is Done.
type AsynchronousTokenCounter struct {
count int
// mutex protects all fields of the AsynchronousTokenCounter, it must be
// acquired before any read or write operation.
mutex sync.Mutex
// finished tells if the count is finished or not.
// TokenCount() finishes the count. Once the count is finished, Add() will
// throw errors.
finished bool
}
// TokenCount implements the TokenCounter interface.
// It returns how many tokens have been counted. It also marks the counter as
// finished. Once a counter is finished, tokens cannot be added anymore.
func (tc *AsynchronousTokenCounter) TokenCount() int {
// If the count is already finished, we return the values
tc.mutex.Lock()
defer tc.mutex.Unlock()
tc.finished = true
return tc.count + perRequest
}
// Add a streamed token to the count.
func (tc *AsynchronousTokenCounter) Add() error {
tc.mutex.Lock()
defer tc.mutex.Unlock()
if tc.finished {
return trace.Errorf("Count is already finished, cannot add more content")
}
tc.count += 1
return nil
}
// NewAsynchronousTokenCounter takes the partial completion request output
// and creates a token counter that can be already returned even if not all
// the content has been streamed yet. Streamed content can be added a posteriori
// with Add(). Once all the content is streamed, Finish() must be called.
func NewAsynchronousTokenCounter(completionStart string) (*AsynchronousTokenCounter, error) {
completionTokens, _, err := defaultTokenizer.Encode(completionStart)
if err != nil {
return nil, trace.Wrap(err)
}
return &AsynchronousTokenCounter{
count: len(completionTokens),
mutex: sync.Mutex{},
finished: false,
}, nil
}
// CountTokens is a helper that calls tc.CountAll() on a TokenCount pointer,
// but also return 0, 0 when receiving a nil pointer. This makes token counting
// less awkward in cases where we don't know whether a completion happened or
// not.
func CountTokens(tc *TokenCount) (int, int) {
if tc != nil {
return tc.CountAll()
}
return 0, 0
}