Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noam/dee 538 better handle download of nltk #2493

Merged
merged 2 commits into from May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion deepchecks/nlp/checks/data_integrity/unknown_tokens.py
Expand Up @@ -10,6 +10,7 @@
#
"""Module contains the Unknown Tokens check."""
import typing as t
import warnings
from collections import Counter

import nltk
Expand Down Expand Up @@ -113,9 +114,11 @@ def run_logic(self, context: Context, dataset_kind) -> CheckResult:
def find_unknown_words(self, samples, indices):
"""Find words with unknown tokens in samples."""
# Choose tokenizer based on availability of nltk
if nltk.download('punkt'):
if nltk.download('punkt', quiet=True):
tokenize = nltk.word_tokenize
else:
warnings.warn('nltk punkt is not available, using str.split instead to identify individual words. '
'Please check your internet connection.')
tokenize = str.split

# Tokenize samples and count unknown words
Expand Down
18 changes: 12 additions & 6 deletions deepchecks/nlp/utils/text.py
Expand Up @@ -12,6 +12,7 @@
import string
import typing as t
import unicodedata
import warnings

import nltk
from nltk.corpus import stopwords
Expand Down Expand Up @@ -62,10 +63,6 @@ def break_to_lines_and_trim(s, max_lines: int = 10, min_line_length: int = 50, m
return '<br>'.join(lines)


nltk.download('stopwords')
nltk.download('punkt')


def remove_punctuation(text: str) -> str:
"""Remove punctuation characters from a string."""
return text.translate(str.maketrans('', '', string.punctuation))
Expand All @@ -78,8 +75,17 @@ def normalize_unicode(text: str) -> str:

def remove_stopwords(text: str) -> str:
"""Remove stop words from a string."""
stop_words = set(stopwords.words('english'))
words = word_tokenize(text)
if nltk.download('stopwords', quiet=True):
stop_words = set(stopwords.words('english'))
else:
warnings.warn('nltk stopwords not found, stopwords won\'t be ignored when considering text duplicates.'
' Please check your internet connection.')
return text
if nltk.download('punkt', quiet=True):
tokenize = word_tokenize
else:
tokenize = str.split
words = tokenize(text)
return ' '.join([word for word in words if word.lower() not in stop_words])


Expand Down
2 changes: 1 addition & 1 deletion docs/source/nlp/usage_guides/text_data_object.rst
Expand Up @@ -4,7 +4,7 @@
The TextData Object
===================

The :class:`TextData <deepchecks.nlp.text_data.TextData>` is a container for your textual data, labels, and relevant
The :class:`TextData <deepchecks.nlp.TextData>` is a container for your textual data, labels, and relevant
metadata for NLP tasks and is a basic building block in the ``deepchecks.nlp`` subpackage.
In order to use any functionality of the ``deepchecks.nlp`` subpackage, you need to first create a ``TextData`` object.
The ``TextData`` object enables easy access to metadata, embeddings and properties relevant for training and validating ML
Expand Down