Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only control write permission in the ChatBot class #561

Merged
merged 1 commit into from
Jan 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, name, **kwargs):

self.logger = kwargs.get('logger', logging.getLogger(__name__))

# Allow the bot to save input it receives so that it can learn
self.read_only = kwargs.get('read_only', False)

if kwargs.get('initialize', True):
self.initialize()

Expand Down Expand Up @@ -138,7 +141,8 @@ def learn_response(self, statement, previous_statement):
))

# Save the statement after selecting a response
self.storage.update(statement)
if not self.read_only:
self.storage.update(statement)

def set_trainer(self, training_class, **kwargs):
"""
Expand Down
36 changes: 17 additions & 19 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,36 +73,34 @@ def filter(self, **kwargs):

return statements

def update(self, statement, **kwargs):
def update(self, statement):
"""
Update the provided statement.
"""
from chatterbot.ext.django_chatterbot.models import Statement as StatementModel

response_statement_cache = statement.response_statement_cache

# Do not alter the database unless writing is enabled
if not self.read_only:
statement, created = StatementModel.objects.get_or_create(text=statement.text)
statement.extra_data = getattr(statement, 'extra_data', '')
statement.save()
statement, created = StatementModel.objects.get_or_create(text=statement.text)
statement.extra_data = getattr(statement, 'extra_data', '')
statement.save()

for _response_statement in response_statement_cache:
for _response_statement in response_statement_cache:

response_statement, created = StatementModel.objects.get_or_create(
text=_response_statement.text
)
response_statement.extra_data = getattr(_response_statement, 'extra_data', '')
response_statement.save()
response_statement, created = StatementModel.objects.get_or_create(
text=_response_statement.text
)
response_statement.extra_data = getattr(_response_statement, 'extra_data', '')
response_statement.save()

response, created = statement.in_response.get_or_create(
statement=statement,
response=response_statement
)
response, created = statement.in_response.get_or_create(
statement=statement,
response=response_statement
)

if not created:
response.occurrence += 1
response.save()
if not created:
response.occurrence += 1
response.save()

return statement

Expand Down
32 changes: 13 additions & 19 deletions chatterbot/storage/jsonfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class JsonFileStorageAdapter(StorageAdapter):
:keyword silence_performance_warning: If set to True, the :code:`UnsuitableForProductionWarning`
will not be displayed.
:type silence_performance_warning: bool

:keyword read_only: If set to True, ChatterBot will not save information to the database.
False by default.
:type read_only: bool
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -154,24 +150,22 @@ def filter(self, **kwargs):

return results

def update(self, statement, **kwargs):
def update(self, statement):
"""
Update a statement in the database.
"""
# Do not alter the database unless writing is enabled
if not self.read_only:
data = statement.serialize()

# Remove the text key from the data
del data['text']
self.database.data(key=statement.text, value=data)

# Make sure that an entry for each response exists
for response_statement in statement.in_response_to:
response = self.find(response_statement.text)
if not response:
response = self.Statement(response_statement.text)
self.update(response)
data = statement.serialize()

# Remove the text key from the data
del data['text']
self.database.data(key=statement.text, value=data)

# Make sure that an entry for each response exists
for response_statement in statement.in_response_to:
response = self.find(response_statement.text)
if not response:
response = self.Statement(response_statement.text)
self.update(response)

return statement

Expand Down
52 changes: 22 additions & 30 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ class MongoDatabaseAdapter(StorageAdapter):
.. code-block:: python

database_uri='mongodb://example.com:8100/'


:keyword read_only: If set to True, ChatterBot will not save information to the database.
False by default.
:type read_only: bool
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -206,41 +201,38 @@ def filter(self, **kwargs):

return results

def update(self, statement, **kwargs):
def update(self, statement):
from pymongo import UpdateOne
from pymongo.errors import BulkWriteError

force = kwargs.get('force', False)
# Do not alter the database unless writing is enabled
if force or not self.read_only:
data = statement.serialize()
data = statement.serialize()

operations = []

update_operation = UpdateOne(
{'text': statement.text},
{'$set': data},
upsert=True
)
operations.append(update_operation)

operations = []
# Make sure that an entry for each response is saved
for response_dict in data.get('in_response_to', []):
response_text = response_dict.get('text')

# $setOnInsert does nothing if the document is not created
update_operation = UpdateOne(
{'text': statement.text},
{'$set': data},
{'text': response_text},
{'$set': response_dict},
upsert=True
)
operations.append(update_operation)

# Make sure that an entry for each response is saved
for response_dict in data.get('in_response_to', []):
response_text = response_dict.get('text')

# $setOnInsert does nothing if the document is not created
update_operation = UpdateOne(
{'text': response_text},
{'$set': response_dict},
upsert=True
)
operations.append(update_operation)

try:
self.statements.bulk_write(operations, ordered=False)
except BulkWriteError as bwe:
# Log the details of a bulk write error
self.logger.error(str(bwe.details))
try:
self.statements.bulk_write(operations, ordered=False)
except BulkWriteError as bwe:
# Log the details of a bulk write error
self.logger.error(str(bwe.details))

return statement

Expand Down
1 change: 0 additions & 1 deletion chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, base_query=None, *args, **kwargs):
"""
self.kwargs = kwargs
self.logger = kwargs.get('logger', logging.getLogger(__name__))
self.read_only = kwargs.get('read_only', False)
self.adapter_supports_queries = True
self.base_query = None

Expand Down
6 changes: 3 additions & 3 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def train(self, conversation):
)

statement_history.append(statement)
self.storage.update(statement, force=True)
self.storage.update(statement)


class ChatterBotCorpusTrainer(Trainer):
Expand Down Expand Up @@ -204,7 +204,7 @@ def train(self):
for _ in range(0, 10):
statements = self.get_statements()
for statement in statements:
self.storage.update(statement, force=True)
self.storage.update(statement)


class UbuntuCorpusTrainer(Trainer):
Expand Down Expand Up @@ -345,4 +345,4 @@ def train(self):
)

statement_history.append(statement)
self.storage.update(statement, force=True)
self.storage.update(statement)
29 changes: 0 additions & 29 deletions tests/storage_adapter_tests/integration_tests/base.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from tests.base_case import ChatBotTestCase
from .base import StorageIntegrationTests


class JsonStorageIntegrationTests(StorageIntegrationTests, ChatBotTestCase):
pass
class JsonStorageIntegrationTests(ChatBotTestCase):

def test_database_is_updated(self):
"""
Test that the database is updated when read_only is set to false.
"""
input_text = 'What is the airspeed velocity of an unladen swallow?'
exists_before = self.chatbot.storage.find(input_text)

response = self.chatbot.get_response(input_text)
exists_after = self.chatbot.storage.find(input_text)

self.assertFalse(exists_before)
self.assertTrue(exists_after)
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from tests.base_case import ChatBotMongoTestCase
from .base import StorageIntegrationTests


class MongoStorageIntegrationTests(StorageIntegrationTests, ChatBotMongoTestCase):
pass
class MongoStorageIntegrationTests(ChatBotMongoTestCase):

def test_database_is_updated(self):
"""
Test that the database is updated when read_only is set to false.
"""
input_text = 'What is the airspeed velocity of an unladen swallow?'
exists_before = self.chatbot.storage.find(input_text)

response = self.chatbot.get_response(input_text)
exists_after = self.chatbot.storage.find(input_text)

self.assertFalse(exists_before)
self.assertTrue(exists_after)
30 changes: 0 additions & 30 deletions tests/storage_adapter_tests/test_json_file_storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,33 +404,3 @@ def test_order_by_created_at(self):
self.assertEqual(len(results), 2)
self.assertEqual(results[0], statement_a)
self.assertEqual(results[1], statement_b)


class ReadOnlyJsonFileStorageAdapterTestCase(JsonAdapterTestCase):

def test_update_does_not_add_new_statement(self):
self.adapter.read_only = True

statement = Statement("New statement")
self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(statement_found, None)

def test_update_does_not_modify_existing_statement(self):
statement = Statement("New statement")
self.adapter.update(statement)

self.adapter.read_only = True

statement.add_response(
Response("New response")
)

self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(statement_found.text, statement.text)
self.assertEqual(
len(statement_found.in_response_to), 0
)
31 changes: 0 additions & 31 deletions tests/storage_adapter_tests/test_mongo_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,34 +430,3 @@ def test_order_by_created_at(self):
self.assertEqual(len(results), 2)
self.assertEqual(results[0], statement_a)
self.assertEqual(results[1], statement_b)


class ReadOnlyMongoDatabaseAdapterTestCase(MongoAdapterTestCase):

def test_update_does_not_add_new_statement(self):
self.adapter.read_only = True

statement = Statement("New statement")
self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(statement_found, None)

def test_update_does_not_modify_existing_statement(self):
statement = Statement("New statement")
self.adapter.update(statement)

self.adapter.read_only = True

statement.add_response(
Response("New response")
)
self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(
statement_found.text, statement.text
)
self.assertEqual(
len(statement_found.in_response_to), 0
)