diff --git a/spellchecker/spellchecker.py b/spellchecker/spellchecker.py index 5f9f644..64a25f0 100644 --- a/spellchecker/spellchecker.py +++ b/spellchecker/spellchecker.py @@ -136,10 +136,19 @@ def candidates(self, word): word (str): The word for which to calculate candidate spellings Returns: set: The set of words that are possible candidates ''' - - return (self.known([word]) or self.known(self.edit_distance_1(word)) or - (self._distance == 2 and - self.known(self.edit_distance_2(word))) or {word}) + if self.known([word]): # short-cut if word is correct already + return {word} + # get edit distance 1... + res = [x for x in self.edit_distance_1(word)] + tmp = self.known(res) + if tmp: + return tmp + # if still not found, use the edit distance 1 to calc edit distance 2 + if self._distance == 2: + tmp = self.known([x for x in self.__edit_distance_alt(res)]) + if tmp: + return tmp + return {word} def known(self, words): ''' The subset of `words` that appear in the dictionary of words @@ -150,7 +159,8 @@ def known(self, words): Returns: set: The set of those words from the input that are in the \ corpus ''' - return set(w for w in words if w in self._word_frequency.dictionary or + tmp = [w.lower() for w in words] + return set(w for w in tmp if w in self._word_frequency.dictionary or not self._check_if_should_check(w)) def unknown(self, words): @@ -162,7 +172,7 @@ def unknown(self, words): Returns: set: The set of those words from the input that are not in \ the corpus ''' - tmp = [w for w in words if self._check_if_should_check(w)] + tmp = [w.lower() for w in words if self._check_if_should_check(w)] return set(w for w in tmp if w not in self._word_frequency.dictionary) def edit_distance_1(self, word): @@ -174,6 +184,7 @@ def edit_distance_1(self, word): Returns: set: The set of strings that are edit distance one from the \ provided word ''' + word = word.lower() if self._check_if_should_check(word) is False: return {word} letters = self._word_frequency.letters @@ -193,8 +204,21 @@ def edit_distance_2(self, word): Returns: set: The set of strings that are edit distance two from the \ provided word ''' - return (e2 for e1 in self.edit_distance_1(word) - for e2 in self.edit_distance_1(e1)) + word = word.lower() + return [e2 for e1 in self.edit_distance_1(word) + for e2 in self.edit_distance_1(e1)] + + def __edit_distance_alt(self, words): + ''' Compute all strings that are 1 edits away from all the words using + only the letters in the corpus + + Args: + words (list): The words for which to calculate the edit distance + Returns: + set: The set of strings that are edit distance two from the \ + provided words ''' + words = [x.lower() for x in words] + return [e2 for e1 in words for e2 in self.edit_distance_1(e1)] @staticmethod def _check_if_should_check(word): @@ -214,6 +238,7 @@ class WordFrequency(object): different methods to load the data and update over time ''' __slots__ = ['_dictionary', '_total_words', '_unique_words', '_letters'] + def __init__(self): self._dictionary = Counter() self._total_words = 0 @@ -222,11 +247,20 @@ def __init__(self): def __contains__(self, key): ''' turn on contains ''' - return key in self._dictionary + return key.lower() in self._dictionary def __getitem__(self, key): ''' turn on getitem ''' - return self._dictionary[key] + return self._dictionary[key.lower()] + + def pop(self, key, default=None): + ''' Remove the key and return the associated value or default if not + found + + Args: + key (str): The key to remove + default (obj): The value to return if key is not present ''' + return self._dictionary.pop(key.lower(), default) @property def dictionary(self): @@ -328,7 +362,7 @@ def add(self, word): Args: word (str): The word to add ''' - self.load_words([word.lower()]) + self.load_words([word]) def remove_words(self, words): ''' Remove a list of words from the word frequency list @@ -353,7 +387,7 @@ def remove_by_threshold(self, threshold=5): Args: threshold (int): The threshold at which a word is to be \ removed ''' - keys = [x.lower() for x in self._dictionary.keys()] + keys = [x for x in self._dictionary.keys()] for key in keys: if self._dictionary[key] <= threshold: self._dictionary.pop(key) diff --git a/tests/spellchecker_test.py b/tests/spellchecker_test.py index 823626a..dcf4332 100644 --- a/tests/spellchecker_test.py +++ b/tests/spellchecker_test.py @@ -235,3 +235,36 @@ def test_import_export_gzip(self): self.assertFalse('bananna' in sp) os.remove(new_filepath) + + def test_capitalization(self): + ''' test that capitalization doesn't affect in comparisons ''' + spell = SpellChecker(language=None) + spell.word_frequency.add('Bob') + spell.word_frequency.add('Bob') + spell.word_frequency.add('Bab') + self.assertEqual('Bob' in spell, True) + self.assertEqual('BOb' in spell, True) + self.assertEqual('BOB' in spell, True) + self.assertEqual('bob' in spell, True) + + words = ['Bb', 'bb', 'BB'] + self.assertEqual(spell.unknown(words), {'bb'}) + + known_words = ['BOB', 'bOb'] + self.assertEqual(spell.known(known_words), {'bob'}) + + self.assertEqual(spell.candidates('BB'), {'bob', 'bab'}) + self.assertEqual(spell.correction('BB'), 'bob') + + def test_pop(self): + ''' test the popping of a word ''' + spell = SpellChecker() + self.assertEqual('apple' in spell, True) + self.assertGreater(spell.word_frequency.pop('apple'), 1) + self.assertEqual('apple' in spell, False) + + def test_pop_default(self): + ''' test the default value being set for popping a word ''' + spell = SpellChecker() + self.assertEqual('appleies' in spell, False) + self.assertEqual(spell.word_frequency.pop('appleies', False), False)