Skip to content

Commit

Permalink
Merge ababc88 into f645e2f
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Sep 19, 2018
2 parents f645e2f + ababc88 commit 54726e3
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 11 deletions.
5 changes: 3 additions & 2 deletions chatterbot/ext/django_chatterbot/abstract_models.py
Expand Up @@ -77,8 +77,9 @@ def add_tags(self, *tags):
Add a list of strings to the statement as tags.
(Overrides the method from StatementMixin)
"""
for tag in tags:
self.tags.create(name=tag)
for _tag in tags:
if not self.tags.filter(name=_tag).exists():
self.tags.create(name=_tag)


class AbstractBaseTag(models.Model):
Expand Down
14 changes: 10 additions & 4 deletions chatterbot/storage/django_storage.py
Expand Up @@ -51,15 +51,18 @@ def create(self, **kwargs):
Returns the created statement.
"""
Statement = self.get_model('statement')
Tag = self.get_model('tag')

tags = kwargs.pop('tags', [])

statement = Statement(**kwargs)

statement.save()

for tag in tags:
statement.tags.create(name=tag)
for _tag in tags:
tag, _ = Tag.objects.get_or_create(name=_tag)

statement.tags.add(tag)

return statement

Expand All @@ -68,6 +71,7 @@ def update(self, statement):
Update the provided statement.
"""
Statement = self.get_model('statement')
Tag = self.get_model('tag')

if hasattr(statement, 'id'):
statement.save()
Expand All @@ -79,8 +83,10 @@ def update(self, statement):
created_at=statement.created_at
)

for tag in statement.tags.all():
statement.tags.create(name=tag)
for _tag in statement.tags.all():
tag, _ = Tag.objects.get_or_create(name=_tag)

statement.tags.add(tag)

return statement

Expand Down
17 changes: 16 additions & 1 deletion chatterbot/storage/mongodb.py
Expand Up @@ -138,6 +138,9 @@ def create(self, **kwargs):
"""
Statement = self.get_model('statement')

if 'tags' in kwargs:
kwargs['tags'] = list(set(kwargs['tags']))

inserted = self.statements.insert_one(kwargs)

kwargs['id'] = inserted.inserted_id
Expand All @@ -147,6 +150,18 @@ def create(self, **kwargs):
def update(self, statement):
data = statement.serialize()
data.pop('id', None)
data.pop('tags', None)

update_data = {
'$set': data
}

if statement.tags:
update_data['$addToSet'] = {
'tags': {
'$each': statement.tags
}
}

search_parameters = {}

Expand All @@ -158,7 +173,7 @@ def update(self, statement):

update_operation = self.statements.update_one(
search_parameters,
{'$set': data},
update_data,
upsert=True
)

Expand Down
13 changes: 9 additions & 4 deletions chatterbot/storage/sql_storage.py
Expand Up @@ -160,13 +160,18 @@ def create(self, **kwargs):

session = self.Session()

tags = kwargs.pop('tags', [])
tags = set(kwargs.pop('tags', []))

statement = Statement(**kwargs)

statement.tags.extend([
Tag(name=tag) for tag in tags
])
for _tag in tags:
tag = session.query(Tag).filter_by(name=_tag).first()

if not tag:
# Create the tag
tag = Tag(name=_tag)

statement.tags.append(tag)

session.add(statement)

Expand Down
28 changes: 28 additions & 0 deletions tests/storage_adapter_tests/test_mongo_adapter.py
Expand Up @@ -289,6 +289,19 @@ def test_create_tags(self):
self.assertIn('a', results[0].get_tags())
self.assertIn('b', results[0].get_tags())

def test_create_duplicate_tags(self):
"""
The storage adapter should not create a statement with tags
that are duplicates.
"""
self.adapter.create(text='testing', tags=['ab', 'ab'])

results = self.adapter.filter()

self.assertEqual(len(results), 1)
self.assertEqual(len(results[0].get_tags()), 1)
self.assertEqual(results[0].get_tags(), ['ab'])


class StorageAdapterUpdateTestCase(MongoAdapterFilterTestCase):
"""
Expand All @@ -305,3 +318,18 @@ def test_update_adds_tags(self):
self.assertEqual(len(statements), 1)
self.assertIn('a', statements[0].get_tags())
self.assertIn('b', statements[0].get_tags())

def test_update_duplicate_tags(self):
"""
The storage adapter should not update a statement with tags
that are duplicates.
"""
statement = self.adapter.create(text='Testing', tags=['ab'])
statement.add_tags('ab')
self.adapter.update(statement)

statements = self.adapter.filter()

self.assertEqual(len(statements), 1)
self.assertEqual(len(statements[0].get_tags()), 1)
self.assertEqual(statements[0].get_tags(), ['ab'])
28 changes: 28 additions & 0 deletions tests/storage_adapter_tests/test_sqlalchemy_adapter.py
Expand Up @@ -303,6 +303,19 @@ def test_create_tags(self):
self.assertIn('a', results[0].get_tags())
self.assertIn('b', results[0].get_tags())

def test_create_duplicate_tags(self):
"""
The storage adapter should not create a statement with tags
that are duplicates.
"""
self.adapter.create(text='testing', tags=['ab', 'ab'])

results = self.adapter.filter()

self.assertEqual(len(results), 1)
self.assertEqual(len(results[0].get_tags()), 1)
self.assertEqual(results[0].get_tags(), ['ab'])


class StorageAdapterUpdateTestCase(SQLAlchemyAdapterTestCase):
"""
Expand All @@ -319,3 +332,18 @@ def test_update_adds_tags(self):
self.assertEqual(len(statements), 1)
self.assertIn('a', statements[0].get_tags())
self.assertIn('b', statements[0].get_tags())

def test_update_duplicate_tags(self):
"""
The storage adapter should not update a statement with tags
that are duplicates.
"""
statement = self.adapter.create(text='Testing', tags=['ab'])
statement.add_tags('ab')
self.adapter.update(statement)

statements = self.adapter.filter()

self.assertEqual(len(statements), 1)
self.assertEqual(len(statements[0].get_tags()), 1)
self.assertEqual(statements[0].get_tags(), ['ab'])
28 changes: 28 additions & 0 deletions tests_django/test_django_adapter.py
Expand Up @@ -263,6 +263,19 @@ def test_create_tags(self):
self.assertIn('a', results[0].get_tags())
self.assertIn('b', results[0].get_tags())

def test_create_duplicate_tags(self):
"""
The storage adapter should not create a statement with tags
that are duplicates.
"""
self.adapter.create(text='testing', tags=['ab', 'ab'])

results = self.adapter.filter()

self.assertEqual(len(results), 1)
self.assertEqual(len(results[0].get_tags()), 1)
self.assertEqual(results[0].get_tags(), ['ab'])


class StorageAdapterUpdateTestCase(DjangoStorageAdapterTestCase):
"""
Expand All @@ -279,3 +292,18 @@ def test_update_adds_tags(self):
self.assertEqual(len(statements), 1)
self.assertIn('a', statements[0].get_tags())
self.assertIn('b', statements[0].get_tags())

def test_update_duplicate_tags(self):
"""
The storage adapter should not update a statement with tags
that are duplicates.
"""
statement = self.adapter.create(text='Testing', tags=['ab'])
statement.add_tags('ab')
self.adapter.update(statement)

statements = self.adapter.filter()

self.assertEqual(len(statements), 1)
self.assertEqual(len(statements[0].get_tags()), 1)
self.assertEqual(statements[0].get_tags(), ['ab'])

0 comments on commit 54726e3

Please sign in to comment.