Skip to content

Commit

Permalink
Use preprocessors during training
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Jan 27, 2018
1 parent 19be583 commit b6ff1a2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
3 changes: 3 additions & 0 deletions chatterbot/chatterbot.py
Expand Up @@ -160,6 +160,9 @@ def set_trainer(self, training_class, **kwargs):
:param \**kwargs: Any parameters that should be passed to the training class.
"""
if 'chatbot' not in kwargs:
kwargs['chatbot'] = self

self.trainer = training_class(self.storage, **kwargs)

@property
Expand Down
31 changes: 25 additions & 6 deletions chatterbot/trainers.py
Expand Up @@ -2,7 +2,7 @@
import os
import sys
from .conversation import Statement, Response
from .utils import print_progress_bar
from . import utils


class Trainer(object):
Expand All @@ -11,13 +11,28 @@ class Trainer(object):
"""

def __init__(self, storage, **kwargs):
self.chatbot = kwargs.get('chatbot')
self.storage = storage
self.logger = logging.getLogger(__name__)
self.show_training_progress = kwargs.get('show_training_progress', True)

def get_preprocessed_statement(self, input_statement):
"""
Preprocess the input statement.
"""

# The chatbot is optional to prevent backwards-incompatible changes
if not self.chatbot:
return input_statement

for preprocessor in self.chatbot.preprocessors:
input_statement = preprocessor(self, input_statement)

return input_statement

def train(self, *args, **kwargs):
"""
This class must be overridden by a class the inherits from 'Trainer'.
This method must be overridden by a child class.
"""
raise self.TrainerInitializationException()

Expand All @@ -26,10 +41,14 @@ def get_or_create(self, statement_text):
Return a statement if it exists.
Create and return the statement if it does not exist.
"""
statement = self.storage.find(statement_text)
temp_statement = self.get_preprocessed_statement(
Statement(text=statement_text)
)

statement = self.storage.find(temp_statement.text)

if not statement:
statement = Statement(statement_text)
statement = Statement(temp_statement.text)

return statement

Expand Down Expand Up @@ -83,7 +102,7 @@ def train(self, conversation):

for conversation_count, text in enumerate(conversation):
if self.show_training_progress:
print_progress_bar(
utils.print_progress_bar(
'List Trainer',
conversation_count + 1, len(conversation)
)
Expand Down Expand Up @@ -128,7 +147,7 @@ def train(self, *corpus_paths):
for conversation_count, conversation in enumerate(corpus):

if self.show_training_progress:
print_progress_bar(
utils.print_progress_bar(
str(os.path.basename(corpus_files[corpus_count])) + ' Training',
conversation_count + 1,
len(corpus)
Expand Down
30 changes: 30 additions & 0 deletions tests/training_tests/test_training_preprocessors.py
@@ -0,0 +1,30 @@
from tests.base_case import ChatBotTestCase
from chatterbot import trainers
from chatterbot import preprocessors


class PreprocessorTrainingTests(ChatBotTestCase):
"""
These tests are designed to ensure that preprocessors
will be used to process the input the chat bot is given
during the training process.
"""

def test_training_cleans_whitespace(self):
"""
Test that the ``clean_whitespace`` preprocessor is used during
the training process.
"""
self.chatbot.preprocessors = [preprocessors.clean_whitespace]
self.chatbot.set_trainer(trainers.ListTrainer)

self.chatbot.train([
'Can I help you with anything?',
'No, I think I am all set.',
'Okay, have a nice day.',
'Thank you, you too.'
])

response = self.chatbot.get_response('Can I help you with anything?')

self.assertEqual(response.text, 'No, I think I am all set.')

0 comments on commit b6ff1a2

Please sign in to comment.