Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noam/dee 521 unknown token check #2483

Merged
merged 10 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepchecks/nlp/checks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""Module importing all nlp checks."""

from deepchecks.nlp.checks.data_integrity import (ConflictingLabels, PropertyLabelCorrelation, SpecialCharacters,
TextDuplicates, TextPropertyOutliers)
TextDuplicates, TextPropertyOutliers, UnknownTokens)
from deepchecks.nlp.checks.model_evaluation import (ConfusionMatrixReport, MetadataSegmentsPerformance, PredictionDrift,
PropertySegmentsPerformance, SingleDatasetPerformance,
TrainTestPerformance)
Expand All @@ -24,6 +24,7 @@
'TextDuplicates',
'ConflictingLabels',
'SpecialCharacters',
'UnknownTokens',

# Model Evaluation
'SingleDatasetPerformance',
Expand Down
4 changes: 3 additions & 1 deletion deepchecks/nlp/checks/data_integrity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from .special_characters import SpecialCharacters
from .text_duplicates import TextDuplicates
from .text_property_outliers import TextPropertyOutliers
from .unknown_tokens import UnknownTokens

__all__ = [
'PropertyLabelCorrelation',
'TextPropertyOutliers',
'TextDuplicates',
'ConflictingLabels',
'SpecialCharacters'
'SpecialCharacters',
'UnknownTokens',
]
23 changes: 1 addition & 22 deletions deepchecks/nlp/checks/data_integrity/special_characters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class SpecialCharacters(SingleDatasetCheck):
----------
special_characters_whitelist: Union[str, Sequence[str]] , default ' ' + string.punctuation
set of special characters to ignore. Punctuation (string.punctuation) is whitelisted by default.
{text_normalization_params:1*indent}
n_most_common : int , default: 2
n_most_common : int , default: 10
Number of most common special-only samples to show in results
n_samples: int, default: 10_000_000
number of samples to use for this check.
Expand All @@ -61,11 +60,6 @@ class SpecialCharacters(SingleDatasetCheck):
def __init__(
self,
special_characters_whitelist: t.Union[str, t.Sequence[str], None] = None,
ignore_case: bool = True,
remove_punctuation: bool = True,
normalize_unicode: bool = True,
remove_stopwords: bool = True,
ignore_whitespace: bool = False,
n_most_common: int = 10,
n_samples: int = 10_000_000,
random_state: int = 42,
Expand All @@ -81,26 +75,11 @@ def __init__(
self.special_characters = self.SPECIAL_CHARACTERS.difference(
self.special_characters_whitelist
)
self.ignore_case = ignore_case
self.remove_punctuation = remove_punctuation
self.normalize_unicode = normalize_unicode
self.remove_stopwords = remove_stopwords
self.ignore_whitespace = ignore_whitespace
self.n_most_common = n_most_common
self.n_samples = n_samples
self.random_state = random_state
self.max_text_length_for_display = max_text_length_for_display

@property
def _text_normalization_kwargs(self):
return {
'ignore_case': self.ignore_case,
'ignore_whitespace': self.ignore_whitespace,
'normalize_uni': self.normalize_unicode,
'remove_punct': self.remove_punctuation,
'remove_stops': self.remove_stopwords,
}

def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check."""
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state)
Expand Down
175 changes: 175 additions & 0 deletions deepchecks/nlp/checks/data_integrity/unknown_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2023 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module contains the Unknown Tokens check."""
import typing as t
from collections import Counter

import nltk
import plotly.graph_objects as go
from transformers import BertTokenizer

from deepchecks.core import CheckResult, ConditionCategory, ConditionResult
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.nlp import Context, SingleDatasetCheck
from deepchecks.nlp._shared_docs import docstrings
from deepchecks.nlp.text_data import TextData
from deepchecks.utils.strings import format_percent
from deepchecks.utils.strings import get_ellipsis as truncate_string

__all__ = ['UnknownTokens']


@docstrings
class UnknownTokens(SingleDatasetCheck):
"""Find samples that contain tokens unsupported by your tokenizer.

Parameters
----------
tokenizer: t.Any , default: None
noamzbr marked this conversation as resolved.
Show resolved Hide resolved
Transformers tokenizer to use for tokenization. If None, BertTokenizer.from_pretrained('bert-base-uncased')
will be used.
n_most_common : int , default: 5
Number of most common words with unknown tokens to show in results
n_samples: int, default: 1_000_000
number of samples to use for this check.
random_state : int, default: 42
random seed for all check internals.
{max_text_length_for_display_param:1*indent}
"""

def __init__(
self,
tokenizer: t.Any = None,
n_most_common: int = 5,
n_samples: int = 1_000_000,
random_state: int = 42,
max_text_length_for_display: int = 30,
**kwargs
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
if self.tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
if not hasattr(self.tokenizer, 'tokenize'):
raise DeepchecksValueError('tokenizer must have a "tokenize" method')
self.n_most_common = n_most_common
self.n_samples = n_samples
self.random_state = random_state
self.max_text_length_for_display = max_text_length_for_display

def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check."""
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state)
dataset = t.cast(TextData, dataset)
samples = dataset.text
if len(samples) == 0:
raise DeepchecksValueError('Dataset cannot be empty')

indices = dataset.get_original_text_indexes()

all_unknown_words_counter, total_words, unknown_word_indexes = self.find_unknown_words(samples, indices)
if len(all_unknown_words_counter) == 0:
display = None
value = {'unknown_word_ratio': 0,
'unknown_word_details': {}}
else:
fig = self.create_pie_chart(all_unknown_words_counter, total_words)
display = [fig]

# The value contains two fields - unknown_word_percent and unknown_word_details.
# The latter contains a dict, in which for each word we have its ratio of the data and the list of indexes
# of the samples that contain it.
unknown_word_details = {}
for word, indexes in unknown_word_indexes.items():
unknown_word_details[word] = {'ratio': all_unknown_words_counter[word] / total_words,
'indexes': indexes}
value = {'unknown_word_ratio': sum(all_unknown_words_counter.values()) / total_words,
'unknown_word_details': unknown_word_details}

return CheckResult(value, display=display)

def find_unknown_words(self, samples, indices):
"""Find words with unknown tokens in samples."""
# Choose tokenizer based on availability of nltk
if nltk.download('punkt'):
tokenize = nltk.word_tokenize
Nadav-Barak marked this conversation as resolved.
Show resolved Hide resolved
else:
tokenize = str.split

# Tokenize samples and count unknown words
words_array = [tokenize(sample) for sample in samples]
all_unknown_words = []
unknown_word_indexes = {}
total_words = 0
for idx, words in zip(indices, words_array):
total_words += len(words)
for word in words:
tokens = self.tokenizer.tokenize(word)
if any(self.tokenizer.convert_tokens_to_ids(token) == self.tokenizer.unk_token_id for token in tokens):
noamzbr marked this conversation as resolved.
Show resolved Hide resolved
all_unknown_words.append(word)
unknown_word_indexes.setdefault(word, []).append(idx)

return Counter(all_unknown_words), total_words, unknown_word_indexes

def create_pie_chart(self, all_unknown_words_counter, total_words):
"""Create pie chart with unknown words distribution."""
# Separate most common unknown words and other unknown words
most_common_unknown_words = [x[0] for x in all_unknown_words_counter.most_common(self.n_most_common)
if x[1] > 1]
other_words = [x for x in all_unknown_words_counter if x not in most_common_unknown_words]

# Calculate percentages for each category
other_words_count = sum(all_unknown_words_counter[word] for word in other_words)
other_words_percentage = (other_words_count * 1. / total_words) * 100.
labels = most_common_unknown_words
percentages = [all_unknown_words_counter[word] * 1. / total_words * 100. for word in most_common_unknown_words]

# Add "Other Unknown Words" and "Known Words" categories
if other_words_percentage > 0:
labels.append('Other Unknown Words')
percentages.append(other_words_percentage)
labels.append('Known Words')
percentages.append(100. - sum(percentages))

# Truncate labels for display
labels = [truncate_string(label, self.max_text_length_for_display) for label in labels]

# Create pie chart with hover text and custom hover template
fig = go.Figure(data=[go.Pie(
labels=labels, values=percentages, textinfo='label+percent',
hovertext=[' '.join(other_words[:self.max_text_length_for_display]) if label == 'Other Unknown Words'
else label for label in labels],
hovertemplate='%{hovertext}<br>%{percent}<extra></extra>',
pull=[0.1 if label != 'Known Words' else 0 for label in labels]
)])

# Customize chart appearance
fig.update_layout(title=f'Words containing Unknown Tokens - {self.tokenizer.name_or_path} Tokenizer',
legend_title='Words')

return fig

def add_condition_ratio_of_unknown_words_less_or_equal(self, ratio: float = 0):
"""Add condition that checks if the ratio of unknown words is less than a given ratio.

Parameters
----------
ratio : float
Maximal allowed ratio of unknown words.
"""
def condition(result):
passed = result['unknown_word_ratio'] <= ratio
condition_result = ConditionCategory.FAIL if not passed else ConditionCategory.PASS
details = f'Ratio was {format_percent(result["unknown_word_ratio"])}'
return ConditionResult(condition_result, details)

return self.add_condition(f'Ratio of unknown words is less than {format_percent(ratio)}',
condition)
6 changes: 4 additions & 2 deletions deepchecks/nlp/suites/default_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from deepchecks.nlp import Suite
from deepchecks.nlp.checks import (ConflictingLabels, LabelDrift, MetadataSegmentsPerformance, PredictionDrift,
PropertyLabelCorrelation, PropertySegmentsPerformance, SingleDatasetPerformance,
SpecialCharacters, TextDuplicates, TextPropertyOutliers, TrainTestSamplesMix)
SpecialCharacters, TextDuplicates, TextPropertyOutliers, TrainTestSamplesMix,
UnknownTokens)

__all__ = ['data_integrity', 'train_test_validation',
'model_evaluation', 'full_suite']
Expand Down Expand Up @@ -61,7 +62,8 @@ def data_integrity(n_samples: int = None,
TextPropertyOutliers(),
TextDuplicates().add_condition_ratio_less_or_equal(),
ConflictingLabels().add_condition_ratio_of_conflicting_labels_less_or_equal(),
SpecialCharacters().add_condition_ratio_of_samples_with_special_characters_less_or_equal()
SpecialCharacters().add_condition_ratio_of_samples_with_special_characters_less_or_equal(),
UnknownTokens().add_condition_ratio_of_unknown_words_less_or_equal()
)


Expand Down
5 changes: 4 additions & 1 deletion spelling-allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,7 @@ NLP
embeddings
ONNX
f1
multiindex
multiindex
tokenizer
nltk
Tokenize