Skip to content

Commit

Permalink
Add tests for file extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Nov 24, 2016
1 parent 13dda38 commit 47bfb45
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 31 deletions.
89 changes: 71 additions & 18 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def train(self, *args, **kwargs):
"""
raise self.TrainerInitializationException()

def get_or_create(self, statement_text):
"""
Return a statement if it exists.
Create and return the statement if it does not exist.
"""
statement = self.storage.find(statement_text)

if not statement:
statement = Statement(statement_text)

return statement

class TrainerInitializationException(Exception):
"""
Exception raised when a base class has not overridden
Expand Down Expand Up @@ -59,18 +71,6 @@ class ListTrainer(Trainer):
where the list represents a conversation.
"""

def get_or_create(self, statement_text):
"""
Return a statement if it exists.
Create and return the statement if it does not exist.
"""
statement = self.storage.find(statement_text)

if not statement:
statement = Statement(statement_text)

return statement

def train(self, conversation):
"""
Train the chat bot based on the provided list of
Expand Down Expand Up @@ -217,6 +217,11 @@ def __init__(self, storage, **kwargs):
super(UbuntuCorpusTrainer, self).__init__(storage, **kwargs)
import os

self.data_download_url = kwargs.get(
'ubuntu_corpus_data_download_url',
'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'
)

self.data_directory = kwargs.get(
'ubuntu_corpus_data_directory',
'./data/'
Expand All @@ -241,6 +246,7 @@ def download(self, url, show_status=True):

# Do not download the data if it already exists
if os.path.exists(file_path):
self.logger.info('File is already downloaded')
return file_path

with open(file_path, 'wb') as open_file:
Expand Down Expand Up @@ -268,28 +274,75 @@ def extract(self, file_path):
"""
Extract a tar file at the specified file path.
"""
import os
import tarfile

dir_name = os.path.split(file_path)[-1].split('.')[0]

extracted_file_directory = os.path.join(
self.data_directory,
dir_name
)

# Do not extract if the extracted directory already exists
if os.path.isdir(extracted_file_directory):
return False

self.logger.info('Starting file extraction')

def track_progress(members):
for member in members:
# this will be the current file being extracted
yield member
print('Extracting {}'.format(member))
print('Extracting {}'.format(member.path))

with tarfile.open(file_path) as tar:
tar.extractall(path=self.data_directory, members=track_progress(tar))

self.logger.info('File extraction complete')

return self.data_directory
return True

def train(self):
import glob
import csv
import os

# Download and extract the Ubuntu dialog corpus
corpus_download_path = self.download(self.data_download_url)

self.extract(corpus_download_path)

extracted_corpus_path = os.path.join(
self.data_directory,
os.path.split(corpus_download_path)[-1].split('.')[0],
'**', '*.tsv'
)

for file in glob.iglob(extracted_corpus_path):
self.logger.info('Training from: {}'.format(file))

with open(file, 'r') as tsv:
reader = csv.reader(tsv, delimiter='\t')

statement_history = []

for row in reader:
if len(row) > 0:
text = row[3]
statement = self.get_or_create(text)
print(text, len(row))

statement.add_extra_data('datetime', row[0])
statement.add_extra_data('speaker', row[1])

if row[2].strip():
statement.add_extra_data('addressing_speaker', row[2])

# data_directory = self.extract('C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs.tgz')
data_directory = 'C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs/dialogs/'
if statement_history:
statement.add_response(
Response(statement_history[-1].text)
)

for file in glob.glob(data_directory):
print('file:', file)
statement_history.append(statement)
self.storage.update(statement, force=True)
45 changes: 32 additions & 13 deletions tests/training_tests/test_ubuntu_corpus_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ def _create_test_corpus(self):
Create a small tar in a similar format to the
Ubuntu corpus file in memory for testing.
"""
file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar')
file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_dialogs.tgz')
tar = tarfile.TarFile(file_path, 'w')

data1 = (
b'2004-11-04T16:49:00.000Z tom jane : Hello\n' +
b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' +
b'2004-11-04T16:49:00.000Z jane tom I am good' +
b'2004-11-04T16:49:00.000Z tom jane Hello\n' +
b'2004-11-04T16:49:00.000Z tom jane Is anyone there?\n' +
b'2004-11-04T16:49:00.000Z jane Yes\n' +
b'\n'
)

data2 = (
b'2004-11-04T16:49:00.000Z tom jane : Hello\n' +
b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' +
b'2004-11-04T16:49:00.000Z jane tom I am good' +
b'2004-11-04T16:49:00.000Z tom jane Hello\n' +
b'2004-11-04T16:49:00.000Z tom Is anyone there?\n' +
b'2004-11-04T16:49:00.000Z jane Yes\n' +
b'\n'
)

Expand All @@ -68,7 +68,7 @@ def _destroy_test_corpus(self):
"""
Remove the test corpus file.
"""
file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_corpus.tar')
file_path = os.path.join(self.chatbot.trainer.data_directory, 'ubuntu_dialogs.tgz')
os.remove(file_path)

def _mock_get_response(self, *args, **kwargs):
Expand All @@ -88,7 +88,7 @@ def test_download(self):
import requests

requests.get = Mock(side_effect=self._mock_get_response)
download_url = 'https://example.com/download.tar'
download_url = 'https://example.com/download.tgz'
self.chatbot.trainer.download(download_url, show_status=False)

file_name = download_url.split('/')[-1]
Expand All @@ -106,19 +106,19 @@ def test_download_file_exists(self):
"""
import requests

file_path = os.path.join(self.chatbot.trainer.data_directory, 'download.tar')
file_path = os.path.join(self.chatbot.trainer.data_directory, 'download.tgz')
open(file_path, 'a').close()

requests.get = Mock(side_effect=self._mock_get_response)
download_url = 'https://example.com/download.tar'
download_url = 'https://example.com/download.tgz'
self.chatbot.trainer.download(download_url, show_status=False)

# Remove the dummy download_url
os.remove(file_path)

self.assertFalse(requests.get.called)

def test_download_url_does_not_exist(self):
def test_download_url_not_found(self):
"""
Test the case that the url being downloaded does not exist.
"""
Expand All @@ -137,8 +137,27 @@ def test_extract(self):
self.assertTrue(os.path.exists(os.path.join(corpus_path, '1.tsv')))
self.assertTrue(os.path.exists(os.path.join(corpus_path, '2.tsv')))

def test_already_extracted(self):
"""
Test that extraction is only done if the compressed file
has not already been extracted.
"""
file_object_path = self._create_test_corpus()
created = self.chatbot.trainer.extract(file_object_path)
not_created = self.chatbot.trainer.extract(file_object_path)
self._destroy_test_corpus()

self.assertTrue(created)
self.assertFalse(not_created)

def test_train(self):
"""
Test that the chat bot is trained using data from the Ubuntu Corpus.
"""
pass
path = self._create_test_corpus()

self.chatbot.train()
self._destroy_test_corpus()

response = self.chatbot.get_response('Is anyone there?')
self.assertEqual(response, 'Yes')

0 comments on commit 47bfb45

Please sign in to comment.