Skip to content

Commit

Permalink
Fast preprocessing by allowing workers to write directly on disk inst…
Browse files Browse the repository at this point in the history
…ead of sending processed inputs to main process
  • Loading branch information
Thomas committed Jul 25, 2021
1 parent 8d10378 commit 0e8287c
Showing 1 changed file with 172 additions and 49 deletions.
221 changes: 172 additions & 49 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
"""Processing data for pretraining."""

import argparse
import collections
import itertools
import json
import multiprocessing
import os
import sys
from multiprocessing.connection import Connection

from megatron.data.indexed_dataset import index_file_path, data_file_path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
Expand Down Expand Up @@ -54,42 +60,103 @@ def tokenize(self, *text):

class Encoder(object):
def __init__(self, args):
self.args = args

def initializer(self):
self.json_keys = args.json_keys
self.append_eod = args.append_eod
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
self.tokenizer = build_tokenizer(args)
if args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
splitter = nltk.load("tokenizers/punkt/english.pickle")
if self.args.keep_newlines:
if args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
self.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = splitter
self.splitter = splitter

else:
Encoder.splitter = IdentitySplitter()
self.splitter = IdentitySplitter()

def encode(self, json_line):
data = json.loads(json_line)
ids = {}
for key in self.args.json_keys:
for key in self.json_keys:
text = data[key]
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
for sentence in self.splitter.tokenize(text):
sentence_ids = self.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
if len(doc_ids) > 0 and self.append_eod:
doc_ids[-1].append(self.tokenizer.eod)
ids[key] = doc_ids
return ids, len(json_line)


def process_samples(simple_queue, process_index, args, level, writer: Connection):
encoder = Encoder(args)

output_bin_files = {}
output_idx_files = {}
builders = {}
for key in args.json_keys:
output_filename = f"{args.output_prefix}_{key}_{level}_{process_index}"
output_bin_files[key] = data_file_path(output_filename)
output_idx_files[key] = index_file_path(output_filename)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=encoder.tokenizer.vocab_size)

json_lines = simple_queue.get()
while json_lines is not None:
try:
process_json_lines(json_lines, encoder, builders, writer)
except:
# Debugging code in order to understand why the encoder can fail
for json_line in json_lines:
try:
if json_line.strip() == "":
continue
encoder.encode(json_line)
except:
print(repr(json_line))
print(json_line.strip() == "")
raise

json_lines = simple_queue.get()

# in case finished, we still need to add None to signal to everyone else
simple_queue.put(None)
# we need to send EOFError
writer.close()

for key in args.json_keys:
builders[key].finalize(output_idx_files[key])


def process_json_lines(json_lines, encoder, builders, writer):
total_bytes_processed = 0
for json_line in json_lines:
if json_line.strip() == "":
continue

doc, bytes_processed = encoder.encode(json_line)

total_bytes_processed += bytes_processed

for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()

writer.send((len(json_lines), total_bytes_processed))


def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
Expand Down Expand Up @@ -142,62 +209,118 @@ def get_args():

return args

def fill_simple_queue(filename, simple_queue, chunk_size:int):
with open(filename, "r") as f:
print("Start filling queue")
while True:
acc = tuple(itertools.islice(f, chunk_size))
if len(acc) == 0:
simple_queue.put(None)
return
simple_queue.put(acc)

def log(readers, log_interval):
proc_start = time.time()
total_bytes_processed = 0
doc_processed = 0
logged_docs = 0

# we want to compute a rolling average of bytes processed over last 10k documents (more or less)
bytes_queue_max_length = 10_000 // log_interval + 1
bytes_queue = collections.deque(maxlen= bytes_queue_max_length)
# we fill the queue with (start_time, 0)
bytes_queue.extend([(proc_start, total_bytes_processed)]*bytes_queue_max_length)

print("Start Logging")

while len(readers) != 0:
for r in multiprocessing.connection.wait(readers):
try:
nb_of_docs, bytes_processed = r.recv()
total_bytes_processed += bytes_processed
doc_processed += nb_of_docs
except EOFError:
r.close()
readers.remove(r)

if (doc_processed - logged_docs) >= log_interval:
logged_docs = doc_processed
current = time.time()
elapsed = current - proc_start

(old_start_time, old_bytes) = bytes_queue.popleft()
bytes_queue.append((current, total_bytes_processed))
mbs = (total_bytes_processed - old_bytes) / (current - old_start_time) / 1024 / 1024
print(f"Processed {doc_processed} documents",
f"({doc_processed / elapsed} docs/s, {mbs} MB/s).")

def main():
# multiprocessing.set_start_method('spawn')
args = get_args()
startup_start = time.time()

print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
# TODO: Figure out why it's impossible to create a thread safe iterator.
simple_queue = multiprocessing.Queue() # we can switch to other queue in order to bound the number of elements.
chunk_size = 25

if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)

encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
#encoded_docs = map(encoder.encode, fin)

level = "document"
if args.split_sentences:
level = "sentence"

assert args.workers > 2, "One worker is used for logging, one for filling the queue"
readers, writers = list(zip(*[multiprocessing.Pipe(duplex=False) for _ in range(args.workers - 2)]))
processes = [multiprocessing.Process(target=process_samples, args=(simple_queue, i, args, level, writer)) for i, writer in enumerate(writers)]
log_thread = multiprocessing.Process(target=log, args=(list(readers), args.log_interval))
fill_thread = multiprocessing.Process(target=fill_simple_queue, args=(args.input, simple_queue, chunk_size))

fill_thread.start()
log_thread.start()
for i, process in enumerate(processes):
process.start()

# We close the writable end of the pipe now to be sure that
# p is the only process which owns a handle for it. This
# ensures that when p closes its handle for the writable end,
# wait() will promptly report the readable end as being ready.
# https://docs.python.org/fr/3/library/multiprocessing.html#multiprocessing.connection.Connection
for writer in writers:
writer.close()

# fill_simple_queue(args.input, simple_queue, chunk_size)

fill_thread.join()
fill_thread.close()
for process in processes:
process.join()
process.close()
log_thread.join()
log_thread.close()

# TODO: this may be done after.
print("Merging files together")

tokenizer = build_tokenizer(args)

print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
output_filename = f"{args.output_prefix}_{key}_{level}"
output_bin_files[key] = data_file_path(output_filename)
output_idx_files[key] = index_file_path(output_filename)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)

startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)

for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)

for key in args.json_keys:
for process_index in range(len(processes)):
output_filename = f"{args.output_prefix}_{key}_{level}_{process_index}"
builders[key].merge_file_(output_filename)
builders[key].finalize(output_idx_files[key])

if __name__ == '__main__':
Expand Down

0 comments on commit 0e8287c

Please sign in to comment.