Skip to content

Commit

Permalink
Add tar file extraction tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Nov 15, 2016
1 parent f962d9d commit 99c5624
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
6 changes: 2 additions & 4 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,20 +263,18 @@ def extract(self, file_path):

self.logger.info('Starting file extraction')

extracted_directory_path = ''

def track_progress(members):
for member in members:
# this will be the current file being extracted
yield member
print('Extracting {}'.format(member))

with tarfile.open(file_path) as tar:
tar.extractall(members=track_progress(tar))
tar.extractall(path=self.data_directory, members=track_progress(tar))

self.logger.info('File extraction complete')

return extracted_directory_path
return self.data_directory

def train(self):
import glob
Expand Down
22 changes: 19 additions & 3 deletions tests/training_tests/test_ubuntu_corpus_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ def setUp(self):

def tearDown(self):
super(UbuntuCorpusTrainerTestCase, self).tearDown()
import shutil

# Clean up by removing the corpus data directory
os.removedirs(self.chatbot.trainer.data_directory)
if os.path.exists(self.chatbot.trainer.data_directory):
shutil.rmtree(self.chatbot.trainer.data_directory)

def _create_test_corpus(self):
"""
Create a small tar in a similar format to the
Ubuntu corpus file in memory for testing.
"""
tar = tarfile.TarFile('ubuntu_corpus.tar', 'w')
file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar')
tar = tarfile.TarFile(file_path, 'w')

data1 = (
b'2004-11-04T16:49:00.000Z tom jane : Hello\n' +
Expand Down Expand Up @@ -59,7 +62,14 @@ def _create_test_corpus(self):
tsv2.close()
tar.close()

return os.path.realpath(tar.name)
return file_path

def _destroy_test_corpus(self):
"""
Remove the test corpus file.
"""
file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar')
os.remove(file_path)

def _mock_get_response(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -121,6 +131,12 @@ def test_extract(self):
file_object_path = self._create_test_corpus()
self.chatbot.trainer.extract(file_object_path)

self._destroy_test_corpus()
corpus_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_dialogs', '3')

self.assertTrue(os.path.exists(os.path.join(corpus_path, '1.tsv')))
self.assertTrue(os.path.exists(os.path.join(corpus_path, '2.tsv')))

def test_train(self):
"""
Test that the chat bot is trained using data from the Ubuntu Corpus.
Expand Down

0 comments on commit 99c5624

Please sign in to comment.