Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noam/dee 521 unknown token check #2483

Merged
merged 10 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepchecks/nlp/checks/__init__.py
Expand Up @@ -11,7 +11,7 @@
"""Module importing all nlp checks."""

from deepchecks.nlp.checks.data_integrity import (ConflictingLabels, PropertyLabelCorrelation, SpecialCharacters,
TextDuplicates, TextPropertyOutliers)
TextDuplicates, TextPropertyOutliers, UnknownTokens)
from deepchecks.nlp.checks.model_evaluation import (ConfusionMatrixReport, MetadataSegmentsPerformance, PredictionDrift,
PropertySegmentsPerformance, SingleDatasetPerformance,
TrainTestPerformance)
Expand All @@ -24,6 +24,7 @@
'TextDuplicates',
'ConflictingLabels',
'SpecialCharacters',
'UnknownTokens',

# Model Evaluation
'SingleDatasetPerformance',
Expand Down
4 changes: 3 additions & 1 deletion deepchecks/nlp/checks/data_integrity/__init__.py
Expand Up @@ -15,11 +15,13 @@
from .special_characters import SpecialCharacters
from .text_duplicates import TextDuplicates
from .text_property_outliers import TextPropertyOutliers
from .unknown_tokens import UnknownTokens

__all__ = [
'PropertyLabelCorrelation',
'TextPropertyOutliers',
'TextDuplicates',
'ConflictingLabels',
'SpecialCharacters'
'SpecialCharacters',
'UnknownTokens',
]
23 changes: 1 addition & 22 deletions deepchecks/nlp/checks/data_integrity/special_characters.py
Expand Up @@ -45,8 +45,7 @@ class SpecialCharacters(SingleDatasetCheck):
----------
special_characters_whitelist: Union[str, Sequence[str]] , default ' ' + string.punctuation
set of special characters to ignore. Punctuation (string.punctuation) is whitelisted by default.
{text_normalization_params:1*indent}
n_most_common : int , default: 2
n_most_common : int , default: 10
Number of most common special-only samples to show in results
n_samples: int, default: 10_000_000
number of samples to use for this check.
Expand All @@ -61,11 +60,6 @@ class SpecialCharacters(SingleDatasetCheck):
def __init__(
self,
special_characters_whitelist: t.Union[str, t.Sequence[str], None] = None,
ignore_case: bool = True,
remove_punctuation: bool = True,
normalize_unicode: bool = True,
remove_stopwords: bool = True,
ignore_whitespace: bool = False,
n_most_common: int = 10,
n_samples: int = 10_000_000,
random_state: int = 42,
Expand All @@ -81,26 +75,11 @@ def __init__(
self.special_characters = self.SPECIAL_CHARACTERS.difference(
self.special_characters_whitelist
)
self.ignore_case = ignore_case
self.remove_punctuation = remove_punctuation
self.normalize_unicode = normalize_unicode
self.remove_stopwords = remove_stopwords
self.ignore_whitespace = ignore_whitespace
self.n_most_common = n_most_common
self.n_samples = n_samples
self.random_state = random_state
self.max_text_length_for_display = max_text_length_for_display

@property
def _text_normalization_kwargs(self):
return {
'ignore_case': self.ignore_case,
'ignore_whitespace': self.ignore_whitespace,
'normalize_uni': self.normalize_unicode,
'remove_punct': self.remove_punctuation,
'remove_stops': self.remove_stopwords,
}

def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check."""
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state)
Expand Down
190 changes: 190 additions & 0 deletions deepchecks/nlp/checks/data_integrity/unknown_tokens.py
@@ -0,0 +1,190 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2023 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module contains the Unknown Tokens check."""
import typing as t
from collections import Counter

import nltk
import plotly.graph_objects as go
from transformers import BertTokenizer

from deepchecks.core import CheckResult, ConditionCategory, ConditionResult
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.nlp import Context, SingleDatasetCheck
from deepchecks.nlp._shared_docs import docstrings
from deepchecks.nlp.text_data import TextData
from deepchecks.utils.strings import format_list, format_percent
from deepchecks.utils.strings import get_ellipsis as truncate_string

__all__ = ['UnknownTokens']

OTHER_CAT_NAME = 'Other Unknown Words'


@docstrings
class UnknownTokens(SingleDatasetCheck):
"""Find samples that contain tokens unsupported by your tokenizer.

Parameters
----------
tokenizer: t.Any , default: None
noamzbr marked this conversation as resolved.
Show resolved Hide resolved
Tokenizer from the HuggingFace transformers library to use for tokenization. If None,
BertTokenizer.from_pretrained('bert-base-uncased') will be used.
group_singleton_words: bool, default: False
If True, group all words that appear only once in the data into the "Other" category in the display.
n_most_common : int , default: 5
Number of most common words with unknown tokens to show in the display.
n_samples: int, default: 1_000_000
number of samples to use for this check.
random_state : int, default: 42
random seed for all check internals.
{max_text_length_for_display_param:1*indent}
"""

def __init__(
self,
tokenizer: t.Any = None,
group_singleton_words: bool = False,
n_most_common: int = 5,
n_samples: int = 1_000_000,
random_state: int = 42,
max_text_length_for_display: int = 30,
**kwargs
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
if self.tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
if not hasattr(self.tokenizer, 'tokenize'):
raise DeepchecksValueError('tokenizer must have a "tokenize" method')
if not hasattr(self.tokenizer, 'unk_token_id'):
raise DeepchecksValueError('tokenizer must have an "unk_token_id" attribute')
if not hasattr(self.tokenizer, 'convert_tokens_to_ids'):
raise DeepchecksValueError('tokenizer must have an "convert_tokens_to_ids" method')
self.group_singleton_words = group_singleton_words
self.n_most_common = n_most_common
self.n_samples = n_samples
self.random_state = random_state
self.max_text_length_for_display = max_text_length_for_display

def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check."""
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state)
dataset = t.cast(TextData, dataset)
samples = dataset.text
if len(samples) == 0:
raise DeepchecksValueError('Dataset cannot be empty')

indices = dataset.get_original_text_indexes()

all_unknown_words_counter, total_words, unknown_word_indexes = self.find_unknown_words(samples, indices)
if len(all_unknown_words_counter) == 0:
display = None
value = {'unknown_word_ratio': 0,
'unknown_word_details': {}}
else:
fig = self.create_pie_chart(all_unknown_words_counter, total_words)
percent_explanation = (
'<p style="font-size:0.9em;line-height:1;"><i>'
'Percents shown above are the percent of each word (or group of words) out of all words in the data.'
)
display = [fig, percent_explanation]

# The value contains two fields - unknown_word_percent and unknown_word_details.
# The latter contains a dict, in which for each word we have its ratio of the data and the list of indexes
# of the samples that contain it.
unknown_word_details = {}
for word, indexes in unknown_word_indexes.items():
unknown_word_details[word] = {'ratio': all_unknown_words_counter[word] / total_words,
'indexes': indexes}
value = {'unknown_word_ratio': sum(all_unknown_words_counter.values()) / total_words,
'unknown_word_details': unknown_word_details}

return CheckResult(value, display=display)

def find_unknown_words(self, samples, indices):
"""Find words with unknown tokens in samples."""
# Choose tokenizer based on availability of nltk
if nltk.download('punkt'):
tokenize = nltk.word_tokenize
Nadav-Barak marked this conversation as resolved.
Show resolved Hide resolved
else:
tokenize = str.split

# Tokenize samples and count unknown words
words_array = [tokenize(sample) for sample in samples]
all_unknown_words = []
unknown_word_indexes = {}
total_words = 0
for idx, words in zip(indices, words_array):
total_words += len(words)
for word in words:
tokens = self.tokenizer.tokenize(word)
if any(self.tokenizer.convert_tokens_to_ids(token) == self.tokenizer.unk_token_id for token in tokens):
noamzbr marked this conversation as resolved.
Show resolved Hide resolved
all_unknown_words.append(word)
unknown_word_indexes.setdefault(word, []).append(idx)

return Counter(all_unknown_words), total_words, unknown_word_indexes

def create_pie_chart(self, all_unknown_words_counter, total_words):
"""Create pie chart with most common unknown words."""
most_common_unknown_words = [x[0] for x in all_unknown_words_counter.most_common(self.n_most_common) if
((x[1] > 1) or (not self.group_singleton_words))]
other_words = [x for x in all_unknown_words_counter if x not in most_common_unknown_words]

# Calculate percentages for each category
other_words_count = sum(all_unknown_words_counter[word] for word in other_words)
other_words_percentage = (other_words_count * 1. / total_words) * 100.
labels = most_common_unknown_words
percentages = [all_unknown_words_counter[word] * 1. / total_words * 100. for word in most_common_unknown_words]

# Add "Other Unknown Words" and "Known Words" categories
if other_words_percentage > 0:
labels.append(OTHER_CAT_NAME)
percentages.append(other_words_percentage)

# Truncate labels for display
labels = [truncate_string(label, self.max_text_length_for_display) for label in labels]

# Create pie chart with hover text and custom hover template
fig = go.Figure(data=[go.Pie(
labels=labels, values=percentages, texttemplate='%{label}<br>%{value}%',
hovertext=[format_list(other_words, max_string_length=self.max_text_length_for_display)
if label == OTHER_CAT_NAME else label for label in labels],
hovertemplate=['<b>Unknown Word</b>: %{hovertext}<br><b>Percent of All Words</b>: %{value}%<extra></extra>'
if label != OTHER_CAT_NAME else
'<b>Other Unknown Words</b>: %{hovertext}<br>'
'<b>Percent of All Words</b>: %{value}%<extra></extra>'
for label in labels],
pull=[0.1 if label == OTHER_CAT_NAME else 0 for label in labels]
)])

# Customize chart appearance
fig.update_layout(title=f'Words containing Unknown Tokens - {self.tokenizer.name_or_path} Tokenizer',
legend_title='Words with Unknown Tokens')

return fig

def add_condition_ratio_of_unknown_words_less_or_equal(self, ratio: float = 0):
"""Add condition that checks if the ratio of unknown words is less than a given ratio.

Parameters
----------
ratio : float
Maximal allowed ratio of unknown words.
"""
def condition(result):
passed = result['unknown_word_ratio'] <= ratio
condition_result = ConditionCategory.FAIL if not passed else ConditionCategory.PASS
details = f'Ratio was {format_percent(result["unknown_word_ratio"])}'
return ConditionResult(condition_result, details)

return self.add_condition(f'Ratio of unknown words is less than {format_percent(ratio)}',
condition)
6 changes: 4 additions & 2 deletions deepchecks/nlp/suites/default_suites.py
Expand Up @@ -18,7 +18,8 @@
from deepchecks.nlp import Suite
from deepchecks.nlp.checks import (ConflictingLabels, LabelDrift, MetadataSegmentsPerformance, PredictionDrift,
PropertyLabelCorrelation, PropertySegmentsPerformance, SingleDatasetPerformance,
SpecialCharacters, TextDuplicates, TextPropertyOutliers, TrainTestSamplesMix)
SpecialCharacters, TextDuplicates, TextPropertyOutliers, TrainTestSamplesMix,
UnknownTokens)

__all__ = ['data_integrity', 'train_test_validation',
'model_evaluation', 'full_suite']
Expand Down Expand Up @@ -61,7 +62,8 @@ def data_integrity(n_samples: int = None,
TextPropertyOutliers(),
TextDuplicates().add_condition_ratio_less_or_equal(),
ConflictingLabels().add_condition_ratio_of_conflicting_labels_less_or_equal(),
SpecialCharacters().add_condition_ratio_of_samples_with_special_characters_less_or_equal()
SpecialCharacters().add_condition_ratio_of_samples_with_special_characters_less_or_equal(),
UnknownTokens().add_condition_ratio_of_unknown_words_less_or_equal()
)


Expand Down
3 changes: 2 additions & 1 deletion requirements/dev-requirements.txt
Expand Up @@ -48,4 +48,5 @@ beautifulsoup4>=4.11.1
nltk<=3.6.7
datasets
langdetect
textblob
textblob
transformers
5 changes: 4 additions & 1 deletion spelling-allowlist.txt
Expand Up @@ -144,4 +144,7 @@ NLP
embeddings
ONNX
f1
multiindex
multiindex
tokenizer
nltk
Tokenize