diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 1bcb48795..4cd10d4ea 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -263,8 +263,6 @@ def extract(self, file_path): self.logger.info('Starting file extraction') - extracted_directory_path = '' - def track_progress(members): for member in members: # this will be the current file being extracted @@ -272,11 +270,11 @@ def track_progress(members): print('Extracting {}'.format(member)) with tarfile.open(file_path) as tar: - tar.extractall(members=track_progress(tar)) + tar.extractall(path=self.data_directory, members=track_progress(tar)) self.logger.info('File extraction complete') - return extracted_directory_path + return self.data_directory def train(self): import glob diff --git a/tests/training_tests/test_ubuntu_corpus_training.py b/tests/training_tests/test_ubuntu_corpus_training.py index e22343df8..361d7d80b 100644 --- a/tests/training_tests/test_ubuntu_corpus_training.py +++ b/tests/training_tests/test_ubuntu_corpus_training.py @@ -19,16 +19,19 @@ def setUp(self): def tearDown(self): super(UbuntuCorpusTrainerTestCase, self).tearDown() + import shutil # Clean up by removing the corpus data directory - os.removedirs(self.chatbot.trainer.data_directory) + if os.path.exists(self.chatbot.trainer.data_directory): + shutil.rmtree(self.chatbot.trainer.data_directory) def _create_test_corpus(self): """ Create a small tar in a similar format to the Ubuntu corpus file in memory for testing. """ - tar = tarfile.TarFile('ubuntu_corpus.tar', 'w') + file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar') + tar = tarfile.TarFile(file_path, 'w') data1 = ( b'2004-11-04T16:49:00.000Z tom jane : Hello\n' + @@ -59,7 +62,14 @@ def _create_test_corpus(self): tsv2.close() tar.close() - return os.path.realpath(tar.name) + return file_path + + def _destroy_test_corpus(self): + """ + Remove the test corpus file. + """ + file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar') + os.remove(file_path) def _mock_get_response(self, *args, **kwargs): """ @@ -121,6 +131,12 @@ def test_extract(self): file_object_path = self._create_test_corpus() self.chatbot.trainer.extract(file_object_path) + self._destroy_test_corpus() + corpus_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_dialogs', '3') + + self.assertTrue(os.path.exists(os.path.join(corpus_path, '1.tsv'))) + self.assertTrue(os.path.exists(os.path.join(corpus_path, '2.tsv'))) + def test_train(self): """ Test that the chat bot is trained using data from the Ubuntu Corpus.