Skip to content

Commit

Permalink
clarify and enforce apis
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 14, 2024
1 parent cf3984b commit c98b1aa
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
27 changes: 16 additions & 11 deletions py/llguidance/python/llguidance/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,32 @@ class TokenizerWrapper:
tokens: Sequence[bytes]

def __init__(self, gtokenizer) -> None:
self.gtokenizer = gtokenizer
# these are required by LLTokenizer:
self.eos_token_id = gtokenizer.eos_token_id
self.bos_token_id = gtokenizer.bos_token_id
self.tokens = gtokenizer.tokens
self.accepts_bytes = True
self.is_tokenizer_wrapper = True

# more private stuff
self._gtokenizer = gtokenizer
self._accepts_bytes = True
try:
gtokenizer(b"test")
except:
self.accepts_bytes = False
self._accepts_bytes = False
# If the tokenizer used bytes, then b"\xff" would be better (since it's invalid UTF-8)
# For now, we'll settle for "\x02" as assume it doesn't start any other token
self.prefix_string = "\x02"
self.prefix_tokens = self._encode_string(self.prefix_string)
self._prefix_string = "\x02"
self._prefix_tokens = self._encode_string(self._prefix_string)

def _encode_string(self, s: str) -> List[TokenId]:
if self.accepts_bytes:
return self.gtokenizer(s.encode("utf-8"))
if self._accepts_bytes:
return self._gtokenizer(s.encode("utf-8"))
else:
return self.gtokenizer(s)
return self._gtokenizer(s)

# required by LLTokenizer
def __call__(self, s: str):
tokens = self._encode_string(self.prefix_string + s)
assert tokens[: len(self.prefix_tokens)] == self.prefix_tokens
return tokens[len(self.prefix_tokens) :]
tokens = self._encode_string(self._prefix_string + s)
assert tokens[: len(self._prefix_tokens)] == self._prefix_tokens
return tokens[len(self._prefix_tokens) :]
21 changes: 16 additions & 5 deletions py/llguidance/rust/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,24 @@ struct PyMidProcessResult {
#[pymethods]
impl LLTokenizer {
#[new]
fn py_new(gtokenizer: Bound<'_, PyAny>) -> PyResult<Self> {
let tok_eos = gtokenizer.getattr("eos_token_id")?.extract::<u32>()?;
let tok_bos = gtokenizer
fn py_new(tokenizer: Bound<'_, PyAny>) -> PyResult<Self> {
let is_tokenizer = tokenizer
.getattr("is_tokenizer_wrapper")
.map(|v| v.extract::<bool>())
.unwrap_or(Ok(false))
.unwrap_or(false);
if !is_tokenizer {
return Err(PyValueError::new_err(
"Expecting a TokenizerWrapper() class",
));
}

let tok_eos = tokenizer.getattr("eos_token_id")?.extract::<u32>()?;
let tok_bos = tokenizer
.getattr("bos_token_id")?
.extract::<u32>()
.map_or(None, |v| Some(v));
let tokens = gtokenizer.getattr("tokens")?.extract::<Vec<Vec<u8>>>()?;
let tokens = tokenizer.getattr("tokens")?.extract::<Vec<Vec<u8>>>()?;
let info = TokRxInfo {
vocab_size: tokens.len() as u32,
tok_eos,
Expand All @@ -121,7 +132,7 @@ impl LLTokenizer {
let tok_trie = TokTrie::from(&info, &tokens);
Ok(LLTokenizer {
tok_trie: Arc::new(tok_trie),
tokenizer_fun: gtokenizer.into(),
tokenizer_fun: tokenizer.into(),
tok_bos,
})
}
Expand Down
3 changes: 2 additions & 1 deletion py/llguidance/t1.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def run_constraint(tok: llguidance.LLTokenizer, e: LlamaCppEngine, grm: guidance
def main():
#m = guidance.models.Transformers(model="../../tmp/Phi-3-mini-128k-instruct/", trust_remote_code=True)
m = guidance.models.LlamaCpp(model="../../tmp/Phi-3-mini-4k-instruct-q4.gguf")
t = llguidance.TokenizerWrapper(m.engine.tokenizer)
t = m.engine.tokenizer
t = llguidance.TokenizerWrapper(t)
t = llguidance.LLTokenizer(t)
assert t.tokenize_str("") == []
assert t.tokenize_str(" ") == [29871]
Expand Down

0 comments on commit c98b1aa

Please sign in to comment.