Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: add CJK support to TextStreamer #22664

Merged
merged 2 commits into from
Apr 15, 2023

Conversation

bcol23
Copy link
Contributor

@bcol23 bcol23 commented Apr 8, 2023

What does this PR do?

This pull request adds support for streaming CJK (Chinese, Japanese, Korean) characters to the TextStreamer class. It now flushes the token cache if the last token is a CJK character, in addition to flushing it if the text ends with "\n" or " ". This prevents CJK characters from being stuck in token_cache.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 8, 2023

The documentation is not available anymore as the PR was closed or merged.

@amyeroberts
Copy link
Collaborator

cc @gante

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening the PR! It looks good to me, but I confess I am unknowledgeable about Chinese characters.

Before we merge, would you be able to post here a screen recording of streaming before/after the change? It would be useful for future reference, as testing this feature is very tricky 🙏

# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the trailing # character :)

(here and on the subsequent lines)

Copy link
Contributor Author

@bcol23 bcol23 Apr 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure. I just copy it from tokenization_bert.py, and it looks like all the _is_chinese_char functions in the repo are the same. Maybe we should leave it consistent with the others?

def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hah, I didn't know of this previous code block!

In that case yes, let's leave the two equal 👍

@bcol23
Copy link
Contributor Author

bcol23 commented Apr 15, 2023

Yes of course, I use this tiny script for screen recording.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer


class TextCJKStreamer(TextStreamer):
    def put(self, value):
        """
        Recives tokens, decodes them, and prints them to stdout as soon as they form entire words.
        """
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        # Add the new token to the cache and decodes the entire thing.
        self.token_cache.extend(value.tolist())
        text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)

        # After the symbol for a new line, we flush the cache.
        if text.endswith("\n"):
            printable_text = text[self.print_len :]
            self.token_cache = []
            self.print_len = 0
        # If the last token is a CJK character, we print the characters.
        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
            printable_text = text[self.print_len :]
            self.print_len += len(printable_text)
        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
        # which may change with the subsequent token -- there are probably smarter ways to do this!)
        else:
            printable_text = text[self.print_len : text.rfind(" ") + 1]
            self.print_len += len(printable_text)

        self.on_finalized_text(printable_text)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # This defines a "chinese character" as anything in the CJK Unicode block:
        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
        #
        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
        # despite its name. The modern Korean Hangul alphabet is a different block,
        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
        # space-separated words, so they are not treated specially and handled
        # like the all of the other languages.
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)  #
            or (cp >= 0x20000 and cp <= 0x2A6DF)  #
            or (cp >= 0x2A700 and cp <= 0x2B73F)  #
            or (cp >= 0x2B740 and cp <= 0x2B81F)  #
            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
        ):  #
            return True

        return False


tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
# Use CPU to make generation slow
model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
streamer = TextStreamer(tokenizer, skip_prompt=True)
cjk_streamer = TextCJKStreamer(tokenizer, skip_prompt=True)

prompt = "一个传奇的开端,一个不灭的神话,这不仅仅是一部电影,"
tokenized_inputs = tokenizer([prompt], return_tensors="pt")
print("Origin TextStreamer:")
tokenized_inputs = tokenized_inputs.to(model.device)
_ = model.generate(
    **tokenized_inputs,
    do_sample=False,
    streamer=streamer,
    min_new_tokens=64,
    max_new_tokens=128,
)
print("CJK TextStreamer")
_ = model.generate(
    **tokenized_inputs,
    do_sample=False,
    streamer=cjk_streamer,
    min_new_tokens=64,
    max_new_tokens=128,
)

prompt = "Suggest at least five related search terms to 'Mạng neural nhân tạo'."
tokenized_inputs = tokenizer([prompt], return_tensors="pt")
print("Origin TextStreamer:")
tokenized_inputs = tokenized_inputs.to(model.device)
_ = model.generate(
    **tokenized_inputs,
    do_sample=False,
    streamer=streamer,
    min_new_tokens=64,
    max_new_tokens=128,
)
print("CJK TextStreamer")
_ = model.generate(
    **tokenized_inputs,
    do_sample=False,
    streamer=cjk_streamer,
    min_new_tokens=64,
    max_new_tokens=128,
)

The model and prompts are from https://huggingface.co/bigscience/bloomz-560m. And here is the comparison. As you can see, before the Chinese text only prints when it meets "。".

Screen.Recording.2023-04-15.at.10.04.58.mov
Screen.Recording.2023-04-15.at.10.05.23.mov

@gante
Copy link
Member

gante commented Apr 15, 2023

@bcol23 thank you for adding the screen recordings for future reference 🙏

And thank you for making transformers a little bit more inclusive 🤗

@gante gante merged commit 28f26c1 into huggingface:main Apr 15, 2023
2 checks passed
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants