Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ python = [
]

[dependencies]
pyo3 = { version = "0.26", default-features = false, features = [
pyo3 = { version = "0.26.0", default-features = false, features = [
"extension-module",
"macros",
], optional = true }

# tiktoken dependencies
fancy-regex = "0.16"
fancy-regex = "0.13.0"
regex = "1.10.3"
rustc-hash = "2"
bstr = "1.5.0"
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ skip = [
"*-manylinux_i686",
"*-musllinux_i686",
"*-win32",
"*-musllinux_aarch64",
]
macos.archs = ["x86_64", "arm64"]
# When cross-compiling on Intel, it is not possible to test arm64 wheels.
Expand Down
12 changes: 6 additions & 6 deletions src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl CoreBPE {

#[pyo3(name = "encode_ordinary")]
fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
py.allow_threads(|| self.encode_ordinary(text))
py.detach(|| self.encode_ordinary(text))
}

#[pyo3(name = "encode")]
Expand All @@ -38,7 +38,7 @@ impl CoreBPE {
text: &str,
allowed_special: HashSet<PyBackedStr>,
) -> PyResult<Vec<Rank>> {
py.allow_threads(|| {
py.detach(|| {
let allowed_special: HashSet<&str> =
allowed_special.iter().map(|s| s.as_ref()).collect();
match self.encode(text, &allowed_special) {
Expand All @@ -54,7 +54,7 @@ impl CoreBPE {
text: &str,
allowed_special: HashSet<PyBackedStr>,
) -> PyResult<Py<PyAny>> {
let tokens_res = py.allow_threads(|| {
let tokens_res = py.detach(|| {
let allowed_special: HashSet<&str> =
allowed_special.iter().map(|s| s.as_ref()).collect();
self.encode(text, &allowed_special)
Expand All @@ -70,7 +70,7 @@ impl CoreBPE {
}

fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
py.allow_threads(|| {
py.detach(|| {
match std::str::from_utf8(bytes) {
// Straightforward case
Ok(text) => self.encode_ordinary(text),
Expand Down Expand Up @@ -121,7 +121,7 @@ impl CoreBPE {
text: &str,
allowed_special: HashSet<PyBackedStr>,
) -> PyResult<(Vec<Rank>, Py<PyList>)> {
let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.allow_threads(|| {
let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.detach(|| {
let allowed_special: HashSet<&str> =
allowed_special.iter().map(|s| s.as_ref()).collect();
self._encode_unstable_native(text, &allowed_special)
Expand Down Expand Up @@ -155,7 +155,7 @@ impl CoreBPE {

#[pyo3(name = "decode_bytes")]
fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
match py.allow_threads(|| self.decode_bytes(&tokens)) {
match py.detach(|| self.decode_bytes(&tokens)) {
Ok(bytes) => Ok(PyBytes::new(py, &bytes).into()),
Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
}
Expand Down
17 changes: 12 additions & 5 deletions tiktoken/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence

import regex

from tiktoken import _tiktoken

if TYPE_CHECKING:
import re

import numpy as np
import numpy.typing as npt

Expand Down Expand Up @@ -391,6 +391,9 @@ def _encode_single_piece(self, text_or_bytes: str | bytes) -> list[int]:

def _encode_only_native_bpe(self, text: str) -> list[int]:
"""Encodes a string into tokens, but do regex splitting in Python."""
# We need specifically `regex` in order to compile pat_str due to e.g. \p
import regex

_unused_pat = regex.compile(self._pat_str)
ret = []
for piece in regex.findall(_unused_pat, text):
Expand Down Expand Up @@ -423,9 +426,13 @@ def __setstate__(self, value: object) -> None:


@functools.lru_cache(maxsize=128)
def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":
inner = "|".join(regex.escape(token) for token in tokens)
return regex.compile(f"({inner})")
def _special_token_regex(tokens: frozenset[str]) -> re.Pattern[str]:
try:
import regex as re
except ImportError:
import re
inner = "|".join(re.escape(token) for token in tokens)
return re.compile(f"({inner})")


def raise_disallowed_special_token(token: str) -> NoReturn:
Expand Down
32 changes: 18 additions & 14 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@


def read_file(blobpath: str) -> bytes:
if not blobpath.startswith("http://") and not blobpath.startswith("https://"):
try:
import blobfile
except ImportError as e:
raise ImportError(
"blobfile is not installed. Please install it by running `pip install blobfile`."
) from e
with blobfile.BlobFile(blobpath, "rb") as f:
if "://" not in blobpath:
with open(blobpath, "rb", buffering=0) as f:
return f.read()

# avoiding blobfile for public files helps avoid auth issues, like MFA prompts.
import requests
if blobpath.startswith(("http://", "https://")):
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts.
import requests

resp = requests.get(blobpath)
resp.raise_for_status()
return resp.content

resp = requests.get(blobpath)
resp.raise_for_status()
return resp.content
try:
import blobfile
except ImportError as e:
raise ImportError(
"blobfile is not installed. Please install it by running `pip install blobfile`."
) from e
with blobfile.BlobFile(blobpath, "rb") as f:
return f.read()


def check_hash(data: bytes, expected_hash: str) -> bool:
Expand Down Expand Up @@ -49,7 +53,7 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:

cache_path = os.path.join(cache_dir, cache_key)
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
with open(cache_path, "rb", buffering=0) as f:
data = f.read()
if expected_hash is None or check_hash(data, expected_hash):
return data
Expand Down
Loading