Skip to content

Commit 62ba557

Browse files
authored
refactor: move len_tokens and related code into gptme.util.tokens (#809)
* refactor: move len_tokens and related code into gptme.util.tokens * fix: fixes to review comments
1 parent cdf548e commit 62ba557

File tree

5 files changed

+92
-64
lines changed

5 files changed

+92
-64
lines changed

gptme/logmanager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def __getitem__(self, key):
4444
def __len__(self) -> int:
4545
return len(self.messages)
4646

47+
def len_tokens(self, model: str) -> int:
48+
return len_tokens(self.messages, model)
49+
4750
def __iter__(self) -> Generator[Message, None, None]:
4851
yield from self.messages
4952

gptme/message.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import dataclasses
2-
import hashlib
32
import logging
43
import shutil
54
import sys
@@ -19,8 +18,9 @@
1918

2019
from .codeblock import Codeblock
2120
from .constants import ROLE_COLOR
22-
from .util import console, get_tokenizer
21+
from .util import console
2322
from .util.prompt import rich_to_str
23+
from .util.tokens import len_tokens
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -68,6 +68,9 @@ def __eq__(self, other):
6868
and self.timestamp == other.timestamp
6969
)
7070

71+
def len_tokens(self, model: str) -> int:
72+
return len_tokens(self, model=model)
73+
7174
def replace(self, **kwargs) -> Self:
7275
"""Replace attributes of the message."""
7376
return dataclasses.replace(self, **kwargs)
@@ -326,43 +329,3 @@ def toml_to_msgs(toml: str) -> list[Message]:
326329
def msgs2dicts(msgs: list[Message]) -> list[dict]:
327330
"""Convert a list of Message objects to a list of dicts ready to pass to an LLM."""
328331
return [msg.to_dict(keys=["role", "content", "files", "call_id"]) for msg in msgs]
329-
330-
331-
# Global cache mapping hashes to token counts
332-
_token_cache: dict[tuple[str, str], int] = {}
333-
334-
335-
def _hash_content(content: str) -> str:
336-
"""Create a hash of the content"""
337-
return hashlib.sha256(content.encode()).hexdigest()
338-
339-
340-
def len_tokens(content: str | Message | list[Message], model: str) -> int:
341-
"""Get the number of tokens in a string, message, or list of messages.
342-
343-
Uses efficient caching with content hashing to minimize memory usage while
344-
maintaining fast repeated calculations, which is especially important for
345-
conversations with many messages.
346-
"""
347-
if isinstance(content, list):
348-
return sum(len_tokens(msg, model) for msg in content)
349-
if isinstance(content, Message):
350-
content = content.content
351-
352-
assert isinstance(content, str), content
353-
# Check cache using hash
354-
content_hash = _hash_content(content)
355-
cache_key = (content_hash, model)
356-
if cache_key in _token_cache:
357-
return _token_cache[cache_key]
358-
359-
# Calculate and cache
360-
count = len(get_tokenizer(model).encode(content, disallowed_special=[]))
361-
_token_cache[cache_key] = count
362-
363-
# Limit cache size by removing oldest entries if needed
364-
if len(_token_cache) > 1000:
365-
# Remove first item (oldest in insertion order)
366-
_token_cache.pop(next(iter(_token_cache)))
367-
368-
return count

gptme/tools/shell.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
import bashlex
2828

2929
from ..message import Message
30-
from ..util import get_installed_programs, get_tokenizer
30+
from ..util import get_installed_programs
3131
from ..util.ask_execute import execute_with_confirmation
3232
from ..util.output_storage import save_large_output
33+
from ..util.tokens import get_tokenizer
3334
from .base import (
3435
ConfirmFunc,
3536
Parameter,

gptme/util/__init__.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020
console = Console(log_path=False)
2121

22-
_warned_models = set()
23-
24-
25-
@lru_cache
26-
def get_tokenizer(model: str):
27-
import tiktoken # fmt: skip
28-
29-
if "gpt-4o" in model:
30-
return tiktoken.get_encoding("o200k_base")
31-
32-
try:
33-
return tiktoken.encoding_for_model(model)
34-
except KeyError:
35-
global _warned_models
36-
if model not in _warned_models:
37-
logger.debug(
38-
f"No tokenizer for '{model}'. Using tiktoken cl100k_base. Use results only as estimates."
39-
)
40-
_warned_models |= {model}
41-
return tiktoken.get_encoding("cl100k_base")
42-
4322

4423
def epoch_to_age(epoch, incl_date=False):
4524
# takes epoch and returns "x minutes ago", "3 hours ago", "yesterday", etc.

gptme/util/tokens.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import hashlib
2+
import logging
3+
import typing
4+
from functools import lru_cache
5+
6+
if typing.TYPE_CHECKING:
7+
import tiktoken # fmt: skip
8+
9+
from ..message import Message # fmt: skip
10+
11+
12+
# Global cache mapping hashes to token counts
13+
_token_cache: dict[tuple[str, str], int] = {}
14+
15+
_warned_models = set()
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@lru_cache
21+
def get_tokenizer(model: str) -> "tiktoken.Encoding":
22+
"""Get the tokenizer for a given model, with caching and fallbacks."""
23+
import tiktoken # fmt: skip
24+
25+
if "gpt-4o" in model:
26+
return tiktoken.get_encoding("o200k_base")
27+
28+
try:
29+
return tiktoken.encoding_for_model(model)
30+
except KeyError:
31+
global _warned_models
32+
if model not in _warned_models:
33+
logger.debug(
34+
f"No tokenizer for '{model}'. Using tiktoken cl100k_base. Use results only as estimates."
35+
)
36+
_warned_models |= {model}
37+
return tiktoken.get_encoding("cl100k_base")
38+
39+
40+
# perf trick: start background thread that pre-loads the gpt-4 and gpt-5 tokenizers
41+
# needs logic to wait for the tokenizer to be ready if requested before loaded
42+
# threading.Thread(target=get_tokenizer, args=("gpt-4",), daemon=True).start()
43+
# threading.Thread(target=get_tokenizer, args=("gpt-5",), daemon=True).start()
44+
45+
46+
def _hash_content(content: str) -> str:
47+
"""Create a hash of the content"""
48+
return hashlib.sha256(content.encode()).hexdigest()
49+
50+
51+
def len_tokens(content: "str | Message | list[Message]", model: str) -> int:
52+
"""Get the number of tokens in a string, message, or list of messages.
53+
54+
Uses efficient caching with content hashing to minimize memory usage while
55+
maintaining fast repeated calculations, which is especially important for
56+
conversations with many messages.
57+
"""
58+
from ..message import Message # fmt: skip
59+
60+
if isinstance(content, list):
61+
return sum(len_tokens(msg, model) for msg in content)
62+
if isinstance(content, Message):
63+
content = content.content
64+
65+
assert isinstance(content, str), content
66+
# Check cache using hash
67+
content_hash = _hash_content(content)
68+
cache_key = (content_hash, model)
69+
if cache_key in _token_cache:
70+
return _token_cache[cache_key]
71+
72+
# Calculate and cache
73+
tokenizer = get_tokenizer(model)
74+
count = len(tokenizer.encode(content, disallowed_special=[]))
75+
_token_cache[cache_key] = count
76+
77+
# Limit cache size by removing oldest entries if needed
78+
if len(_token_cache) > 1000:
79+
# Remove first item (oldest in insertion order)
80+
_token_cache.pop(next(iter(_token_cache)))
81+
82+
return count

0 commit comments

Comments
 (0)