Skip to content

Commit

Permalink
Add training class for Ubuntu corpus
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Nov 14, 2016
1 parent bf1fd63 commit ae3f517
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ docs/_build/
*.iml

examples/settings.py
examples/ubuntu_dialogs*
70 changes: 69 additions & 1 deletion chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,75 @@ def get_statements(self):
return statements

def train(self):
for i in range(0, 10):
for _ in range(0, 10):
statements = self.get_statements()
for statement in statements:
self.storage.update(statement, force=True)


class UbuntuCorpusTrainer(Trainer):
"""
Allow chatbots to be trained with the data from
the Ubuntu Dialog Corpus.
"""

def download(self, url, show_status=True):
"""
Download a file from the given url.
Show a progress indicator for the download status.
Based on: http://stackoverflow.com/a/15645088/1547223
"""
import sys
import requests

file_name = 'download.data'
with open(file_name, 'wb') as open_file:
print('Downloading %s' % file_name)
response = requests.get(url, stream=True)
total_length = response.headers.get('content-length')

if total_length is None:
# No content length header
open_file.write(response.content)
else:
downloadl = 0
total_length = int(total_length)
for data in response.iter_content(chunk_size=4096):
downloadl += len(data)
open_file.write(data)
if show_status:
done = int(50 * downloadl / total_length)
sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush()

def extract(self, file_path):
"""
Extract a tar file at the specified file path.
"""
import tarfile

self.logger.info('Starting file extraction')

extracted_directory_path = ''

def track_progress(members):
for member in members:
# this will be the current file being extracted
yield member
print('Extracting {}'.format(member))

with tarfile.open(file_path) as tar:
tar.extractall(members=track_progress(tar))

self.logger.info('File extraction complete')

return extracted_directory_path

def train(self):
import glob

# data_directory = self.extract('C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs.tgz')
data_directory = 'C:/Users/Gunther/GitHub/ChatterBot/examples/ubuntu_dialogs/dialogs/'

for file in glob.glob(data_directory):
print('file:', file)
23 changes: 23 additions & 0 deletions examples/ubuntu_corpus_training_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from chatterbot import ChatBot
import logging


'''
This is an example showing how to train a chat bot using the
Ubuntu Corpus of conversation dialog.
'''

# Enable info level logging
logging.basicConfig(level=logging.INFO)

chatbot = ChatBot(
'Example Bot',
trainer='chatterbot.trainers.UbuntuCorpusTrainer'
)

# Start by training our bot with the Ubuntu corpus data
chatbot.train()

# Now let's get a response to a greeting
response = chatbot.get_response('How are you doing today?')
print(response)
92 changes: 92 additions & 0 deletions tests/training_tests/test_ubuntu_corpus_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from io import BytesIO
import tarfile
import os
from mock import Mock

from tests.base_case import ChatBotTestCase
from chatterbot.trainers import UbuntuCorpusTrainer


class UbuntuCorpusTrainerTestCase(ChatBotTestCase):
"""
Test the Ubuntu Corpus trainer class.
"""

def setUp(self):
super(UbuntuCorpusTrainerTestCase, self).setUp()
self.chatbot.set_trainer(UbuntuCorpusTrainer)

def _create_test_corpus(self):
"""
Create a small tar in a similar format to the
Ubuntu corpus file in memory for testing.
"""
tar = tarfile.TarFile('ubuntu_corpus.tar', 'w')

data1 = (
b'2004-11-04T16:49:00.000Z tom jane : Hello\n' +
b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' +
b'2004-11-04T16:49:00.000Z jane tom I am good' +
b'\n'
)

data2 = (
b'2004-11-04T16:49:00.000Z tom jane : Hello\n' +
b'2004-11-04T16:49:00.000Z tom jane : Is anyone there?\n' +
b'2004-11-04T16:49:00.000Z jane tom I am good' +
b'\n'
)

tsv1 = BytesIO(data1)
tsv2 = BytesIO(data2)

tarinfo = tarfile.TarInfo('ubuntu_dialogs/3/1.tsv')
tarinfo.size = len(data1)
tar.addfile(tarinfo, fileobj=tsv1)

tarinfo = tarfile.TarInfo('ubuntu_dialogs/3/2.tsv')
tarinfo.size = len(data2)
tar.addfile(tarinfo, fileobj=tsv2)

tsv1.close()
tsv2.close()
tar.close()

return os.path.realpath(tar.name)

def test_download(self):
"""
Test the download function for the Ubuntu corpus trainer.
"""
import requests

def mock_get_response(*args, **kwargs):
response = requests.Response()
response._content = b'Some response content'
response.headers['content-length'] = len(response.content)
return response

requests.get = Mock(side_effect=mock_get_response)
download_url = 'https://example.com/download.tar'
# self.chatbot.trainer.requests.get = Mock()
self.chatbot.trainer.download(download_url, show_status=False)
requests.get.assert_called_with(download_url, stream=True)

def test_download_does_not_exist(self):
"""
Test the case that the file being downloaded does not exist.
"""
pass

def test_extract(self):
"""
Test the extraction of text from a decompressed Ubuntu Corpus file.
"""
file_object_path = self._create_test_corpus()
self.chatbot.trainer.extract(file_object_path)

def test_train(self):
"""
Test that the chat bot is trained using data from the Ubuntu Corpus.
"""
pass

0 comments on commit ae3f517

Please sign in to comment.