Skip to content

Commit

Permalink
fix t5 tokenizer and prompt token failures (#1966)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithkrn authored May 28, 2024
1 parent e78aa57 commit 714fabf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
:param properties (dict): other properties of the model, such as decoder strategy
"""
self.lmi_dist_config = LmiDistRbProperties(**properties)
self.model_type = getattr(kwargs.get("model_config", None),
"model_type", None)
super().__init__(self.lmi_dist_config)
self.supports_speculative_decoding = supports_speculative_decoding()
engine_kwargs = {}
Expand Down Expand Up @@ -83,8 +85,6 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
kwargs["warmup_prefill_tokens"] = _WARMUP_PREFILL_TOKENS
self.engine = engine_from_args(engine_args, **kwargs)
self.request_cache = OrderedDict()
self.model_type = getattr(kwargs.get("model_config", None),
"model_type", None)
self.lora_ids = defaultdict(lambda: len(self.lora_ids) + 1)

def reset(self) -> None:
Expand All @@ -96,6 +96,8 @@ def reset(self) -> None:
super().reset()

def get_tokenizer(self):
if "t5" == self.model_type:
return self.engine.preprocessor.tokenizer
return self.engine.preprocessor.tokenizer.tokenizer

def translate_lmi_dist_params(self, parameters: dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def update_request_cache_with_output(request_cache: OrderedDict,
if "prompt_tokens_details" not in request_cache[
request_id] and request_output.prompt_logprobs:
request_cache[request_id]["prompt_tokens_details"] = []
if not isinstance(request_output.prompt_token_ids, list):
## lmi-dist does not return prompt_token_ids for t5
request_output.prompt_token_ids = []
for index, prompt_token_id in enumerate(
request_output.prompt_token_ids):
prompt_token = Token(
Expand Down

0 comments on commit 714fabf

Please sign in to comment.