-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: add CJK support to TextStreamer #22664
Conversation
The documentation is not available anymore as the PR was closed or merged. |
72ab28e
to
58a4dfe
Compare
cc @gante |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening the PR! It looks good to me, but I confess I am unknowledgeable about Chinese characters.
Before we merge, would you be able to post here a screen recording of streaming before/after the change? It would be useful for future reference, as testing this feature is very tricky 🙏
# like the all of the other languages. | ||
if ( | ||
(cp >= 0x4E00 and cp <= 0x9FFF) | ||
or (cp >= 0x3400 and cp <= 0x4DBF) # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the trailing #
character :)
(here and on the subsequent lines)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite sure. I just copy it from tokenization_bert.py
, and it looks like all the _is_chinese_char
functions in the repo are the same. Maybe we should leave it consistent with the others?
transformers/src/transformers/models/bert/tokenization_bert.py
Lines 481 to 503 in fb3aa06
def _is_chinese_char(self, cp): | |
"""Checks whether CP is the codepoint of a CJK character.""" | |
# This defines a "chinese character" as anything in the CJK Unicode block: | |
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |
# | |
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |
# despite its name. The modern Korean Hangul alphabet is a different block, | |
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | |
# space-separated words, so they are not treated specially and handled | |
# like the all of the other languages. | |
if ( | |
(cp >= 0x4E00 and cp <= 0x9FFF) | |
or (cp >= 0x3400 and cp <= 0x4DBF) # | |
or (cp >= 0x20000 and cp <= 0x2A6DF) # | |
or (cp >= 0x2A700 and cp <= 0x2B73F) # | |
or (cp >= 0x2B740 and cp <= 0x2B81F) # | |
or (cp >= 0x2B820 and cp <= 0x2CEAF) # | |
or (cp >= 0xF900 and cp <= 0xFAFF) | |
or (cp >= 0x2F800 and cp <= 0x2FA1F) # | |
): # | |
return True | |
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hah, I didn't know of this previous code block!
In that case yes, let's leave the two equal 👍
Yes of course, I use this tiny script for screen recording. import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
class TextCJKStreamer(TextStreamer):
def put(self, value):
"""
Recives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.on_finalized_text(printable_text)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
# Use CPU to make generation slow
model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
streamer = TextStreamer(tokenizer, skip_prompt=True)
cjk_streamer = TextCJKStreamer(tokenizer, skip_prompt=True)
prompt = "一个传奇的开端,一个不灭的神话,这不仅仅是一部电影,"
tokenized_inputs = tokenizer([prompt], return_tensors="pt")
print("Origin TextStreamer:")
tokenized_inputs = tokenized_inputs.to(model.device)
_ = model.generate(
**tokenized_inputs,
do_sample=False,
streamer=streamer,
min_new_tokens=64,
max_new_tokens=128,
)
print("CJK TextStreamer")
_ = model.generate(
**tokenized_inputs,
do_sample=False,
streamer=cjk_streamer,
min_new_tokens=64,
max_new_tokens=128,
)
prompt = "Suggest at least five related search terms to 'Mạng neural nhân tạo'."
tokenized_inputs = tokenizer([prompt], return_tensors="pt")
print("Origin TextStreamer:")
tokenized_inputs = tokenized_inputs.to(model.device)
_ = model.generate(
**tokenized_inputs,
do_sample=False,
streamer=streamer,
min_new_tokens=64,
max_new_tokens=128,
)
print("CJK TextStreamer")
_ = model.generate(
**tokenized_inputs,
do_sample=False,
streamer=cjk_streamer,
min_new_tokens=64,
max_new_tokens=128,
) The model and prompts are from https://huggingface.co/bigscience/bloomz-560m. And here is the comparison. As you can see, before the Chinese text only prints when it meets "。". Screen.Recording.2023-04-15.at.10.04.58.movScreen.Recording.2023-04-15.at.10.05.23.mov |
@bcol23 thank you for adding the screen recordings for future reference 🙏 And thank you for making |
What does this PR do?
This pull request adds support for streaming CJK (Chinese, Japanese, Korean) characters to the TextStreamer class. It now flushes the token cache if the last token is a CJK character, in addition to flushing it if the text ends with
"\n"
or" "
. This prevents CJK characters from being stuck intoken_cache
.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante