Skip to content

Commit

Permalink
Add unknown token check (deepchecks#2483)
Browse files Browse the repository at this point in the history
  • Loading branch information
noamzbr authored and kishore-s-15 committed May 7, 2023
1 parent 351c351 commit c49075d
Show file tree
Hide file tree
Showing 8 changed files with 391 additions and 28 deletions.
3 changes: 2 additions & 1 deletion deepchecks/nlp/checks/__init__.py
Expand Up @@ -11,7 +11,7 @@
"""Module importing all nlp checks."""

from deepchecks.nlp.checks.data_integrity import (ConflictingLabels, PropertyLabelCorrelation, SpecialCharacters,
TextDuplicates, TextPropertyOutliers)
TextDuplicates, TextPropertyOutliers, UnknownTokens)
from deepchecks.nlp.checks.model_evaluation import (ConfusionMatrixReport, MetadataSegmentsPerformance, PredictionDrift,
PropertySegmentsPerformance, SingleDatasetPerformance,
TrainTestPerformance)
Expand All @@ -24,6 +24,7 @@
'TextDuplicates',
'ConflictingLabels',
'SpecialCharacters',
'UnknownTokens',

# Model Evaluation
'SingleDatasetPerformance',
Expand Down
4 changes: 3 additions & 1 deletion deepchecks/nlp/checks/data_integrity/__init__.py
Expand Up @@ -15,11 +15,13 @@
from .special_characters import SpecialCharacters
from .text_duplicates import TextDuplicates
from .text_property_outliers import TextPropertyOutliers
from .unknown_tokens import UnknownTokens

__all__ = [
'PropertyLabelCorrelation',
'TextPropertyOutliers',
'TextDuplicates',
'ConflictingLabels',
'SpecialCharacters'
'SpecialCharacters',
'UnknownTokens',
]
23 changes: 1 addition & 22 deletions deepchecks/nlp/checks/data_integrity/special_characters.py
Expand Up @@ -45,8 +45,7 @@ class SpecialCharacters(SingleDatasetCheck):
----------
special_characters_whitelist: Union[str, Sequence[str]] , default ' ' + string.punctuation
set of special characters to ignore. Punctuation (string.punctuation) is whitelisted by default.
{text_normalization_params:1*indent}
n_most_common : int , default: 2
n_most_common : int , default: 10
Number of most common special-only samples to show in results
n_samples: int, default: 10_000_000
number of samples to use for this check.
Expand All @@ -61,11 +60,6 @@ class SpecialCharacters(SingleDatasetCheck):
def __init__(
self,
special_characters_whitelist: t.Union[str, t.Sequence[str], None] = None,
ignore_case: bool = True,
remove_punctuation: bool = True,
normalize_unicode: bool = True,
remove_stopwords: bool = True,
ignore_whitespace: bool = False,
n_most_common: int = 10,
n_samples: int = 10_000_000,
random_state: int = 42,
Expand All @@ -81,26 +75,11 @@ def __init__(
self.special_characters = self.SPECIAL_CHARACTERS.difference(
self.special_characters_whitelist
)
self.ignore_case = ignore_case
self.remove_punctuation = remove_punctuation
self.normalize_unicode = normalize_unicode
self.remove_stopwords = remove_stopwords
self.ignore_whitespace = ignore_whitespace
self.n_most_common = n_most_common
self.n_samples = n_samples
self.random_state = random_state
self.max_text_length_for_display = max_text_length_for_display

@property
def _text_normalization_kwargs(self):
return {
'ignore_case': self.ignore_case,
'ignore_whitespace': self.ignore_whitespace,
'normalize_uni': self.normalize_unicode,
'remove_punct': self.remove_punctuation,
'remove_stops': self.remove_stopwords,
}

def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check."""
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state)
Expand Down
190 changes: 190 additions & 0 deletions deepchecks/nlp/checks/data_integrity/unknown_tokens.py
@@ -0,0 +1,190 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2023 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module contains the Unknown Tokens check."""
import typing as t
from collections import Counter

import nltk
import plotly.graph_objects as go
from transformers import BertTokenizer

from deepchecks.core import CheckResult, ConditionCategory, ConditionResult
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.nlp import Context, SingleDatasetCheck
from deepchecks.nlp._shared_docs import docstrings
from deepchecks.nlp.text_data import TextData
from deepchecks.utils.strings import format_list, format_percent
from deepchecks.utils.strings import get_ellipsis as truncate_string

__all__ = ['UnknownTokens']

OTHER_CAT_NAME = 'Other Unknown Words'


@docstrings
class UnknownTokens(SingleDatasetCheck):
"""Find samples that contain tokens unsupported by your tokenizer.
Parameters
----------
tokenizer: t.Any , default: None
Tokenizer from the HuggingFace transformers library to use for tokenization. If None,
BertTokenizer.from_pretrained('bert-base-uncased') will be used.
group_singleton_words: bool, default: False
If True, group all words that appear only once in the data into the "Other" category in the display.
n_most_common : int , default: 5
Number of most common words with unknown tokens to show in the display.
n_samples: int, default: 1_000_000
number of samples to use for this check.
random_state : int, default: 42
random seed for all check internals.
{max_text_length_for_display_param:1*indent}
"""

def __init__(
self,
tokenizer: t.Any = None,
group_singleton_words: bool = False,
n_most_common: int = 5,
n_samples: int = 1_000_000,
random_state: int = 42,
max_text_length_for_display: int = 30,
**kwargs
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
if self.tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
if not hasattr(self.tokenizer, 'tokenize'):
raise DeepchecksValueError('tokenizer must have a "tokenize" method')
if not hasattr(self.tokenizer, 'unk_token_id'):
raise DeepchecksValueError('tokenizer must have an "unk_token_id" attribute')
if not hasattr(self.tokenizer, 'convert_tokens_to_ids'):
raise DeepchecksValueError('tokenizer must have an "convert_tokens_to_ids" method')
self.group_singleton_words = group_singleton_words
self.n_most_common = n_most_common
self.n_samples = n_samples
self.random_state = random_state
self.max_text_length_for_display = max_text_length_for_display

def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check."""
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state)
dataset = t.cast(TextData, dataset)
samples = dataset.text
if len(samples) == 0:
raise DeepchecksValueError('Dataset cannot be empty')

indices = dataset.get_original_text_indexes()

all_unknown_words_counter, total_words, unknown_word_indexes = self.find_unknown_words(samples, indices)
if len(all_unknown_words_counter) == 0:
display = None
value = {'unknown_word_ratio': 0,
'unknown_word_details': {}}
else:
fig = self.create_pie_chart(all_unknown_words_counter, total_words)
percent_explanation = (
'<p style="font-size:0.9em;line-height:1;"><i>'
'Percents shown above are the percent of each word (or group of words) out of all words in the data.'
)
display = [fig, percent_explanation]

# The value contains two fields - unknown_word_percent and unknown_word_details.
# The latter contains a dict, in which for each word we have its ratio of the data and the list of indexes
# of the samples that contain it.
unknown_word_details = {}
for word, indexes in unknown_word_indexes.items():
unknown_word_details[word] = {'ratio': all_unknown_words_counter[word] / total_words,
'indexes': indexes}
value = {'unknown_word_ratio': sum(all_unknown_words_counter.values()) / total_words,
'unknown_word_details': unknown_word_details}

return CheckResult(value, display=display)

def find_unknown_words(self, samples, indices):
"""Find words with unknown tokens in samples."""
# Choose tokenizer based on availability of nltk
if nltk.download('punkt'):
tokenize = nltk.word_tokenize
else:
tokenize = str.split

# Tokenize samples and count unknown words
words_array = [tokenize(sample) for sample in samples]
all_unknown_words = []
unknown_word_indexes = {}
total_words = 0
for idx, words in zip(indices, words_array):
total_words += len(words)
for word in words:
tokens = self.tokenizer.tokenize(word)
if any(self.tokenizer.convert_tokens_to_ids(token) == self.tokenizer.unk_token_id for token in tokens):
all_unknown_words.append(word)
unknown_word_indexes.setdefault(word, []).append(idx)

return Counter(all_unknown_words), total_words, unknown_word_indexes

def create_pie_chart(self, all_unknown_words_counter, total_words):
"""Create pie chart with most common unknown words."""
most_common_unknown_words = [x[0] for x in all_unknown_words_counter.most_common(self.n_most_common) if
((x[1] > 1) or (not self.group_singleton_words))]
other_words = [x for x in all_unknown_words_counter if x not in most_common_unknown_words]

# Calculate percentages for each category
other_words_count = sum(all_unknown_words_counter[word] for word in other_words)
other_words_percentage = (other_words_count * 1. / total_words) * 100.
labels = most_common_unknown_words
percentages = [all_unknown_words_counter[word] * 1. / total_words * 100. for word in most_common_unknown_words]

# Add "Other Unknown Words" and "Known Words" categories
if other_words_percentage > 0:
labels.append(OTHER_CAT_NAME)
percentages.append(other_words_percentage)

# Truncate labels for display
labels = [truncate_string(label, self.max_text_length_for_display) for label in labels]

# Create pie chart with hover text and custom hover template
fig = go.Figure(data=[go.Pie(
labels=labels, values=percentages, texttemplate='%{label}<br>%{value}%',
hovertext=[format_list(other_words, max_string_length=self.max_text_length_for_display)
if label == OTHER_CAT_NAME else label for label in labels],
hovertemplate=['<b>Unknown Word</b>: %{hovertext}<br><b>Percent of All Words</b>: %{value}%<extra></extra>'
if label != OTHER_CAT_NAME else
'<b>Other Unknown Words</b>: %{hovertext}<br>'
'<b>Percent of All Words</b>: %{value}%<extra></extra>'
for label in labels],
pull=[0.1 if label == OTHER_CAT_NAME else 0 for label in labels]
)])

# Customize chart appearance
fig.update_layout(title=f'Words containing Unknown Tokens - {self.tokenizer.name_or_path} Tokenizer',
legend_title='Words with Unknown Tokens')

return fig

def add_condition_ratio_of_unknown_words_less_or_equal(self, ratio: float = 0):
"""Add condition that checks if the ratio of unknown words is less than a given ratio.
Parameters
----------
ratio : float
Maximal allowed ratio of unknown words.
"""
def condition(result):
passed = result['unknown_word_ratio'] <= ratio
condition_result = ConditionCategory.FAIL if not passed else ConditionCategory.PASS
details = f'Ratio was {format_percent(result["unknown_word_ratio"])}'
return ConditionResult(condition_result, details)

return self.add_condition(f'Ratio of unknown words is less than {format_percent(ratio)}',
condition)
6 changes: 4 additions & 2 deletions deepchecks/nlp/suites/default_suites.py
Expand Up @@ -18,7 +18,8 @@
from deepchecks.nlp import Suite
from deepchecks.nlp.checks import (ConflictingLabels, LabelDrift, MetadataSegmentsPerformance, PredictionDrift,
PropertyLabelCorrelation, PropertySegmentsPerformance, SingleDatasetPerformance,
SpecialCharacters, TextDuplicates, TextPropertyOutliers, TrainTestSamplesMix)
SpecialCharacters, TextDuplicates, TextPropertyOutliers, TrainTestSamplesMix,
UnknownTokens)

__all__ = ['data_integrity', 'train_test_validation',
'model_evaluation', 'full_suite']
Expand Down Expand Up @@ -61,7 +62,8 @@ def data_integrity(n_samples: int = None,
TextPropertyOutliers(),
TextDuplicates().add_condition_ratio_less_or_equal(),
ConflictingLabels().add_condition_ratio_of_conflicting_labels_less_or_equal(),
SpecialCharacters().add_condition_ratio_of_samples_with_special_characters_less_or_equal()
SpecialCharacters().add_condition_ratio_of_samples_with_special_characters_less_or_equal(),
UnknownTokens().add_condition_ratio_of_unknown_words_less_or_equal()
)


Expand Down
3 changes: 2 additions & 1 deletion requirements/dev-requirements.txt
Expand Up @@ -48,4 +48,5 @@ beautifulsoup4>=4.11.1
nltk<=3.6.7
datasets
langdetect
textblob
textblob
transformers
5 changes: 4 additions & 1 deletion spelling-allowlist.txt
Expand Up @@ -145,4 +145,7 @@ embeddings
ONNX
f1
multiindex
misclassified
misclassified
tokenizer
nltk
Tokenize

0 comments on commit c49075d

Please sign in to comment.