Skip to content

Commit

Permalink
Fixes for worker prompt truncation in ChatML case (#3673)
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
  • Loading branch information
olliestanley and andreaskoepf committed Aug 29, 2023
1 parent d613c81 commit adcf8dc
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
2 changes: 1 addition & 1 deletion inference/worker/__main__.py
Expand Up @@ -34,7 +34,7 @@ def main():
tokenizer = None
else:
tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}")
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {len(tokenizer)}")

inference_http = utils.HttpClient(
base_url=settings.inference_server_url,
Expand Down
2 changes: 1 addition & 1 deletion inference/worker/basic_hf_server.py
Expand Up @@ -138,7 +138,7 @@ def load_models():
hf_config = transformers.AutoConfig.from_pretrained(model_config.model_id)
logger.warning(f"Loading model {model_config.model_id}...")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {tokenizer.vocab_size}")
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {len(tokenizer)}")

# see `decode_token` method, taken from HF text-generation-inference
tokenizer.add_special_tokens({"additional_special_tokens": ["<decode-token>"]})
Expand Down
21 changes: 13 additions & 8 deletions inference/worker/utils.py
Expand Up @@ -95,11 +95,14 @@ def get_max_input_length(worker_config: inference.WorkerConfig, plugin_used: boo
return max_input_length


def get_tokens_until(tokens: list[int], target: int | list[int]) -> list[int]:
if isinstance(target, int):
return tokens[: tokens.index(target)]
else:
return next((i for i in range(len(tokens) - len(target) + 1) if tokens[i : i + len(target)] == target))
def get_tokens_until(tokens: list[int], target: list[int]) -> list[int]:
if len(target) == 1:
return tokens[: tokens.index(target[0])]

for i in range(len(tokens) - len(target)):
if tokens[i : i + len(target)] == target:
break
return tokens[:i]


def truncate_prompt(
Expand All @@ -118,8 +121,8 @@ def truncate_prompt(
"""
with shared_tokenizer_lock:
ids = tokenizer.encode(prompt)
# prompter_prefix_ids could be int or list of ints
prompter_prefix_ids = tokenizer.convert_tokens_to_ids(special_tokens["prompter"])
# list of int IDs
prompter_prefix_ids = tokenizer.encode(special_tokens["prompter"])

system_prompt: str | None = None
system_tokens: list[int] | None = None
Expand All @@ -134,7 +137,9 @@ def truncate_prompt(

num_system_tokens = len(system_tokens) if system_tokens else 0
# Maximum token allowed for the conversation, ex system prompt
max_conversation_length = max_input_length - num_system_tokens
# We incorporate a buffer to allow for final inference tokenization differing from ours
# This is a slightly hacky workaround and it would be better to find a cleaner solution
max_conversation_length = max_input_length - num_system_tokens - int(0.01 * max_input_length)
ids = ids[-(max_conversation_length - 1) :]

with shared_tokenizer_lock:
Expand Down
5 changes: 5 additions & 0 deletions oasst-shared/oasst_shared/model_configs.py
Expand Up @@ -150,4 +150,9 @@ def compat_hash(self) -> str:
max_input_length=3072,
max_total_length=4096,
),
"OA_SFT_CodeLlama_13B_10": ModelConfig(
model_id="OpenAssistant/codellama-13b-oasst-sft-v10",
max_input_length=8192,
max_total_length=12288,
),
}

0 comments on commit adcf8dc

Please sign in to comment.