Skip to content

Commit

Permalink
Remove multiprocessing to prevent CI errors on Travis
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Apr 6, 2019
1 parent 8efc371 commit f002e39
Showing 1 changed file with 25 additions and 79 deletions.
104 changes: 25 additions & 79 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
import csv
import time
from multiprocessing import Pool, Manager
from dateutil import parser as date_parser
from chatterbot.conversation import Statement
from chatterbot.tagging import PosLemmaTagger
Expand Down Expand Up @@ -174,41 +173,6 @@ def train(self, *corpus_paths):
self.chatbot.storage.create_many(statements_to_create)


def read_file(files, queue, preprocessors, tagger):

statements_from_file = []

for tsv_file in files:
with open(tsv_file, 'r', encoding='utf-8') as tsv:
reader = csv.reader(tsv, delimiter='\t')

previous_statement_text = None
previous_statement_search_text = ''

for row in reader:
if len(row) > 0:
statement = Statement(
text=row[3],
in_response_to=previous_statement_text,
conversation='training',
created_at=date_parser.parse(row[0]),
persona=row[1]
)

for preprocessor in preprocessors:
statement = preprocessor(statement)

statement.search_text = tagger.get_bigram_pair_string(statement.text)
statement.search_in_response_to = previous_statement_search_text

previous_statement_text = statement.text
previous_statement_search_text = statement.search_text

statements_from_file.append(statement)

queue.put(tuple(statements_from_file))


class UbuntuCorpusTrainer(Trainer):
"""
Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus.
Expand Down Expand Up @@ -337,9 +301,6 @@ def train(self):
'**', '**', '*.tsv'
)

manager = Manager()
queue = manager.Queue()

def chunks(items, items_per_chunk):
for start_index in range(0, len(items), items_per_chunk):
end_index = start_index + items_per_chunk
Expand All @@ -349,55 +310,40 @@ def chunks(items, items_per_chunk):

file_groups = tuple(chunks(file_list, 10000))

argument_groups = tuple(
(
file_names,
queue,
self.chatbot.preprocessors,
tagger,
) for file_names in file_groups
)

pool_batches = chunks(argument_groups, 9)

total_batches = len(file_groups)
batch_number = 0

start_time = time.time()

with Pool() as pool:
for pool_batch in pool_batches:
pool.starmap(read_file, pool_batch)
for tsv_files in file_groups:

statements_from_file = []

while True:
for tsv_file in tsv_files:
with open(tsv_file, 'r', encoding='utf-8') as tsv:
reader = csv.reader(tsv, delimiter='\t')

if queue.empty():
break
previous_statement_text = None
previous_statement_search_text = ''

batch_number += 1
for row in reader:
if len(row) > 0:
statement = Statement(
text=row[3],
in_response_to=previous_statement_text,
conversation='training',
created_at=date_parser.parse(row[0]),
persona=row[1]
)

print('Training with batch {} with {} batches remaining...'.format(
batch_number,
total_batches - batch_number
))
for preprocessor in self.chatbot.preprocessors:
statement = preprocessor(statement)

self.chatbot.storage.create_many(queue.get())
statement.search_text = tagger.get_bigram_pair_string(statement.text)
statement.search_in_response_to = previous_statement_search_text

elapsed_time = time.time() - start_time
time_per_batch = elapsed_time / batch_number
remaining_time = time_per_batch * (total_batches - batch_number)
previous_statement_text = statement.text
previous_statement_search_text = statement.search_text

print('{:.0f} hours {:.0f} minutes {:.0f} seconds elapsed.'.format(
elapsed_time // 3600 % 24,
elapsed_time // 60 % 60,
elapsed_time % 60
))
statements_from_file.append(statement)

print('{:.0f} hours {:.0f} minutes {:.0f} seconds remaining.'.format(
remaining_time // 3600 % 24,
remaining_time // 60 % 60,
remaining_time % 60
))
print('---')
self.chatbot.storage.create_many(statements_from_file)

print('Training took', time.time() - start_time, 'seconds.')

0 comments on commit f002e39

Please sign in to comment.