Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix t5 tokenizer and prompt token failures #1966

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
:param properties (dict): other properties of the model, such as decoder strategy
"""
self.lmi_dist_config = LmiDistRbProperties(**properties)
self.model_type = getattr(kwargs.get("model_config", None),
"model_type", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: I see we are populating model_config, Where do we populate "model_type"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_config here is hugging face config.json which has model_type populated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh misread it, Model_type is read from model_config. Cool

super().__init__(self.lmi_dist_config)
self.supports_speculative_decoding = supports_speculative_decoding()
engine_kwargs = {}
Expand Down Expand Up @@ -83,8 +85,6 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
kwargs["warmup_prefill_tokens"] = _WARMUP_PREFILL_TOKENS
self.engine = engine_from_args(engine_args, **kwargs)
self.request_cache = OrderedDict()
self.model_type = getattr(kwargs.get("model_config", None),
"model_type", None)
self.lora_ids = defaultdict(lambda: len(self.lora_ids) + 1)

def reset(self) -> None:
Expand All @@ -96,6 +96,8 @@ def reset(self) -> None:
super().reset()

def get_tokenizer(self):
if "t5" == self.model_type:
return self.engine.preprocessor.tokenizer
maaquib marked this conversation as resolved.
Show resolved Hide resolved
return self.engine.preprocessor.tokenizer.tokenizer

def translate_lmi_dist_params(self, parameters: dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def update_request_cache_with_output(request_cache: OrderedDict,
if "prompt_tokens_details" not in request_cache[
request_id] and request_output.prompt_logprobs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, request_output should not have prompt_logprobs as well right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I can read, request_output.prompt_logprobs is set to [float("nan")] and seems like it's passing that check.

request_cache[request_id]["prompt_tokens_details"] = []
if not isinstance(request_output.prompt_token_ids, list):
## lmi-dist does not return prompt_token_ids for t5
request_output.prompt_token_ids = []
for index, prompt_token_id in enumerate(
request_output.prompt_token_ids):
prompt_token = Token(
Expand Down
Loading