Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

long docs unknown tokens takes a long time #2569

Merged
merged 3 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 68 additions & 25 deletions deepchecks/nlp/checks/data_integrity/unknown_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# ----------------------------------------------------------------------------
#
"""Module contains the Unknown Tokens check."""
import os
import typing as t
import warnings
from collections import Counter
Expand Down Expand Up @@ -38,7 +39,7 @@ class UnknownTokens(SingleDatasetCheck):
----------
tokenizer: t.Any , default: None
Tokenizer from the HuggingFace transformers library to use for tokenization. If None,
BertTokenizer.from_pretrained('bert-base-uncased') will be used.
AutoTokenizer.from_pretrained('bert-base-uncased') will be used.
group_singleton_words: bool, default: False
If True, group all words that appear only once in the data into the "Other" category in the display.
n_most_common : int , default: 5
Expand All @@ -64,15 +65,17 @@ def __init__(
self.tokenizer = tokenizer
if tokenizer is None:
try:
from transformers import BertTokenizer # pylint: disable=W0611,C0415 # noqa
from transformers import AutoTokenizer # pylint: disable=W0611,C0415 # noqa
except ImportError as e:
raise DeepchecksProcessError(
'Tokenizer was not provided. In order to use checks default '
'tokenizer (BertTokenizer), please run:\n>> pip install transformers>=4.27.4.'
'tokenizer (bert-base-uncased), please run:\n>> pip install transformers>=4.27.4.'
) from e
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
else:
self._validate_tokenizer()
self._use_fast_method = self.tokenizer.is_fast

self.group_singleton_words = group_singleton_words
self.n_most_common = n_most_common
self.n_samples = n_samples
Expand All @@ -87,6 +90,8 @@ def _validate_tokenizer(self):
raise DeepchecksValueError('tokenizer must have an "unk_token_id" attribute')
if not hasattr(self.tokenizer, 'convert_tokens_to_ids'):
raise DeepchecksValueError('tokenizer must have an "convert_tokens_to_ids" method')
if not hasattr(self.tokenizer, 'is_fast'):
raise DeepchecksValueError('tokenizer must have an "is_fast" method')

def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check."""
Expand Down Expand Up @@ -123,30 +128,68 @@ def run_logic(self, context: Context, dataset_kind) -> CheckResult:

return CheckResult(value, display=display)

def _get_non_text_token_ids(self):
noamzbr marked this conversation as resolved.
Show resolved Hide resolved
"""Get ids of all non-text tokens in the tokenizer.

These include notably the [CLS] token marking the beginning of the sequence, the [SEP] token marking the end
of the sequence, and the [PAD] token used for padding.
"""
non_text_token_ids = []
for token_name, token in self.tokenizer.special_tokens_map.items():
if token_name not in ['unk_token']:
non_text_token_ids.append(self.tokenizer.convert_tokens_to_ids(token))
return non_text_token_ids

def find_unknown_words(self, samples, indices):
"""Find words with unknown tokens in samples."""
# Choose tokenizer based on availability of nltk
if nltk.download('punkt', quiet=True):
tokenize = nltk.word_tokenize
all_unknown_tokens = []
unknown_token_indexes = {}
total_tokens = 0

if self._use_fast_method:
non_text_token_ids = self._get_non_text_token_ids()

# Batch tokenization
# ------------------
# Needed to avoid warning when used after loading a hub dataset
os.environ['TOKENIZERS_PARALLELISM '] = 'true'
tokenized_samples = self.tokenizer(list(samples), return_offsets_mapping=True, is_split_into_words=False,
truncation=False)

for idx, (tokens, offsets_mapping, sample) in zip(indices, zip(tokenized_samples['input_ids'],
tokenized_samples['offset_mapping'],
samples)):
for token_id, offset_mapping in zip(tokens, offsets_mapping):
if token_id == self.tokenizer.unk_token_id:
start, end = offset_mapping
token = sample[start:end]
all_unknown_tokens.append(token)
unknown_token_indexes.setdefault(token, []).append(idx)
if token_id not in non_text_token_ids:
total_tokens += 1
else:
warnings.warn('nltk punkt is not available, using str.split instead to identify individual words. '
'Please check your internet connection.')
tokenize = str.split

# Tokenize samples and count unknown words
words_array = [tokenize(sample) for sample in samples]
all_unknown_words = []
unknown_word_indexes = {}
total_words = 0
for idx, words in zip(indices, words_array):
total_words += len(words)
for word in words:
tokens = self.tokenizer.tokenize(word)
if any(self.tokenizer.convert_tokens_to_ids(token) == self.tokenizer.unk_token_id for token in tokens):
all_unknown_words.append(word)
unknown_word_indexes.setdefault(word, []).append(idx)

return Counter(all_unknown_words), total_words, unknown_word_indexes
# Tokenization for each word
# --------------------------
# Choose tokenizer based on availability of nltk
if nltk.download('punkt', quiet=True):
tokenize = nltk.word_tokenize
else:
warnings.warn('nltk punkt is not available, using str.split instead to identify individual words. '
'Please check your internet connection.')
tokenize = str.split

# Tokenize samples and count unknown words
words_array = [tokenize(sample) for sample in samples]
for idx, words in zip(indices, words_array):
total_tokens += len(words)
for word in words:
tokens = self.tokenizer.tokenize(word)
if any(self.tokenizer.convert_tokens_to_ids(token) == self.tokenizer.unk_token_id for token in
tokens):
all_unknown_tokens.append(word)
unknown_token_indexes.setdefault(word, []).append(idx)

return Counter(all_unknown_tokens), total_tokens, unknown_token_indexes

def create_pie_chart(self, all_unknown_words_counter, total_words):
"""Create pie chart with most common unknown words."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# The check has several key parameters that affect its behavior and output:
#
# * `tokenizer`: Tokenizer from the HuggingFace transformers library to use for tokenization. If None,
# BertTokenizer.from_pretrained('bert-base-uncased') will be used.
# AutoTokenizer.from_pretrained('bert-base-uncased') will be used. It's highly recommended to use a fast tokenizer.
# * `group_singleton_words`: If True, group all words that appear only once in the data into the "Other" category in
# the display.

Expand All @@ -66,8 +66,8 @@
#
# We can also use a different tokenizer, such as the GPT2 tokenizer, to see how the results change.

from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')

UnknownTokens(tokenizer=tokenizer).run(dataset)

Expand Down
1 change: 1 addition & 0 deletions spelling-allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,4 @@ fasttext
misclassified
Uncomment
dimensionality
tokenization
17 changes: 15 additions & 2 deletions tests/nlp/checks/data_integrity/unknown_tokens_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""Test for the NLP UnknownTokens check."""
import pytest
from hamcrest import *
from transformers import GPT2Tokenizer
from transformers import AutoTokenizer, BertTokenizer

from deepchecks.core.errors import DeepchecksValueError
from deepchecks.nlp.checks import UnknownTokens
Expand Down Expand Up @@ -129,6 +129,19 @@ def test_with_unknown_tokens(dataset_with_unknown_tokens):
}))


def test_compare_fast_to_slow_tokenizer(dataset_with_unknown_tokens):
# Arrange
check = UnknownTokens()
check_slow = UnknownTokens(tokenizer=BertTokenizer.from_pretrained("bert-base-uncased"))

# Act
result = check.run(dataset=dataset_with_unknown_tokens)
result_slow = check_slow.run(dataset=dataset_with_unknown_tokens)

# Assert
assert_that(result.value, equal_to(result_slow.value))


def test_group_singleton_words_true(dataset_with_reoccurring_unknown_words):
# Arrange
check = UnknownTokens(group_singleton_words=True).add_condition_ratio_of_unknown_words_less_or_equal()
Expand Down Expand Up @@ -158,7 +171,7 @@ def test_group_singleton_words_true(dataset_with_reoccurring_unknown_words):

def test_with_more_robust_tokenizer(dataset_with_unknown_tokens):
# Arrange
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
check = UnknownTokens(tokenizer=tokenizer)

# Act
Expand Down