Skip to content

Commit

Permalink
Propose a faster preprocessing mechanim by reducing the interprocesse…
Browse files Browse the repository at this point in the history
…s communications
  • Loading branch information
Thomas committed Jul 28, 2021
1 parent 18201ce commit fac6e90
Showing 1 changed file with 322 additions and 0 deletions.
322 changes: 322 additions & 0 deletions tools/preprocess_data_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Processing data for pretraining. It's supposed to be a faster version compared to vanilla preprocess.py"""

import argparse
import collections
import itertools
import json
import multiprocessing
import os
import sys
import threading
from multiprocessing.connection import Connection

from megatron.data.indexed_dataset import index_file_path, data_file_path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time

import torch
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False

from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset


# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):

_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""

class IdentitySplitter(object):
def tokenize(self, *text):
return text

class Encoder(object):
def __init__(self, args):
self.json_keys = args.json_keys
self.append_eod = args.append_eod
# Use Encoder class as a container for global data
self.tokenizer = build_tokenizer(args)
if args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
splitter = nltk.load("tokenizers/punkt/english.pickle")
if args.keep_newlines:
# this prevents punkt from eating newlines after sentences
self.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
self.splitter = splitter

else:
self.splitter = IdentitySplitter()

def encode(self, json_line):
data = json.loads(json_line)
ids = {}
for key in self.json_keys:
text = data[key]
doc_ids = []
for sentence in self.splitter.tokenize(text):
sentence_ids = self.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.append_eod:
doc_ids[-1].append(self.tokenizer.eod)
ids[key] = doc_ids
return ids, len(json_line)


def process_samples(simple_queue, process_index, args, level, writer: Connection):
encoder = Encoder(args)

output_bin_files = {}
output_idx_files = {}
builders = {}
for key in args.json_keys:
output_filename = f"{args.output_prefix}_{key}_{level}_{process_index}"
output_bin_files[key] = data_file_path(output_filename)
output_idx_files[key] = index_file_path(output_filename)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=encoder.tokenizer.vocab_size)

json_lines = simple_queue.get()
while json_lines is not None:
try:
process_json_lines(json_lines, encoder, builders, writer)
except:
# Debugging code in order to understand why the encoder can fail
for json_line in json_lines:
try:
if json_line.strip() == "":
continue
encoder.encode(json_line)
except:
print(repr(json_line))
print(json_line.strip() == "")
raise

json_lines = simple_queue.get()

# in case finished, we still need to add None to signal to everyone else
simple_queue.put(None)
# we need to send EOFError
writer.close()

for key in args.json_keys:
builders[key].finalize(output_idx_files[key])


def process_json_lines(json_lines, encoder, builders, writer):
total_bytes_processed = 0
for json_line in json_lines:
if json_line.strip() == "":
continue

doc, bytes_processed = encoder.encode(json_line)

total_bytes_processed += bytes_processed

for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()

writer.send((len(json_lines), total_bytes_processed))


def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')

group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'PretrainedFromHF'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument("--tokenizer-name-or-path", type=str, default=None,
help="Name or path of the huggingface tokenizer.")

group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])

group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False

if args.tokenizer_type.lower().startswith('bert'):
if not args.split_sentences:
print("Bert tokenizer detected, are you sure you don't want to split sentences?")

# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0

return args

def fill_simple_queue(filename, simple_queue, chunk_size:int):
with open(filename, "r") as f:
print("Start filling queue")
while True:
acc = tuple(itertools.islice(f, chunk_size))
if len(acc) == 0:
simple_queue.put(None)
return
simple_queue.put(acc)

def log(readers, log_interval):
print("Start Logging")
proc_start = time.time()
total_bytes_processed = 0
doc_processed = 0
logged_docs = 0

# we want to compute a rolling average of bytes processed over last 10k documents (more or less)
bytes_queue_max_length = 10_000 // log_interval + 1
bytes_queue = collections.deque(maxlen= bytes_queue_max_length)
# we fill the queue with (start_time, 0)
bytes_queue.extend([(proc_start, total_bytes_processed)]*bytes_queue_max_length)

while len(readers) != 0:
for r in multiprocessing.connection.wait(readers):
try:
nb_of_docs, bytes_processed = r.recv()
total_bytes_processed += bytes_processed
doc_processed += nb_of_docs
except EOFError:
r.close()
readers.remove(r)

if (doc_processed - logged_docs) >= log_interval:
logged_docs = doc_processed
current = time.time()
elapsed = current - proc_start

(old_start_time, old_bytes) = bytes_queue.popleft()
bytes_queue.append((current, total_bytes_processed))
mbs = (total_bytes_processed - old_bytes) / (current - old_start_time) / 1024 / 1024
print(f"Processed {doc_processed} documents",
f"({doc_processed / elapsed} docs/s, {mbs} MB/s).")

def main():
args = get_args()

print("Opening", args.input)
simple_queue = multiprocessing.Queue() # we can also limit the number of elements to reduce the memory footprint.
chunk_size = 25

if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)

level = "document"
if args.split_sentences:
level = "sentence"

assert args.workers > 2, "One worker is used for logging, one for filling the queue"
readers, writers = list(zip(*[multiprocessing.Pipe(duplex=False) for _ in range(args.workers - 2)]))
processes = [multiprocessing.Process(target=process_samples, args=(simple_queue, i, args, level, writer)) for i, writer in enumerate(writers)]
log_thread = threading.Thread(target=log, args=(list(readers), args.log_interval))
fill_thread = multiprocessing.Process(target=fill_simple_queue, args=(args.input, simple_queue, chunk_size))

fill_thread.start()
log_thread.start()
for i, process in enumerate(processes):
process.start()

# We close the writable end of the pipe now to be sure that
# p is the only process which owns a handle for it. This
# ensures that when p closes its handle for the writable end,
# wait() will promptly report the readable end as being ready.
# https://docs.python.org/fr/3/library/multiprocessing.html#multiprocessing.connection.Connection
for writer in writers:
writer.close()

fill_thread.join()
fill_thread.close()
for process in processes:
process.join()
process.close()
log_thread.join()

# TODO: this may be done after.
print("Merging files together")

tokenizer = build_tokenizer(args)

print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in args.json_keys:
output_filename = f"{args.output_prefix}_{key}_{level}"
output_bin_files[key] = data_file_path(output_filename)
output_idx_files[key] = index_file_path(output_filename)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)

for key in args.json_keys:
for process_index in range(len(processes)):
output_filename = f"{args.output_prefix}_{key}_{level}_{process_index}"
builders[key].merge_file_(output_filename)
builders[key].finalize(output_idx_files[key])

if __name__ == '__main__':
main()

0 comments on commit fac6e90

Please sign in to comment.