Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 30, 2020
1 parent 402d625 commit fa011aa
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions scripts/pretraining/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ def prepare_pretrain_text_dataset(
"""Create dataset based on the raw text files"""
if not isinstance(filenames, (list, tuple)):
filenames = [filenames]
# generate a filename based on the input filename ensuring no crash.
# filename example: urlsf_subset00-130_data.txt
suffix = re.findall(r'\d+-\d+', filenames[0])[0]
if cached_file_path:
output_file = os.path.join(cached_file_path, "owt-pretrain-record-{}.npz".format(suffix))
# generate a filename based on the input filename ensuring no crash.
# filename example: urlsf_subset00-130_data.txt
suffix = re.split(r'\.|/', filenames[0])[-2]
output_file = os.path.join(cached_file_path, "{}-pretrain-record.npz".format(suffix))
else:
output_file = None
np_features = get_all_features(
Expand Down
2 changes: 1 addition & 1 deletion scripts/pretraining/run_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def train(args):
'num_workers: {}, rank: {}'.format(
args.num_buckets, num_workers, rank))
if args.from_raw_text:
if not os.path.exists(args.cached_file_path):
if args.cached_file_path and not os.path.exists(args.cached_file_path):
os.mkdir(args.cached_file_path)
get_dataset_fn = functools.partial(get_pretrain_data_text,
max_seq_length=args.max_seq_length,
Expand Down

0 comments on commit fa011aa

Please sign in to comment.