|
1 | 1 | import dataclasses |
2 | | -import hashlib |
3 | 2 | import logging |
4 | 3 | import shutil |
5 | 4 | import sys |
|
19 | 18 |
|
20 | 19 | from .codeblock import Codeblock |
21 | 20 | from .constants import ROLE_COLOR |
22 | | -from .util import console, get_tokenizer |
| 21 | +from .util import console |
23 | 22 | from .util.prompt import rich_to_str |
| 23 | +from .util.tokens import len_tokens |
24 | 24 |
|
25 | 25 | logger = logging.getLogger(__name__) |
26 | 26 |
|
@@ -68,6 +68,9 @@ def __eq__(self, other): |
68 | 68 | and self.timestamp == other.timestamp |
69 | 69 | ) |
70 | 70 |
|
| 71 | + def len_tokens(self, model: str) -> int: |
| 72 | + return len_tokens(self, model=model) |
| 73 | + |
71 | 74 | def replace(self, **kwargs) -> Self: |
72 | 75 | """Replace attributes of the message.""" |
73 | 76 | return dataclasses.replace(self, **kwargs) |
@@ -326,43 +329,3 @@ def toml_to_msgs(toml: str) -> list[Message]: |
326 | 329 | def msgs2dicts(msgs: list[Message]) -> list[dict]: |
327 | 330 | """Convert a list of Message objects to a list of dicts ready to pass to an LLM.""" |
328 | 331 | return [msg.to_dict(keys=["role", "content", "files", "call_id"]) for msg in msgs] |
329 | | - |
330 | | - |
331 | | -# Global cache mapping hashes to token counts |
332 | | -_token_cache: dict[tuple[str, str], int] = {} |
333 | | - |
334 | | - |
335 | | -def _hash_content(content: str) -> str: |
336 | | - """Create a hash of the content""" |
337 | | - return hashlib.sha256(content.encode()).hexdigest() |
338 | | - |
339 | | - |
340 | | -def len_tokens(content: str | Message | list[Message], model: str) -> int: |
341 | | - """Get the number of tokens in a string, message, or list of messages. |
342 | | -
|
343 | | - Uses efficient caching with content hashing to minimize memory usage while |
344 | | - maintaining fast repeated calculations, which is especially important for |
345 | | - conversations with many messages. |
346 | | - """ |
347 | | - if isinstance(content, list): |
348 | | - return sum(len_tokens(msg, model) for msg in content) |
349 | | - if isinstance(content, Message): |
350 | | - content = content.content |
351 | | - |
352 | | - assert isinstance(content, str), content |
353 | | - # Check cache using hash |
354 | | - content_hash = _hash_content(content) |
355 | | - cache_key = (content_hash, model) |
356 | | - if cache_key in _token_cache: |
357 | | - return _token_cache[cache_key] |
358 | | - |
359 | | - # Calculate and cache |
360 | | - count = len(get_tokenizer(model).encode(content, disallowed_special=[])) |
361 | | - _token_cache[cache_key] = count |
362 | | - |
363 | | - # Limit cache size by removing oldest entries if needed |
364 | | - if len(_token_cache) > 1000: |
365 | | - # Remove first item (oldest in insertion order) |
366 | | - _token_cache.pop(next(iter(_token_cache))) |
367 | | - |
368 | | - return count |
0 commit comments