Skip to content

Commit

Permalink
Merge 50d4365 into 5e65669
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Feb 8, 2019
2 parents 5e65669 + 50d4365 commit 4addb54
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions autokeras/text/pretrained_bert/tokenization.py
Expand Up @@ -32,8 +32,10 @@
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual"
"-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased"
"-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
VOCAB_NAME = 'vocab.txt'
Expand Down Expand Up @@ -65,6 +67,7 @@ def whitespace_tokenize(text):

class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""

def __init__(self, vocab_file, do_lower_case=True):
if not os.path.isfile(vocab_file):
raise ValueError(
Expand Down Expand Up @@ -163,8 +166,8 @@ def tokenize(self, text):
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens

@classmethod
def _run_strip_accents(self, text):
@staticmethod
def _run_strip_accents(text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
Expand All @@ -175,8 +178,8 @@ def _run_strip_accents(self, text):
output.append(char)
return "".join(output)

@classmethod
def _run_split_on_punc(self, text):
@staticmethod
def _run_split_on_punc(text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
Expand All @@ -196,21 +199,21 @@ def _run_split_on_punc(self, text):
return ["".join(x) for x in output]

@classmethod
def _tokenize_chinese_chars(self, text):
def _tokenize_chinese_chars(cls, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
if cls._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)

@classmethod
def _is_chinese_char(self, cp):
@staticmethod
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
Expand All @@ -220,19 +223,22 @@ def _is_chinese_char(self, cp):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
chinese_character_ranges = [
(0x4E00, 0x9FFF),
(0x3400, 0x4DBF),
(0xF900, 0xFAFF),
(0x20000, 0x2A6DF),
(0x2A700, 0x2B73F),
(0x2B740, 0x2B81F),
(0x2B820, 0x2CEAF),
(0x2F800, 0x2FA1F)]
for start, end in chinese_character_ranges:
if start <= cp <= end:
return True
return False

@classmethod
def _clean_text(self, text):
@staticmethod
def _clean_text(text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
Expand Down Expand Up @@ -337,11 +343,11 @@ def _is_punctuation(char):
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
punctuation_ranges = [(33, 47), (58, 64), (91, 96), (123, 126)]
for start, end in punctuation_ranges:
if start <= cp <= end:
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

0 comments on commit 4addb54

Please sign in to comment.