From 47bfb4574ea1c19c5550c6e1dcc3958c82e88891 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Tue, 15 Nov 2016 20:04:58 -0500 Subject: [PATCH] Add tests for file extraction --- chatterbot/trainers.py | 89 +++++++++++++++---- .../test_ubuntu_corpus_training.py | 45 +++++++--- 2 files changed, 103 insertions(+), 31 deletions(-) diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index ee577734d..39c4896a2 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -17,6 +17,18 @@ def train(self, *args, **kwargs): """ raise self.TrainerInitializationException() + def get_or_create(self, statement_text): + """ + Return a statement if it exists. + Create and return the statement if it does not exist. + """ + statement = self.storage.find(statement_text) + + if not statement: + statement = Statement(statement_text) + + return statement + class TrainerInitializationException(Exception): """ Exception raised when a base class has not overridden @@ -59,18 +71,6 @@ class ListTrainer(Trainer): where the list represents a conversation. """ - def get_or_create(self, statement_text): - """ - Return a statement if it exists. - Create and return the statement if it does not exist. - """ - statement = self.storage.find(statement_text) - - if not statement: - statement = Statement(statement_text) - - return statement - def train(self, conversation): """ Train the chat bot based on the provided list of @@ -217,6 +217,11 @@ def __init__(self, storage, **kwargs): super(UbuntuCorpusTrainer, self).__init__(storage, **kwargs) import os + self.data_download_url = kwargs.get( + 'ubuntu_corpus_data_download_url', + 'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz' + ) + self.data_directory = kwargs.get( 'ubuntu_corpus_data_directory', './data/' @@ -241,6 +246,7 @@ def download(self, url, show_status=True): # Do not download the data if it already exists if os.path.exists(file_path): + self.logger.info('File is already downloaded') return file_path with open(file_path, 'wb') as open_file: @@ -268,28 +274,75 @@ def extract(self, file_path): """ Extract a tar file at the specified file path. """ + import os import tarfile + dir_name = os.path.split(file_path)[-1].split('.')[0] + + extracted_file_directory = os.path.join( + self.data_directory, + dir_name + ) + + # Do not extract if the extracted directory already exists + if os.path.isdir(extracted_file_directory): + return False + self.logger.info('Starting file extraction') def track_progress(members): for member in members: # this will be the current file being extracted yield member - print('Extracting {}'.format(member)) + print('Extracting {}'.format(member.path)) with tarfile.open(file_path) as tar: tar.extractall(path=self.data_directory, members=track_progress(tar)) self.logger.info('File extraction complete') - return self.data_directory + return True def train(self): import glob + import csv + import os + + # Download and extract the Ubuntu dialog corpus + corpus_download_path = self.download(self.data_download_url) + + self.extract(corpus_download_path) + + extracted_corpus_path = os.path.join( + self.data_directory, + os.path.split(corpus_download_path)[-1].split('.')[0], + '**', '*.tsv' + ) + + for file in glob.iglob(extracted_corpus_path): + self.logger.info('Training from: {}'.format(file)) + + with open(file, 'r') as tsv: + reader = csv.reader(tsv, delimiter='\t') + + statement_history = [] + + for row in reader: + if len(row) > 0: + text = row[3] + statement = self.get_or_create(text) + print(text, len(row)) + + statement.add_extra_data('datetime', row[0]) + statement.add_extra_data('speaker', row[1]) + + if row[2].strip(): + statement.add_extra_data('addressing_speaker', row[2]) - # data_directory = self.extract('C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs.tgz') - data_directory = 'C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs/dialogs/' + if statement_history: + statement.add_response( + Response(statement_history[-1].text) + ) - for file in glob.glob(data_directory): - print('file:', file) + statement_history.append(statement) + self.storage.update(statement, force=True) diff --git a/tests/training_tests/test_ubuntu_corpus_training.py b/tests/training_tests/test_ubuntu_corpus_training.py index 361d7d80b..fb2fd2eb6 100644 --- a/tests/training_tests/test_ubuntu_corpus_training.py +++ b/tests/training_tests/test_ubuntu_corpus_training.py @@ -30,20 +30,20 @@ def _create_test_corpus(self): Create a small tar in a similar format to the Ubuntu corpus file in memory for testing. """ - file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar') + file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_dialogs.tgz') tar = tarfile.TarFile(file_path, 'w') data1 = ( - b'2004-11-04T16:49:00.000Z tom jane : Hello\n' + - b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' + - b'2004-11-04T16:49:00.000Z jane tom I am good' + + b'2004-11-04T16:49:00.000Z tom jane Hello\n' + + b'2004-11-04T16:49:00.000Z tom jane Is anyone there?\n' + + b'2004-11-04T16:49:00.000Z jane Yes\n' + b'\n' ) data2 = ( - b'2004-11-04T16:49:00.000Z tom jane : Hello\n' + - b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' + - b'2004-11-04T16:49:00.000Z jane tom I am good' + + b'2004-11-04T16:49:00.000Z tom jane Hello\n' + + b'2004-11-04T16:49:00.000Z tom Is anyone there?\n' + + b'2004-11-04T16:49:00.000Z jane Yes\n' + b'\n' ) @@ -68,7 +68,7 @@ def _destroy_test_corpus(self): """ Remove the test corpus file. """ - file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar') + file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_dialogs.tgz') os.remove(file_path) def _mock_get_response(self, *args, **kwargs): @@ -88,7 +88,7 @@ def test_download(self): import requests requests.get = Mock(side_effect=self._mock_get_response) - download_url = 'https://example.com/download.tar' + download_url = 'https://example.com/download.tgz' self.chatbot.trainer.download(download_url, show_status=False) file_name = download_url.split('/')[-1] @@ -106,11 +106,11 @@ def test_download_file_exists(self): """ import requests - file_path = os.path.join(self.chatbot.trainer.data_directory, 'download.tar') + file_path = os.path.join(self.chatbot.trainer.data_directory, 'download.tgz') open(file_path, 'a').close() requests.get = Mock(side_effect=self._mock_get_response) - download_url = 'https://example.com/download.tar' + download_url = 'https://example.com/download.tgz' self.chatbot.trainer.download(download_url, show_status=False) # Remove the dummy download_url @@ -118,7 +118,7 @@ def test_download_file_exists(self): self.assertFalse(requests.get.called) - def test_download_url_does_not_exist(self): + def test_download_url_not_found(self): """ Test the case that the url being downloaded does not exist. """ @@ -137,8 +137,27 @@ def test_extract(self): self.assertTrue(os.path.exists(os.path.join(corpus_path, '1.tsv'))) self.assertTrue(os.path.exists(os.path.join(corpus_path, '2.tsv'))) + def test_already_extracted(self): + """ + Test that extraction is only done if the compressed file + has not already been extracted. + """ + file_object_path = self._create_test_corpus() + created = self.chatbot.trainer.extract(file_object_path) + not_created = self.chatbot.trainer.extract(file_object_path) + self._destroy_test_corpus() + + self.assertTrue(created) + self.assertFalse(not_created) + def test_train(self): """ Test that the chat bot is trained using data from the Ubuntu Corpus. """ - pass + path = self._create_test_corpus() + + self.chatbot.train() + self._destroy_test_corpus() + + response = self.chatbot.get_response('Is anyone there?') + self.assertEqual(response, 'Yes')