Skip to content

Commit

Permalink
Switch from SimepleStemmer to PosStemmer
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Dec 9, 2018
1 parent 31642b0 commit 7289b9a
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 35 deletions.
33 changes: 22 additions & 11 deletions chatterbot/stemming.py
Expand Up @@ -185,27 +185,38 @@ def get_bigram_pair_string(self, text):
words = text.split()

# Separate punctuation from last word in string
word_with_punctuation_removed = words[-1].strip(string.punctuation)
if words:
word_with_punctuation_removed = words[-1].strip(string.punctuation)

if word_with_punctuation_removed:
words[-1] = word_with_punctuation_removed
if word_with_punctuation_removed:
words[-1] = word_with_punctuation_removed

pos_tags = pos_tag(words)

hypernyms = self.get_hypernyms(pos_tags)

bigrams = []
high_quality_bigrams = []
all_bigrams = []

word_count = len(words)

if word_count <= 1:
bigrams = words
if bigrams:
bigrams[0] = bigrams[0].lower()
all_bigrams = words
if all_bigrams:
all_bigrams[0] = all_bigrams[0].lower()

for index in range(1, word_count):
if words[index].lower() not in self.get_stopwords():
bigram = pos_tags[index - 1][1] + ':' + hypernyms[index].lower()
bigrams.append(bigram)
word = words[index].lower()
previous_word_pos = pos_tags[index][1]
if word not in self.get_stopwords() and len(word) > 1:
bigram = previous_word_pos + ':' + hypernyms[index].lower()
high_quality_bigrams.append(bigram)
all_bigrams.append(bigram)
else:
bigram = previous_word_pos + ':' + word
all_bigrams.append(bigram)

return ' '.join(bigrams)
if high_quality_bigrams:
all_bigrams = high_quality_bigrams

return ' '.join(all_bigrams)
2 changes: 1 addition & 1 deletion chatterbot/storage/sql_storage.py
Expand Up @@ -158,7 +158,7 @@ def filter(self, **kwargs):
]

statements = statements.filter(
*or_query
or_(*or_query)
)

if order_by:
Expand Down
4 changes: 2 additions & 2 deletions chatterbot/storage/storage_adapter.py
@@ -1,5 +1,5 @@
import logging
from chatterbot.stemming import SimpleStemmer
from chatterbot.stemming import PosHypernymStemmer


class StorageAdapter(object):
Expand All @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs):
self.logger = kwargs.get('logger', logging.getLogger(__name__))
self.adapter_supports_queries = True

self.stemmer = SimpleStemmer(language=kwargs.get(
self.stemmer = PosHypernymStemmer(language=kwargs.get(
'stemmer_language', 'english'
))

Expand Down
7 changes: 3 additions & 4 deletions chatterbot/trainers.py
Expand Up @@ -5,7 +5,7 @@
from multiprocessing import Pool, Manager
from dateutil import parser as date_parser
from chatterbot.conversation import Statement
from chatterbot.stemming import SimpleStemmer
from chatterbot.stemming import PosHypernymStemmer
from chatterbot import utils


Expand All @@ -30,7 +30,7 @@ def __init__(self, chatbot, **kwargs):
environment_default
)

self.stemmer = SimpleStemmer(language=kwargs.get(
self.stemmer = PosHypernymStemmer(language=kwargs.get(
'stemmer_language', 'english'
))

Expand Down Expand Up @@ -429,9 +429,8 @@ def track_progress(members):

def train(self):
import glob
from chatterbot.stemming import SimpleStemmer

stemmer = SimpleStemmer()
stemmer = PosHypernymStemmer(language=self.stemmer.language)

# Download and extract the Ubuntu dialog corpus if needed
corpus_download_path = self.download(self.data_download_url)
Expand Down
Expand Up @@ -49,7 +49,10 @@ def test_confidence_exact_match(self):
self.assertEqual(match.confidence, 1)

def test_confidence_half_match(self):
self.chatbot.storage.create(text='xxyy', in_response_to='xxyy')
# Assume that the storage adapter returns a partial match
self.adapter.chatbot.storage.filter = MagicMock(return_value=[
Statement(text='xxyy')
])

statement = Statement(text='wwxx')
match = self.adapter.get(statement)
Expand Down
21 changes: 21 additions & 0 deletions tests/storage/test_sql_adapter.py
Expand Up @@ -224,6 +224,27 @@ def test_persona_not_startswith(self):
self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains(self):
self.adapter.create(text='Hello!', search_text='hello')
self.adapter.create(text='Hi everyone!', search_text='everyone')

results = list(self.adapter.filter(
search_text_contains='everyone'
))

self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains_multiple_matches(self):
self.adapter.create(text='Hello!', search_text='hello')
self.adapter.create(text='Hi everyone!', search_text='everyone')

results = list(self.adapter.filter(
search_text_contains='hello everyone'
))

self.assertEqual(len(results), 2)


class SQLOrderingTests(SQLStorageAdapterTestCase):
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/test_chatbot.py
Expand Up @@ -12,7 +12,8 @@ def test_get_initialization_functions(self):
functions = self.chatbot.get_initialization_functions()

self.assertIn('initialize_nltk_stopwords', functions)
self.assertIsLength(functions, 1)
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIsLength(functions, 2)

def test_get_initialization_functions_synset_distance(self):
"""
Expand All @@ -38,8 +39,9 @@ def test_get_initialization_functions_sentiment_comparison(self):
functions = self.chatbot.get_initialization_functions()

self.assertIn('initialize_nltk_stopwords', functions)
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIn('initialize_nltk_vader_lexicon', functions)
self.assertIsLength(functions, 2)
self.assertIsLength(functions, 3)

def test_get_initialization_functions_jaccard_similarity(self):
"""
Expand Down
26 changes: 20 additions & 6 deletions tests/test_stemming.py
Expand Up @@ -7,6 +7,13 @@ class SimpleStemmerTests(TestCase):
def setUp(self):
self.stemmer = stemming.SimpleStemmer()

def test_empty_string(self):
stemmed_text = self.stemmer.get_bigram_pair_string(
''
)

self.assertEqual(stemmed_text, '')

def test_stemming(self):
stemmed_text = self.stemmer.get_stemmed_words(
'Hello, how are you doing on this awesome day?'
Expand Down Expand Up @@ -112,12 +119,19 @@ class PosHypernymStemmerTests(TestCase):
def setUp(self):
self.stemmer = stemming.PosHypernymStemmer()

def test_empty_string(self):
stemmed_text = self.stemmer.get_bigram_pair_string(
''
)

self.assertEqual(stemmed_text, '')

def test_stemming(self):
stemmed_text = self.stemmer.get_bigram_pair_string(
'Hello, how are you doing on this awesome day?'
)

self.assertEqual(stemmed_text, 'DT:awesome JJ:time_unit')
self.assertEqual(stemmed_text, 'JJ:awesome NN:time_unit')

def test_string_becomes_lowercase(self):
stemmed_text = self.stemmer.get_bigram_pair_string('THIS IS HOW IT BEGINS!')
Expand All @@ -127,12 +141,12 @@ def test_string_becomes_lowercase(self):
def test_stemming_medium_sized_words(self):
stemmed_text = self.stemmer.get_bigram_pair_string('Hello, my name is Gunther.')

self.assertEqual(stemmed_text, 'PRP$:language_unit VBZ:gunther')
self.assertEqual(stemmed_text, 'NN:language_unit NNP:gunther')

def test_stemming_long_words(self):
stemmed_text = self.stemmer.get_bigram_pair_string('I play several orchestra instruments for pleasuer.')

self.assertEqual(stemmed_text, 'PRP:compete VBP:several JJ:orchestra JJ:device IN:pleasuer')
self.assertEqual(stemmed_text, 'VBP:compete JJ:several JJ:orchestra NNS:device NN:pleasuer')

def test_get_bigram_pair_string_punctuation_only(self):
bigram_string = self.stemmer.get_bigram_pair_string(
Expand Down Expand Up @@ -195,18 +209,18 @@ def test_get_bigram_pair_string_multiple_words(self):
'Hello Dr. Salazar. How are you today?'
)

self.assertEqual(bigram_string, 'NNP:scholar NNP:salazar. PRP:present')
self.assertEqual(bigram_string, 'NNP:scholar NNP:salazar. NN:present')

def test_get_bigram_pair_string_single_character_words(self):
bigram_string = self.stemmer.get_bigram_pair_string(
'a e i o u'
)

self.assertEqual(bigram_string, 'DT:antioxidant VBP:nucleotide')
self.assertEqual(bigram_string, 'NN:e NN:i VBP:o NN:u')

def test_get_bigram_pair_string_two_character_words(self):
bigram_string = self.stemmer.get_bigram_pair_string(
'Lo my mu it is of us'
)

self.assertEqual(bigram_string, 'PRP$:letter IN:us')
self.assertEqual(bigram_string, 'NN:letter PRP:us')
4 changes: 2 additions & 2 deletions tests/training/test_chatterbot_corpus_training.py
Expand Up @@ -29,15 +29,15 @@ def test_train_with_english_greeting_corpus_search_text(self):
results = list(self.chatbot.storage.filter(text='Hello'))

self.assertGreater(len(results), 1)
self.assertEqual(results[0].search_text, 'ell')
self.assertEqual(results[0].search_text, 'hello')

def test_train_with_english_greeting_corpus_search_in_response_to(self):
self.trainer.train('chatterbot.corpus.english.greetings')

results = list(self.chatbot.storage.filter(in_response_to='Hello'))

self.assertGreater(len(results), 1)
self.assertEqual(results[0].search_in_response_to, 'ell')
self.assertEqual(results[0].search_in_response_to, 'hello')

def test_train_with_english_greeting_corpus_tags(self):
self.trainer.train('chatterbot.corpus.english.greetings')
Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_list_training.py
Expand Up @@ -65,7 +65,7 @@ def test_training_sets_search_text(self):
))

self.assertIsLength(statements, 1)
self.assertEqual(statements[0].search_text, "ik")
self.assertEqual(statements[0].search_text, 'IN:kind NN:headdress')

def test_training_sets_search_in_response_to(self):

Expand All @@ -81,7 +81,7 @@ def test_training_sets_search_in_response_to(self):
))

self.assertIsLength(statements, 1)
self.assertEqual(statements[0].search_in_response_to, "ik")
self.assertEqual(statements[0].search_in_response_to, 'IN:kind NN:headdress')

def test_database_has_correct_format(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_twitter_trainer.py
Expand Up @@ -88,11 +88,11 @@ def test_train_sets_search_text(self):
statements = list(self.trainer.chatbot.storage.filter())

self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_text, 'urub')
self.assertEqual(statements[0].search_text, 'VBP:sure NNP:jewel')

def test_train_sets_search_in_response_to(self):
self.trainer.train()
statements = list(self.trainer.chatbot.storage.filter())

self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_in_response_to, 'urub')
self.assertEqual(statements[0].search_in_response_to, 'VBP:sure NNP:jewel')
4 changes: 2 additions & 2 deletions tests/training/test_ubuntu_corpus_training.py
Expand Up @@ -176,7 +176,7 @@ def test_train_sets_search_text(self):
results = list(self.chatbot.storage.filter(text='Is anyone there?'))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].search_text, 'nyon')
self.assertEqual(results[0].search_text, 'NN:anyone')

def test_train_sets_search_in_response_to(self):
"""
Expand All @@ -190,7 +190,7 @@ def test_train_sets_search_in_response_to(self):
results = list(self.chatbot.storage.filter(in_response_to='Is anyone there?'))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].search_in_response_to, 'nyon')
self.assertEqual(results[0].search_in_response_to, 'NN:anyone')

def test_is_extracted(self):
"""
Expand Down

0 comments on commit 7289b9a

Please sign in to comment.