Skip to content

Commit

Permalink
Add filter option for search_text_contains
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Dec 9, 2018
1 parent 91975d0 commit 4cfa79a
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 2 deletions.
11 changes: 11 additions & 0 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def filter(self, **kwargs):
exclude_text = kwargs.pop('exclude_text', None)
exclude_text_words = kwargs.pop('exclude_text_words', [])
persona_not_startswith = kwargs.pop('persona_not_startswith', None)
search_text_contains = kwargs.pop('search_text_contains', None)

# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
Expand Down Expand Up @@ -73,6 +74,16 @@ def filter(self, **kwargs):
persona__startswith='bot:'
)

if search_text_contains:
or_query = Q()

for word in search_text_contains.split(' '):
or_query |= Q(search_text__contains=word)

statements = statements.filter(
or_query
)

if order_by:
statements = statements.order_by(*order_by)

Expand Down
7 changes: 7 additions & 0 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def filter(self, **kwargs):
exclude_text = kwargs.pop('exclude_text', None)
exclude_text_words = kwargs.pop('exclude_text_words', [])
persona_not_startswith = kwargs.pop('persona_not_startswith', None)
search_text_contains = kwargs.pop('search_text_contains', None)

if tags:
kwargs['tags'] = {
Expand Down Expand Up @@ -119,6 +120,12 @@ def filter(self, **kwargs):
}
kwargs['persona']['$not'] = re.compile('^bot:*')

if search_text_contains:
or_regex = '|'.join([
'{}'.format(word) for word in search_text_contains.split(' ')
])
kwargs['search_text'] = re.compile(or_regex)

mongo_ordering = []

if order_by:
Expand Down
9 changes: 9 additions & 0 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def filter(self, **kwargs):
exclude_text = kwargs.pop('exclude_text', None)
exclude_text_words = kwargs.pop('exclude_text_words', [])
persona_not_startswith = kwargs.pop('persona_not_startswith', None)
search_text_contains = kwargs.pop('search_text_contains', None)

# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
Expand Down Expand Up @@ -151,6 +152,14 @@ def filter(self, **kwargs):
~Statement.persona.startswith('bot:')
)

if search_text_contains:
or_query = [
Statement.search_text.contains(word) for word in search_text_contains.split(' ')
]
statements = statements.filter(
or_(*or_query)
)

if order_by:

if 'created_at' in order_by:
Expand Down
3 changes: 1 addition & 2 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,8 @@ def track_progress(members):

def train(self):
import glob
from chatterbot.stemming import SimpleStemmer

stemmer = SimpleStemmer()
stemmer = SimpleStemmer(language=self.stemmer.language)

# Download and extract the Ubuntu dialog corpus if needed
corpus_download_path = self.download(self.data_download_url)
Expand Down
21 changes: 21 additions & 0 deletions tests/storage/test_mongo_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,27 @@ def test_persona_not_startswith(self):
self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains(self):
self.adapter.create(text='Hello!', search_text='hello exclamation')
self.adapter.create(text='Hi everyone!', search_text='hi everyone')

results = list(self.adapter.filter(
search_text_contains='everyone'
))

self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains_multiple_matches(self):
self.adapter.create(text='Hello!', search_text='hello exclamation')
self.adapter.create(text='Hi everyone!', search_text='hi everyone')

results = list(self.adapter.filter(
search_text_contains='hello everyone'
))

self.assertEqual(len(results), 2)


class MongoOrderingTestCase(MongoAdapterTestCase):
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/storage/test_sql_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,27 @@ def test_persona_not_startswith(self):
self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains(self):
self.adapter.create(text='Hello!', search_text='hello exclamation')
self.adapter.create(text='Hi everyone!', search_text='hi everyone')

results = list(self.adapter.filter(
search_text_contains='everyone'
))

self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains_multiple_matches(self):
self.adapter.create(text='Hello!', search_text='hello exclamation')
self.adapter.create(text='Hi everyone!', search_text='hi everyone')

results = list(self.adapter.filter(
search_text_contains='hello everyone'
))

self.assertEqual(len(results), 2)


class SQLOrderingTests(SQLStorageAdapterTestCase):
"""
Expand Down
21 changes: 21 additions & 0 deletions tests_django/test_django_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,27 @@ def test_persona_not_startswith(self):
self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains(self):
self.adapter.create(text='Hello!', search_text='hello exclamation')
self.adapter.create(text='Hi everyone!', search_text='hi everyone')

results = list(self.adapter.filter(
search_text_contains='everyone'
))

self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, 'Hi everyone!')

def test_search_text_contains_multiple_matches(self):
self.adapter.create(text='Hello!', search_text='hello exclamation')
self.adapter.create(text='Hi everyone!', search_text='hi everyone')

results = list(self.adapter.filter(
search_text_contains='hello everyone'
))

self.assertEqual(len(results), 2)


class DjangoOrderingTests(DjangoAdapterTestCase):
"""
Expand Down

0 comments on commit 4cfa79a

Please sign in to comment.