Skip to content

Commit

Permalink
Noam/dee 538 better handle download of nltk (#2493)
Browse files Browse the repository at this point in the history
* quite download and warnings for nltk stuff

* docs fix
  • Loading branch information
noamzbr committed May 4, 2023
1 parent bcbfa68 commit 294ed16
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
5 changes: 4 additions & 1 deletion deepchecks/nlp/checks/data_integrity/unknown_tokens.py
Expand Up @@ -10,6 +10,7 @@
#
"""Module contains the Unknown Tokens check."""
import typing as t
import warnings
from collections import Counter

import nltk
Expand Down Expand Up @@ -113,9 +114,11 @@ def run_logic(self, context: Context, dataset_kind) -> CheckResult:
def find_unknown_words(self, samples, indices):
"""Find words with unknown tokens in samples."""
# Choose tokenizer based on availability of nltk
if nltk.download('punkt'):
if nltk.download('punkt', quiet=True):
tokenize = nltk.word_tokenize
else:
warnings.warn('nltk punkt is not available, using str.split instead to identify individual words. '
'Please check your internet connection.')
tokenize = str.split

# Tokenize samples and count unknown words
Expand Down
18 changes: 12 additions & 6 deletions deepchecks/nlp/utils/text.py
Expand Up @@ -12,6 +12,7 @@
import string
import typing as t
import unicodedata
import warnings

import nltk
from nltk.corpus import stopwords
Expand Down Expand Up @@ -62,10 +63,6 @@ def break_to_lines_and_trim(s, max_lines: int = 10, min_line_length: int = 50, m
return '<br>'.join(lines)


nltk.download('stopwords')
nltk.download('punkt')


def remove_punctuation(text: str) -> str:
"""Remove punctuation characters from a string."""
return text.translate(str.maketrans('', '', string.punctuation))
Expand All @@ -78,8 +75,17 @@ def normalize_unicode(text: str) -> str:

def remove_stopwords(text: str) -> str:
"""Remove stop words from a string."""
stop_words = set(stopwords.words('english'))
words = word_tokenize(text)
if nltk.download('stopwords', quiet=True):
stop_words = set(stopwords.words('english'))
else:
warnings.warn('nltk stopwords not found, stopwords won\'t be ignored when considering text duplicates.'
' Please check your internet connection.')
return text
if nltk.download('punkt', quiet=True):
tokenize = word_tokenize
else:
tokenize = str.split
words = tokenize(text)
return ' '.join([word for word in words if word.lower() not in stop_words])


Expand Down
2 changes: 1 addition & 1 deletion docs/source/nlp/usage_guides/text_data_object.rst
Expand Up @@ -4,7 +4,7 @@
The TextData Object
===================

The :class:`TextData <deepchecks.nlp.text_data.TextData>` is a container for your textual data, labels, and relevant
The :class:`TextData <deepchecks.nlp.TextData>` is a container for your textual data, labels, and relevant
metadata for NLP tasks and is a basic building block in the ``deepchecks.nlp`` subpackage.
In order to use any functionality of the ``deepchecks.nlp`` subpackage, you need to first create a ``TextData`` object.
The ``TextData`` object enables easy access to metadata, embeddings and properties relevant for training and validating ML
Expand Down

0 comments on commit 294ed16

Please sign in to comment.