From cc5d40e2f48b178d9663b25408030786f8f9f5d8 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Thu, 10 Nov 2016 19:29:13 -0500 Subject: [PATCH] Add Django management command to train chat bot --- chatterbot/chatterbot.py | 4 ++- .../django_chatterbot/management/__init__.py | 0 .../management/commands/__init__.py | 0 .../management/commands/train.py | 23 ++++++++++++++ chatterbot/trainers.py | 13 ++++++-- docs/django/index.rst | 1 + docs/django/settings.rst | 4 +++ docs/django/training.rst | 30 +++++++++++++++++++ examples/django_app/example_app/settings.py | 6 +++- examples/django_app/tests/test_commands.py | 21 +++++++++++++ setup.py | 2 ++ 11 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 chatterbot/ext/django_chatterbot/management/__init__.py create mode 100644 chatterbot/ext/django_chatterbot/management/commands/__init__.py create mode 100644 chatterbot/ext/django_chatterbot/management/commands/train.py create mode 100644 docs/django/training.rst create mode 100644 examples/django_app/tests/test_commands.py diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index 9acb345a8..ca0b6593e 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -2,7 +2,6 @@ from .adapters.logic import LogicAdapter, MultiLogicAdapter from .adapters.input import InputAdapter from .adapters.output import OutputAdapter -from .conversation import Statement, Response from .utils.queues import ResponseQueue from .utils.module_loading import import_module import logging @@ -67,6 +66,7 @@ def __init__(self, name, **kwargs): trainer = kwargs.get('trainer', 'chatterbot.trainers.Trainer') TrainerClass = import_module(trainer) self.trainer = TrainerClass(self.storage, **kwargs) + self.training_data = kwargs.get('training_data') self.logger = kwargs.get('logger', logging.getLogger(__name__)) @@ -228,6 +228,8 @@ def learn_response(self, statement): """ Learn that the statement provided is a valid response. """ + from .conversation import Response + previous_statement = self.get_last_response_statement() if previous_statement: diff --git a/chatterbot/ext/django_chatterbot/management/__init__.py b/chatterbot/ext/django_chatterbot/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chatterbot/ext/django_chatterbot/management/commands/__init__.py b/chatterbot/ext/django_chatterbot/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chatterbot/ext/django_chatterbot/management/commands/train.py b/chatterbot/ext/django_chatterbot/management/commands/train.py new file mode 100644 index 000000000..308aeeb55 --- /dev/null +++ b/chatterbot/ext/django_chatterbot/management/commands/train.py @@ -0,0 +1,23 @@ +from django.core.management.base import BaseCommand + + +class Command(BaseCommand): + help = 'Trains the database used by the chat bot' + can_import_settings = True + + def handle(self, *args, **options): + from chatterbot import ChatBot + from chatterbot.ext.django_chatterbot import settings + + chatterbot = ChatBot(**settings.CHATTERBOT) + + chatterbot.train(chatterbot.training_data) + + # Django 1.8 does not define SUCCESS + if hasattr(self.style, 'SUCCESS'): + style = self.style.SUCCESS + else: + style = self.style.NOTICE + + training_class = chatterbot.trainer.__class__.__name__ + self.stdout.write(style('ChatterBot trained using "%s"' % training_class)) \ No newline at end of file diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 33ee38a3a..c5ffae78d 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -1,5 +1,4 @@ from .conversation import Statement, Response -from .corpus import Corpus import logging @@ -7,7 +6,6 @@ class Trainer(object): def __init__(self, storage, **kwargs): self.storage = storage - self.corpus = Corpus() self.logger = logging.getLogger(__name__) def train(self, *args, **kwargs): @@ -76,9 +74,20 @@ def train(self, conversation): class ChatterBotCorpusTrainer(Trainer): + def __init__(self, storage, **kwargs): + super(ChatterBotCorpusTrainer, self).__init__(storage, **kwargs) + from .corpus import Corpus + + self.corpus = Corpus() + def train(self, *corpora): trainer = ListTrainer(self.storage) + # Allow a list of coupora to be passed instead of arguments + if len(corpora) == 1: + if isinstance(corpora[0], list): + corpora = corpora[0] + for corpus in corpora: corpus_data = self.corpus.load_corpus(corpus) for data in corpus_data: diff --git a/docs/django/index.rst b/docs/django/index.rst index aef74c7fc..834ddce4c 100644 --- a/docs/django/index.rst +++ b/docs/django/index.rst @@ -10,6 +10,7 @@ Django applications. :maxdepth: 2 settings + training views Installation diff --git a/docs/django/settings.rst b/docs/django/settings.rst index 9f4b1e92d..5cee6fa64 100644 --- a/docs/django/settings.rst +++ b/docs/django/settings.rst @@ -12,6 +12,10 @@ You can edit the ChatterBot configuration through your Django settings.py file. 'chatterbot.adapters.logic.MathematicalEvaluation', 'chatterbot.adapters.logic.TimeLogicAdapter', 'chatterbot.adapters.logic.ClosestMatchAdapter' + ], + 'trainer': 'chatterbot.trainers.ChatterBotCorpusTrainer', + 'training_data': [ + 'chatterbot.corpus.english.greetings' ] } diff --git a/docs/django/training.rst b/docs/django/training.rst new file mode 100644 index 000000000..3d2fcd35e --- /dev/null +++ b/docs/django/training.rst @@ -0,0 +1,30 @@ +======== +Training +======== + +Management command +================== + +When using ChatterBot with Django, the training process can be +executed by running the training management command. + +.. code-block:: bash + + python manage.py train + +Training settings +================= + +You can specify any data that you want to be passed to the chat bot +trainer in the :code:`training_data` parameter in your :code:`CHATTERBOT` +Django settings. + +.. code-block:: python + + CHATTERBOT = { + # ... + 'trainer': 'chatterbot.trainers.ChatterBotCorpusTrainer', + 'training_data': [ + 'chatterbot.corpus.english.greetings' + ] + } \ No newline at end of file diff --git a/examples/django_app/example_app/settings.py b/examples/django_app/example_app/settings.py index 872d0baa6..6e81e4ace 100644 --- a/examples/django_app/example_app/settings.py +++ b/examples/django_app/example_app/settings.py @@ -33,7 +33,11 @@ # ChatterBot settings CHATTERBOT = { - 'name': 'Django ChatterBot Example' + 'name': 'Django ChatterBot Example', + 'trainer': 'chatterbot.trainers.ChatterBotCorpusTrainer', + 'training_data': [ + 'chatterbot.corpus.english.greetings' + ] } MIDDLEWARE_CLASSES = ( diff --git a/examples/django_app/tests/test_commands.py b/examples/django_app/tests/test_commands.py new file mode 100644 index 000000000..f074edbda --- /dev/null +++ b/examples/django_app/tests/test_commands.py @@ -0,0 +1,21 @@ +from django.core.management import call_command +from django.test import TestCase +from django.utils.six import StringIO +from chatterbot.ext.django_chatterbot.models import Statement + + +class TrainCommandTestCase(TestCase): + + def test_command_output(self): + out = StringIO() + call_command('train', stdout=out) + self.assertIn('ChatterBot trained', out.getvalue()) + + def test_command_data_argument(self): + out = StringIO() + statements_before = Statement.objects.exists() + call_command('train', stdout=out) + statements_after = Statement.objects.exists() + + self.assertFalse(statements_before) + self.assertTrue(statements_after) \ No newline at end of file diff --git a/setup.py b/setup.py index 4dba9ac16..090737b9c 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,8 @@ 'chatterbot.ext', 'chatterbot.ext.django_chatterbot', 'chatterbot.ext.django_chatterbot.migrations', + 'chatterbot.ext.django_chatterbot.management', + 'chatterbot.ext.django_chatterbot.management.commands', 'chatterbot.utils' ], package_dir={'chatterbot': 'chatterbot'},