Skip to content

Commit

Permalink
remove dependency on API-based tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Apr 17, 2023
1 parent 9766bbc commit 3736fcb
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 125 deletions.
110 changes: 0 additions & 110 deletions src/lmql/model/served_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,79 +122,6 @@ def _timeout(self, sample_id):
self.pending[sample_id].set_exception(TimeoutError("Timed out waiting for result of sample {}".format(sample_id)))
del self.pending[sample_id]

async def detokenize(self, input_ids):
self.create_result_processor_task_if_required()
loop = asyncio.get_event_loop()

# handle torch tensors
if type(input_ids) is torch.Tensor:
input_ids = input_ids.tolist()

sample_id = self.sample_id
self.sample_id += 1

payload = {
"action": "detokenize",
"client_id": self.client_id,
"input_ids": input_ids,
"sample_id": sample_id
}

try:
assert sample_id not in self.pending, "sample_id {} already in self.pending".format(sample_id)
# setup future and timeout
self.pending[sample_id] = loop.create_future()
r = requests.post(f"{self.host}/queue", json=payload)

if r.status_code != 200:
raise Exception(f"Error posting to {self.host}/queue: {r.status_code}")

loop.call_later(self.timeout, self._timeout, sample_id)

return (await self.pending[sample_id])["text"]
except requests.exceptions.ConnectionError as e:
# check for connection refused
if "Connection refused" in str(e):
raise Exception(f"Error connecting to {self.host}/queue. Please make sure an instance of the LMQL inference API for this model is running. To start it, run `python -m serve <MODEL>`.")
else:
raise e

async def tokenize(self, text):
self.create_result_processor_task_if_required()
loop = asyncio.get_event_loop()

assert type(text) == str, "The provided text for tokenize() must be str."

sample_id = self.sample_id
self.sample_id += 1

payload = {
"action": "tokenize",
"text": text,
"client_id": self.client_id,
"sample_id": sample_id
}

try:
assert sample_id not in self.pending, "sample_id {} already in self.pending".format(sample_id)
# setup future and timeout
self.pending[sample_id] = loop.create_future()

r = requests.post(f"{self.host}/queue", json=payload)

if r.status_code != 200:
raise Exception(f"Error posting to {self.host}/queue: {r.status_code}")

loop.call_later(self.timeout, self._timeout, sample_id)

return (await self.pending[sample_id])["input_ids"]
except requests.exceptions.ConnectionError as e:
# check for connection refused
if "Connection refused" in str(e):
raise Exception(f"Error connecting to {self.host}/queue. Please make sure an instance of the LMQL inference API for this model is running. To start it, run `python -m serve <MODEL>`.")
else:
raise e

def reset_stats(self):
self.consumed_tokens = 0
self.num_queries = 0
Expand Down Expand Up @@ -252,43 +179,6 @@ def forward(self, input_ids, attention_mask=None):
else:
raise e


class Sample:
def __init__(self, model, seq=None, input_ids=None):
assert not (self.seq is None and self.input_ids), "Either seq or input_ids must be provided for a Sample()"

self.model = model
self._seq = seq
self._input_ids = input_ids

async def seq(self):
if self._seq is None:
assert self._input_ids is not None
self._seq = await self.model.detokenize(self._input_ids)
return self._seq

async def input_ids(self):
if self._input_ids is None:
assert self.seq is not None
self._input_ids = await self.model.tokenize(self._seq)
return self._input_ids

async def successor(self, token_id) -> 'Sample':
return Sample(self.model, input_ids=(await self.input_ids()) + [token_id])

async def distribution(self):
res = await self.model.forward(await self.input_ids())
return res["next_token_logits"]

def __repr__(self) -> str:
return "<Sample seq={} input_ids={}>".format(self._seq, self._input_ids)

async def concat(self, text):
seq = (await self.seq()) + text
input_ids = await self.model.tokenize(seq)

return Sample(self.model, seq=seq, input_ids=input_ids)

@dataclass
class ServedPretrainedModelOutput:
logits: torch.Tensor
Expand Down
7 changes: 4 additions & 3 deletions src/lmql/runtime/dclib/dclib_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def stack_logit_masks(self, logit_masks):
ModelQueue._instances = {}

class DcModel:
def __init__(self, model, bos_token_id, eos_token_id, truncation_threshold=-3e38, init_workers=True, **kwargs):
def __init__(self, model, tokenizer, truncation_threshold=-3e38, init_workers=True, **kwargs):
"""
Parameters:
Expand All @@ -108,11 +108,12 @@ def __init__(self, model, bos_token_id, eos_token_id, truncation_threshold=-3e38
truncation_threshold: The threshold to use for logit truncation (cf. DecoderSequence.truncation_threshold). Logits below this threshold are considered to be -inf and will never be considered as next token.
"""
self.model = model
self.tokenizer = tokenizer
self.model_identifier = model.model_identifier
self.model_args = kwargs

self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
self.truncation_threshold = truncation_threshold

self.stats = Stats("dcmodel")
Expand Down
35 changes: 26 additions & 9 deletions src/lmql/runtime/hf_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import asyncio

from lmql.model.served_model import ServedPretrainedModel
import lmql.runtime.dclib as dc
Expand All @@ -9,6 +10,16 @@ def transformers_model(endpoint, model_identifier):
import torch

class NumpyBridgedServedPretrainedModel(ServedPretrainedModel):
def __init__(self, transformers_model: 'TransformersModel', *args, **kwargs):
super().__init__(*args, **kwargs)
self.transformers_model = transformers_model

async def tokenize(self, text):
return await self.transformers_model.tokenize(text)

async def detokenize(self, input_ids):
return await self.transformers_model.detokenize(input_ids)

def __getattribute__(self, __name: str):
if __name == "__dict__":
return super().__getattribute__(__name)
Expand All @@ -35,26 +46,32 @@ def __init__(self):
local = self.model_identifier.startswith("local:")
if local:
self.model_identifier = self.model_identifier.split(":")[1]
self.served_model = NumpyBridgedServedPretrainedModel(endpoint, self.model_identifier, use_tq=False, local=local)
self.served_model = NumpyBridgedServedPretrainedModel(self, endpoint, self.model_identifier, use_tq=False, local=local)

self.tokenizer = load_tokenizer(self.model_identifier)

def get_tokenizer(self):
return self.tokenizer

async def tokenize(self, text):
input_ids = self.get_tokenizer()(text)["input_ids"]
# strip off bos if present, LMQL handles this internally
if len(input_ids) > 0 and input_ids[0] == self.bos_token_id:
input_ids = input_ids[1:]
return [i for i in input_ids if i is not None]
async def task(text):
input_ids = self.get_tokenizer()(text)["input_ids"]
# strip off bos if present, LMQL handles this internally
if len(input_ids) > 0 and input_ids[0] == self.tokenizer.bos_token_id:
input_ids = input_ids[1:]
return [i for i in input_ids if i is not None]
t = asyncio.create_task(task(text))
return (await t)

async def detokenize(self, input_ids):
input_ids = [i for i in input_ids if i is not None]
return self.get_tokenizer().decode(input_ids)
async def task(input_ids):
input_ids = [i for i in input_ids if i is not None]
return self.get_tokenizer().decode(input_ids)
t = asyncio.create_task(task(input_ids))
return (await t)

def get_dclib_model(self):
dc.set_dclib_tokenizer(dc.tokenizer("lmql-adapter-tokenizer", self.tokenize, self.detokenize, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id))
return dc.DcModel(self.served_model, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
return dc.DcModel(self.served_model, self.tokenizer)

return TransformersModel
4 changes: 3 additions & 1 deletion src/lmql/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ async def rewrite_for_sequence(self, seq: dc.DecoderSequence, needs_rewrite):
return RewrittenInputIds(appended_input_ids=None, strip_eos=False, user_data=user_data)

async def tokenize(self, *args):
# tokenize should be specific to the current model in use (infer from currently process
# dc.seq, interpreter should not be tokenizer-specific)
async def task():
return self.tokenizer(*args)["input_ids"]
t = asyncio.create_task(task())
Expand Down Expand Up @@ -538,7 +540,7 @@ async def run(self, fct, **kwargs):
print("warning: no_repeat_ngram_size is known to cause issues when used with constrained decoding, including non-termination.")

# tokenize initial prompt
prompt_ids = await self.dcmodel.tokenize(self.root_state.prompt)
prompt_ids = await self.tokenize(self.root_state.prompt)
if self.dcmodel.bos_token_id is not None:
prompt_ids = [self.dcmodel.bos_token_id] + prompt_ids
n = len(prompt_ids)
Expand Down
3 changes: 1 addition & 2 deletions src/lmql/runtime/openai_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,6 @@ async def tokenize(self, *args, **kwargs):
def task():
return self.tokenizer.encode(*args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, task)
return self.tokenizer.tokenize(*args, **kwargs)

async def detokenize(self, *args, **kwargs):
def task():
Expand Down Expand Up @@ -947,7 +946,7 @@ def get_dclib_model(self):

dc.set_dclib_tokenizer(dc.tokenizer("lmql-adapter-tokenizer", self.tokenize, self.detokenize, bos_token_id, eos_token_id))

return DclibOpenAiModel(self.served_model, bos_token_id, eos_token_id)
return DclibOpenAiModel(self.served_model, self.get_tokenizer())

async def tokenize(self, text):
return self.get_tokenizer()(text)["input_ids"]
Expand Down
2 changes: 2 additions & 0 deletions src/lmql/runtime/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

def get_js_tokenizer(model_identifier):
import js
from pyodide.ffi import to_js
Expand Down

0 comments on commit 3736fcb

Please sign in to comment.