Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use stemmed text to search for statement matches #1452

Merged
merged 3 commits into from Dec 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions chatterbot/logic/best_match.py
Expand Up @@ -13,21 +13,33 @@ def get(self, input_statement):
Takes a statement string and a list of statement strings.
Returns the closest matching statement from the list.
"""
self.chatbot.logger.info('Beginning search for close text match')

input_search_text = self.chatbot.storage.stemmer.get_bigram_pair_string(
input_statement.text
)

statement_list = self.chatbot.storage.filter(
search_text_contains=input_search_text,
persona_not_startswith='bot:',
page_size=self.search_page_size
)

closest_match = input_statement
closest_match.confidence = 0

self.chatbot.logger.info('Processing search results')

# Find the closest matching known statement
for statement in statement_list:
confidence = self.compare_statements(input_statement, statement)

if confidence > closest_match.confidence:
statement.confidence = confidence
closest_match = statement
self.chatbot.logger.info('Similar text found: {} {}'.format(
closest_match.text, confidence
))

# Stop searching if a match that is close enough is found
if closest_match.confidence >= self.maximum_similarity_threshold:
Expand Down
40 changes: 24 additions & 16 deletions chatterbot/stemming.py
@@ -1,5 +1,6 @@
import string
from nltk import pos_tag
from nltk.data import load as load_data
from nltk.corpus import wordnet, stopwords


Expand Down Expand Up @@ -120,8 +121,6 @@ class PosHypernymStemmer(object):
"""

def __init__(self, language='english'):
self.punctuation_table = str.maketrans(dict.fromkeys(string.punctuation))

self.language = language

self.stopwords = None
Expand Down Expand Up @@ -182,32 +181,41 @@ def get_bigram_pair_string(self, text):

DT:beautiful JJ:wetland
"""
words = text.split()
WORD_INDEX = 0
POS_INDEX = 1

pos_tags = []

sentence_detector = load_data('tokenizers/punkt/english.pickle')

# Separate punctuation from last word in string
if words:
word_with_punctuation_removed = words[-1].strip(string.punctuation)
for sentence in sentence_detector.tokenize(text.strip()):

if word_with_punctuation_removed:
words[-1] = word_with_punctuation_removed
# Remove punctuation
if sentence and sentence[-1] in string.punctuation:
sentence_with_punctuation_removed = sentence[:-1]

pos_tags = pos_tag(words)
if sentence_with_punctuation_removed:
sentence = sentence_with_punctuation_removed

words = sentence.split()

pos_tags.extend(pos_tag(words))

hypernyms = self.get_hypernyms(pos_tags)

high_quality_bigrams = []
all_bigrams = []

word_count = len(words)
word_count = len(pos_tags)

if word_count <= 1:
all_bigrams = words
if all_bigrams:
all_bigrams[0] = all_bigrams[0].lower()
if word_count == 1:
all_bigrams.append(
pos_tags[0][WORD_INDEX].lower()
)

for index in range(1, word_count):
word = words[index].lower()
previous_word_pos = pos_tags[index - 1][1]
word = pos_tags[index][WORD_INDEX].lower()
previous_word_pos = pos_tags[index - 1][POS_INDEX]
if word not in self.get_stopwords() and len(word) > 1:
bigram = previous_word_pos + ':' + hypernyms[index].lower()
high_quality_bigrams.append(bigram)
Expand Down
4 changes: 2 additions & 2 deletions chatterbot/storage/storage_adapter.py
@@ -1,5 +1,5 @@
import logging
from chatterbot.stemming import SimpleStemmer
from chatterbot.stemming import PosHypernymStemmer


class StorageAdapter(object):
Expand All @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs):
self.logger = kwargs.get('logger', logging.getLogger(__name__))
self.adapter_supports_queries = True

self.stemmer = SimpleStemmer(language=kwargs.get(
self.stemmer = PosHypernymStemmer(language=kwargs.get(
'stemmer_language', 'english'
))

Expand Down
6 changes: 3 additions & 3 deletions chatterbot/trainers.py
Expand Up @@ -5,7 +5,7 @@
from multiprocessing import Pool, Manager
from dateutil import parser as date_parser
from chatterbot.conversation import Statement
from chatterbot.stemming import SimpleStemmer
from chatterbot.stemming import PosHypernymStemmer
from chatterbot import utils


Expand All @@ -30,7 +30,7 @@ def __init__(self, chatbot, **kwargs):
environment_default
)

self.stemmer = SimpleStemmer(language=kwargs.get(
self.stemmer = PosHypernymStemmer(language=kwargs.get(
'stemmer_language', 'english'
))

Expand Down Expand Up @@ -430,7 +430,7 @@ def track_progress(members):
def train(self):
import glob

stemmer = SimpleStemmer(language=self.stemmer.language)
stemmer = PosHypernymStemmer(language=self.stemmer.language)

# Download and extract the Ubuntu dialog corpus if needed
corpus_download_path = self.download(self.data_download_url)
Expand Down
Expand Up @@ -49,7 +49,10 @@ def test_confidence_exact_match(self):
self.assertEqual(match.confidence, 1)

def test_confidence_half_match(self):
self.chatbot.storage.create(text='xxyy', in_response_to='xxyy')
# Assume that the storage adapter returns a partial match
self.adapter.chatbot.storage.filter = MagicMock(return_value=[
Statement(text='xxyy')
])

statement = Statement(text='wwxx')
match = self.adapter.get(statement)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_chatbot.py
Expand Up @@ -12,7 +12,8 @@ def test_get_initialization_functions(self):
functions = self.chatbot.get_initialization_functions()

self.assertIn('initialize_nltk_stopwords', functions)
self.assertIsLength(functions, 1)
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIsLength(functions, 2)

def test_get_initialization_functions_synset_distance(self):
"""
Expand All @@ -38,8 +39,9 @@ def test_get_initialization_functions_sentiment_comparison(self):
functions = self.chatbot.get_initialization_functions()

self.assertIn('initialize_nltk_stopwords', functions)
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIn('initialize_nltk_vader_lexicon', functions)
self.assertIsLength(functions, 2)
self.assertIsLength(functions, 3)

def test_get_initialization_functions_jaccard_similarity(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stemming.py
Expand Up @@ -209,7 +209,7 @@ def test_get_bigram_pair_string_multiple_words(self):
'Hello Dr. Salazar. How are you today?'
)

self.assertEqual(bigram_string, 'NNP:scholar NNP:salazar. PRP:present')
self.assertEqual(bigram_string, 'NNP:scholar NNP:salazar PRP:present')

def test_get_bigram_pair_string_single_character_words(self):
bigram_string = self.stemmer.get_bigram_pair_string(
Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_chatterbot_corpus_training.py
Expand Up @@ -29,15 +29,15 @@ def test_train_with_english_greeting_corpus_search_text(self):
results = list(self.chatbot.storage.filter(text='Hello'))

self.assertGreater(len(results), 1)
self.assertEqual(results[0].search_text, 'ell')
self.assertEqual(results[0].search_text, 'hello')

def test_train_with_english_greeting_corpus_search_in_response_to(self):
self.trainer.train('chatterbot.corpus.english.greetings')

results = list(self.chatbot.storage.filter(in_response_to='Hello'))

self.assertGreater(len(results), 1)
self.assertEqual(results[0].search_in_response_to, 'ell')
self.assertEqual(results[0].search_in_response_to, 'hello')

def test_train_with_english_greeting_corpus_tags(self):
self.trainer.train('chatterbot.corpus.english.greetings')
Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_list_training.py
Expand Up @@ -65,7 +65,7 @@ def test_training_sets_search_text(self):
))

self.assertIsLength(statements, 1)
self.assertEqual(statements[0].search_text, "ik")
self.assertEqual(statements[0].search_text, 'RB:kind PRP$:headdress')

def test_training_sets_search_in_response_to(self):

Expand All @@ -81,7 +81,7 @@ def test_training_sets_search_in_response_to(self):
))

self.assertIsLength(statements, 1)
self.assertEqual(statements[0].search_in_response_to, "ik")
self.assertEqual(statements[0].search_in_response_to, 'PRP:kind PRP$:headdress')

def test_database_has_correct_format(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_twitter_trainer.py
Expand Up @@ -88,11 +88,11 @@ def test_train_sets_search_text(self):
statements = list(self.trainer.chatbot.storage.filter())

self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_text, 'urub')
self.assertEqual(statements[0].search_text, 'PRP:sure IN:jewel')

def test_train_sets_search_in_response_to(self):
self.trainer.train()
statements = list(self.trainer.chatbot.storage.filter())

self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_in_response_to, 'urub')
self.assertEqual(statements[0].search_in_response_to, 'PRP:sure IN:jewel')
4 changes: 2 additions & 2 deletions tests/training/test_ubuntu_corpus_training.py
Expand Up @@ -176,7 +176,7 @@ def test_train_sets_search_text(self):
results = list(self.chatbot.storage.filter(text='Is anyone there?'))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].search_text, 'nyon')
self.assertEqual(results[0].search_text, 'VBZ:anyone')

def test_train_sets_search_in_response_to(self):
"""
Expand All @@ -190,7 +190,7 @@ def test_train_sets_search_in_response_to(self):
results = list(self.chatbot.storage.filter(in_response_to='Is anyone there?'))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].search_in_response_to, 'nyon')
self.assertEqual(results[0].search_in_response_to, 'VBZ:anyone')

def test_is_extracted(self):
"""
Expand Down