From ae3f5170ce3c1ca99dc05a0df62af21b49d4f94f Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Tue, 8 Nov 2016 19:41:53 -0500 Subject: [PATCH] Add training class for Ubuntu corpus --- .gitignore | 1 + chatterbot/trainers.py | 70 +++++++++++++- examples/ubuntu_corpus_training_example.py | 23 +++++ .../test_ubuntu_corpus_training.py | 92 +++++++++++++++++++ 4 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 examples/ubuntu_corpus_training_example.py create mode 100644 tests/training_tests/test_ubuntu_corpus_training.py diff --git a/.gitignore b/.gitignore index 9a0320637..116ddb6e9 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ docs/_build/ *.iml examples/settings.py +examples/ubuntu_dialogs* diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 33ee38a3a..8ba5ca960 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -166,7 +166,75 @@ def get_statements(self): return statements def train(self): - for i in range(0, 10): + for _ in range(0, 10): statements = self.get_statements() for statement in statements: self.storage.update(statement, force=True) + + +class UbuntuCorpusTrainer(Trainer): + """ + Allow chatbots to be trained with the data from + the Ubuntu Dialog Corpus. + """ + + def download(self, url, show_status=True): + """ + Download a file from the given url. + Show a progress indicator for the download status. + Based on: http://stackoverflow.com/a/15645088/1547223 + """ + import sys + import requests + + file_name = 'download.data' + with open(file_name, 'wb') as open_file: + print('Downloading %s' % file_name) + response = requests.get(url, stream=True) + total_length = response.headers.get('content-length') + + if total_length is None: + # No content length header + open_file.write(response.content) + else: + downloadl = 0 + total_length = int(total_length) + for data in response.iter_content(chunk_size=4096): + downloadl += len(data) + open_file.write(data) + if show_status: + done = int(50 * downloadl / total_length) + sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done))) + sys.stdout.flush() + + def extract(self, file_path): + """ + Extract a tar file at the specified file path. + """ + import tarfile + + self.logger.info('Starting file extraction') + + extracted_directory_path = '' + + def track_progress(members): + for member in members: + # this will be the current file being extracted + yield member + print('Extracting {}'.format(member)) + + with tarfile.open(file_path) as tar: + tar.extractall(members=track_progress(tar)) + + self.logger.info('File extraction complete') + + return extracted_directory_path + + def train(self): + import glob + + # data_directory = self.extract('C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs.tgz') + data_directory = 'C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs/dialogs/' + + for file in glob.glob(data_directory): + print('file:', file) diff --git a/examples/ubuntu_corpus_training_example.py b/examples/ubuntu_corpus_training_example.py new file mode 100644 index 000000000..a1777f586 --- /dev/null +++ b/examples/ubuntu_corpus_training_example.py @@ -0,0 +1,23 @@ +from chatterbot import ChatBot +import logging + + +''' +This is an example showing how to train a chat bot using the +Ubuntu Corpus of conversation dialog. +''' + +# Enable info level logging +logging.basicConfig(level=logging.INFO) + +chatbot = ChatBot( + 'Example Bot', + trainer='chatterbot.trainers.UbuntuCorpusTrainer' +) + +# Start by training our bot with the Ubuntu corpus data +chatbot.train() + +# Now let's get a response to a greeting +response = chatbot.get_response('How are you doing today?') +print(response) \ No newline at end of file diff --git a/tests/training_tests/test_ubuntu_corpus_training.py b/tests/training_tests/test_ubuntu_corpus_training.py new file mode 100644 index 000000000..a7db3a1e3 --- /dev/null +++ b/tests/training_tests/test_ubuntu_corpus_training.py @@ -0,0 +1,92 @@ +from io import BytesIO +import tarfile +import os +from mock import Mock + +from tests.base_case import ChatBotTestCase +from chatterbot.trainers import UbuntuCorpusTrainer + + +class UbuntuCorpusTrainerTestCase(ChatBotTestCase): + """ + Test the Ubuntu Corpus trainer class. + """ + + def setUp(self): + super(UbuntuCorpusTrainerTestCase, self).setUp() + self.chatbot.set_trainer(UbuntuCorpusTrainer) + + def _create_test_corpus(self): + """ + Create a small tar in a similar format to the + Ubuntu corpus file in memory for testing. + """ + tar = tarfile.TarFile('ubuntu_corpus.tar', 'w') + + data1 = ( + b'2004-11-04T16:49:00.000Z tom jane : Hello\n' + + b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' + + b'2004-11-04T16:49:00.000Z jane tom I am good' + + b'\n' + ) + + data2 = ( + b'2004-11-04T16:49:00.000Z tom jane : Hello\n' + + b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' + + b'2004-11-04T16:49:00.000Z jane tom I am good' + + b'\n' + ) + + tsv1 = BytesIO(data1) + tsv2 = BytesIO(data2) + + tarinfo = tarfile.TarInfo('ubuntu_dialogs/3/1.tsv') + tarinfo.size = len(data1) + tar.addfile(tarinfo, fileobj=tsv1) + + tarinfo = tarfile.TarInfo('ubuntu_dialogs/3/2.tsv') + tarinfo.size = len(data2) + tar.addfile(tarinfo, fileobj=tsv2) + + tsv1.close() + tsv2.close() + tar.close() + + return os.path.realpath(tar.name) + + def test_download(self): + """ + Test the download function for the Ubuntu corpus trainer. + """ + import requests + + def mock_get_response(*args, **kwargs): + response = requests.Response() + response._content = b'Some response content' + response.headers['content-length'] = len(response.content) + return response + + requests.get = Mock(side_effect=mock_get_response) + download_url = 'https://example.com/download.tar' + # self.chatbot.trainer.requests.get = Mock() + self.chatbot.trainer.download(download_url, show_status=False) + requests.get.assert_called_with(download_url, stream=True) + + def test_download_does_not_exist(self): + """ + Test the case that the file being downloaded does not exist. + """ + pass + + def test_extract(self): + """ + Test the extraction of text from a decompressed Ubuntu Corpus file. + """ + file_object_path = self._create_test_corpus() + self.chatbot.trainer.extract(file_object_path) + + def test_train(self): + """ + Test that the chat bot is trained using data from the Ubuntu Corpus. + """ + pass