Skip to content

Commit

Permalink
Generate: add CJK support to TextStreamer (#22664)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcol23 committed Apr 15, 2023
1 parent fb3aa06 commit 28f26c1
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def put(self, value):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
Expand All @@ -127,6 +131,30 @@ def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
print(text, flush=True, end="" if not stream_end else None)

def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True

return False


class TextIteratorStreamer(TextStreamer):
"""
Expand Down

0 comments on commit 28f26c1

Please sign in to comment.