Skip to content

Commit

Permalink
multiprocessing for wiki
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 30, 2020
1 parent ddbde75 commit 402d625
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 45 deletions.
14 changes: 10 additions & 4 deletions scripts/datasets/pretrain_corpus/prepare_openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def extract_files(full_name, output_dir, shuffle=False):
"""
if not full_name.endswith(".xz"):
return
file_prefix = re.split('\.|/', full_name)[-2]
with open("{}.txt".format(os.path.join(output_dir, file_prefix)),"w") as fp:
file_prefix = re.split(r'\.|/', full_name)[-2]
file_prefix = file_prefix.replace('urlsf_subset', 'openwebtext-prepared-')
with open("{}.txt".format(os.path.join(output_dir, file_prefix)), "w") as fp:
with tarfile.open(full_name) as t:
txt_names = t.getnames()
if shuffle:
Expand All @@ -63,7 +64,7 @@ def extract_files(full_name, output_dir, shuffle=False):
# skip empty line
line = line.strip()
if line:
fp.write(line.decode()+'\n')
fp.write(line.decode() + '\n')
# Two extra line break to mark the document separation
fp.write('\n')

Expand All @@ -80,7 +81,12 @@ def main(args):
print('Start extracting {} files with {} cores'.format(len(fnames), num_process))
start_time = time.time()
with multiprocessing.Pool(num_process) as pool:
iter = pool.imap(functools.partial(extract_files, output_dir=args.output, shuffle=args.shuffle), fnames)
iter = pool.imap(
functools.partial(
extract_files,
output_dir=args.output,
shuffle=args.shuffle),
fnames)
for f_index, _ in enumerate(iter):
if f_index > 0 and f_index % 250 == 0:
elapsed = time.time() - start_time
Expand Down
134 changes: 94 additions & 40 deletions scripts/datasets/pretrain_corpus/prepare_wikipedia.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Prepare the Wikipedia dataset that contain cleaned articles of all languages."""
import os
import sys
import argparse
import glob
import math
import time
import tarfile
import argparse
import multiprocessing

from gluonnlp.registry import DATA_MAIN_REGISTRY, DATA_PARSER_REGISTRY
from gluonnlp.utils.misc import download, load_checksum_stats
from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY

_CITATION = """\
@ONLINE {wikidump,
Expand Down Expand Up @@ -52,9 +57,10 @@

_URLS = {
'wikipedia-en-20200620':
'https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikicorpus_one_article_per_line_20200620.txt',
'https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz',
}


def get_url(lang, date):
return _BASE_URL_TMPL.format(lang=lang, date=date)

Expand All @@ -71,42 +77,48 @@ def try_import_wikiextractor():
sha1_hash='3c4896a837b75c476d23c037e8d6c7fdfd9a29eb')
sys.path.append(_CURR_DIR)
import WikiExtractor
except:
except BaseException:
raise ImportError('Cannot import WikiExtractor! You can download the "WikiExtractor.py"'
' in https://github.com/attardi/wikiextractor to {}'
.format(_CURR_DIR))
return WikiExtractor


class WikicorpusTextFormatting:
def __init__(self, wiki_path, output_filename, recursive=False):
self.wiki_path = wiki_path
self.recursive = recursive
self.output_filename = output_filename

# This puts one article per line
def merge(self):
with open(self.output_filename, mode='w', newline='\n') as ofile:
for dirname in glob.glob(os.path.join(self.wiki_path, '*'), recursive=False):
for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=self.recursive):
print(filename)
article_lines = []
article_open = False

with open(filename, mode='r', newline='\n') as file:
for line in file:
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
article_open = False
for oline in article_lines[1:]:
if oline != '\n':
ofile.write(oline.rstrip() + " ")
ofile.write("\n\n")
article_lines = []
else:
if article_open:
article_lines.append(line)
def get_formatting_list(wiki_path, recursive=False):
"""
get formatting list of file names from extracted content
"""
filenames = []
for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False):
for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=recursive):
filenames.append(filename)
return filenames


def merge(x):
"""
Puts one article per line
"""
file_list, output_filename = x
article_lines = []
article_open = False

with open(output_filename, mode='w', newline='\n') as ofile:
for filename in file_list:
with open(filename, mode='r', newline='\n') as file:
for line in file:
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
article_open = False
for oline in article_lines[1:]:
if oline != '\n':
ofile.write(oline.rstrip() + " ")
ofile.write("\n\n")
article_lines = []
else:
if article_open:
article_lines.append(line)


@DATA_PARSER_REGISTRY.register('prepare_wikipedia')
Expand All @@ -132,8 +144,13 @@ def get_parser():
parser.add_argument("-o", "--output", default="wikicorpus",
help="directory for downloaded or formatted files")
parser.add_argument("-b", "--bytes", default="100M",
help="maximum bytes per output file (default %(default)s)",
help="maximum bytes per extracted file (default %(default)s)",
metavar="n[KMG]")
parser.add_argument("--num_process", type=int, default=8,
help="number of processes for multiprocessing")
parser.add_argument("--num_out_files", type=int, default=1000,
help="Number of desired output files, where each is processed"
" independently by a worker.")
return parser


Expand All @@ -153,38 +170,75 @@ def download_wikicorpus(lang, date, output):
return output_file


def format_wikicorpus(input, output, bytes):
def format_wikicorpus(input, output, bytes, num_process, num_out_files):
if input is None:
raise ValueError('input file is empty.')
if not input.endswith('xml.bz2'):
raise ValueError('input file not *.xml.bz2.')
if not os.path.exists(output):
os.makedirs(output)

# Use WikiExtractor to extract the content
WikiExtractor = try_import_wikiextractor()
wiki_path = os.path.join(output, 'extracted')
sys.argv = ['prog', '-b', bytes, '-o', wiki_path, input]
WikiExtractor.main()
output_filename = os.path.join(output, 'wikicorpus_one_article_per_line.txt')
wiki_formatter = WikicorpusTextFormatting(wiki_path, output_filename, recursive=True)
wiki_formatter.merge()

# Merge extracted content into txt files
prepared_path = os.path.join(output, 'prepared_wikipedia')
if not os.path.exists(prepared_path):
os.makedirs(prepared_path)
filenames = get_formatting_list(wiki_path, recursive=True)
num_files = len(filenames)
num_out_files = min(num_out_files, num_files)
file_volume = math.ceil(num_files / num_out_files)
splited_files = [filenames[i: i + file_volume] for i in range(0, num_files, file_volume)]
num_out_files = len(splited_files)
output_files = [
os.path.join(
prepared_path,
"wikipedia-prepared-{}.txt".format(
str(i).zfill(4))) for i in range(num_out_files)]
print("All prepared raw text will be saved in {} txt files".format(num_out_files))
num_process = min(num_process, num_out_files)
print('Start preprocessing {} text files with {} cores'.format(num_files, num_process))
process_args = [(splited_files[i], output_files[i]) for i in range(num_out_files)]

start_time = time.time()
with multiprocessing.Pool(num_process) as pool:
f_read = 0
for i, _ in enumerate(pool.imap(merge, process_args)):
elapsed = time.time() - start_time
f_read += len(splited_files[i])
print("prepared {:} files, Elapsed: {:.2f}s, ETA: {:.2f}s, ".format(
f_read, elapsed, (num_files - f_read) / (num_files / elapsed)))
print("Done preparation within {:.2f} seconds".format(elapsed))


@DATA_MAIN_REGISTRY.register('prepare_wikipedia')
def main(args):
num_process = min(multiprocessing.cpu_count(), args.num_process)
if args.mode == 'download':
download_wikicorpus(args.lang, args.date, args.output)
elif args.mode == 'format':
format_wikicorpus(args.input, args.output, args.bytes)
format_wikicorpus(args.input, args.output, args.bytes, num_process, args.num_out_files)
elif args.mode == 'download+format':
downloaded_file = download_wikicorpus(args.lang, args.date, args.output)
format_wikicorpus(downloaded_file, args.output, args.bytes)
format_wikicorpus(downloaded_file, args.output, args.bytes, num_process, args.num_out_files)
elif args.mode == 'download_prepared':
url = _URLS['wikipedia-en-20200620']
file_hash = _URL_FILE_STATS[url]
target_download_location = os.path.join(args.output,
os.path.basename(url))
download(url, target_download_location, sha1_hash=file_hash)
tar = tarfile.open(target_download_location)
names = tar.getnames()
print('Start unarchiving raw text files')
start_time = time.time()
for name in names:
tar.extract(name, path=args.output)
tar.close()
print("Done unarchiving within {:.2f} seconds".format(time.time() - start_time))
else:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion scripts/datasets/url_checksums/wikipedia.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikicorpus_one_article_per_line_20200620.txt 67825b9c721192acbf385816984ac8a250cf5216 13538212348
https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz 1e1d77c31622744aaa45ff5bfbfca397154d9186 5068070627

0 comments on commit 402d625

Please sign in to comment.