Skip to content

Commit

Permalink
Filter response text using stemmed values
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Oct 21, 2018
1 parent e73dff3 commit 2634e83
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 38 deletions.
6 changes: 2 additions & 4 deletions chatterbot/logic/best_match.py
@@ -1,5 +1,4 @@
from chatterbot.logic import LogicAdapter
from chatterbot import stemming


class BestMatch(LogicAdapter):
Expand All @@ -13,9 +12,8 @@ def get(self, input_statement):
Takes a statement string and a list of statement strings.
Returns the closest matching statement from the list.
"""
s = stemming.RidiculouslySimpleStemmer()
statement_list = self.chatbot.storage.filter(
stemmed_text=s.stem(input_statement.text)
statement_list = self.chatbot.storage.get_response_statements(
text=input_statement.text
)

closest_match = input_statement
Expand Down
3 changes: 3 additions & 0 deletions chatterbot/stemming.py
Expand Up @@ -18,6 +18,9 @@ def __init__(self, language='english'):

def stem(self, text):

if not text:
return ''

# Remove punctuation
text = text.translate(self.punctuation_table)

Expand Down
20 changes: 14 additions & 6 deletions chatterbot/storage/sql_storage.py
Expand Up @@ -157,7 +157,8 @@ def create(self, **kwargs):
tags = set(kwargs.pop('tags', []))

if 'stemmed_text' not in kwargs:
kwargs['stemmed_text'] = self.stemmer.stem(kwargs['text'])
if kwargs.get('in_response_to'):
kwargs['stemmed_text'] = self.stemmer.stem(kwargs['in_response_to'])

statement = Statement(**kwargs)

Expand Down Expand Up @@ -199,7 +200,8 @@ def create_many(self, statements):
tags = set(statement_data.pop('tags', []))

if 'stemmed_text' not in statement_data:
statement_data['stemmed_text'] = self.stemmer.stem(statement_data['text'])
if statement_data.get('in_response_to'):
statement_data['stemmed_text'] = self.stemmer.stem(statement_data['in_response_to'])

statement = Statement(**statement_data)

Expand Down Expand Up @@ -254,7 +256,8 @@ def update(self, statement):

record.created_at = statement.created_at

record.stemmed_text = self.stemmer.stem(record.text)
if statement.in_response_to:
record.stemmed_text = self.stemmer.stem(statement.in_response_to)

for _tag in statement.tags:
tag = session.query(Tag).filter_by(name=_tag).first()
Expand Down Expand Up @@ -290,14 +293,14 @@ def get_random(self):
session.close()
return statement

def get_response_statements(self, page_size=1000):
def get_response_statements(self, text=None, page_size=1000):
"""
Return only statements that are in response to another statement.
A statement must exist which lists the closest matching statement in the
in_response_to field. Otherwise, the logic adapter may find a closest
matching statement that does not have a known response.
"""
from sqlalchemy import func
from sqlalchemy import func, or_

Statement = self.get_model('statement')

Expand All @@ -308,10 +311,15 @@ def get_response_statements(self, page_size=1000):
start = 0
stop = min(page_size, total_statements)

or_query = [
Statement.stemmed_text.contains(trigram) for trigram in self.stemmer.stem(text).split(' ')
]

while stop <= total_statements:

statement_set = session.query(Statement).filter(
Statement.in_response_to.isnot(None)
Statement.in_response_to.isnot(None),
or_(*or_query)
).slice(start, stop)

start += page_size
Expand Down
28 changes: 0 additions & 28 deletions tests/test_stemming.py

This file was deleted.

0 comments on commit 2634e83

Please sign in to comment.