Skip to content

Commit

Permalink
added support for tfrecord from single file. (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
ablacklama authored Oct 18, 2021
1 parent 592f6ff commit 1ac79c1
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions create_finetune_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random

from pathlib import Path
from typing import List

import ftfy
import tensorflow as tf
Expand Down Expand Up @@ -31,7 +32,11 @@ def parse_args():
- this causes data loss if you have many .tfrecords files
- This is probably not appropriate for very large datasets
""", formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("input_dir", type=str, help="Path to where your files are located.")
parser.add_argument(
"input_path",
type=str,
help="Path to an input file, or a directory that contains the input files.",
)
parser.add_argument("name", type=str,
help="Name of output file will be {name}_{seqnum}.tfrecords, where seqnum is total sequence count")
parser.add_argument("--output-dir", type=str, default="", help="Output directory (default: current directory)")
Expand Down Expand Up @@ -65,17 +70,29 @@ def parse_args():

args = parser.parse_args()

if not args.input_dir.endswith("/"):
args.input_dir = args.input_dir + "/"
# convert input_path to pathy
args.input_path = Path(args.input_path)

return args


def get_files(input_dir):
filetypes = ["jsonl.zst", ".txt", ".xz", ".tar.gz"]
files = [list(Path(input_dir).glob(f"*{ft}")) for ft in filetypes]
# flatten list of list -> list and stringify Paths
return [str(item) for sublist in files for item in sublist]
def get_files(input_path: Path) -> List[str]:
supported_file_types = ["jsonl.zst", ".txt", ".xz", ".tar.gz"]
if input_path.is_dir():
# get all files with supported file types
files = [list(Path(input_path).glob(f"*{ft}")) for ft in supported_file_types]
# flatten list
files = [f for sublist in files for f in sublist]
assert files, f"No files with supported types found in directory: {input_path}"
elif input_path.is_file():
assert any(
str(input_path).endswith(f_type) for f_type in supported_file_types
), f"Input file type must be one of: {supported_file_types}"
files = [input_path]
else:
raise FileNotFoundError(f"No such file or directory: {input_path=}")

This comment has been minimized.

Copy link
@vfbd

vfbd Oct 18, 2021

Contributor

Shouldn't this be raise FileNotFoundError(f"No such file or directory: {input_path}")


return [str(f) for f in files]


def wikitext_detokenizer(string):
Expand Down Expand Up @@ -142,7 +159,7 @@ def split_list(l, n):


def enforce_min_unique(seqs, min_unique_tokens, enc, verbose=False):
for seq in tqdm(seqs, mininterval=1, smoothing=0):
for seq in tqdm(seqs, mininterval=1, smoothing=0, desc="enforce_min_unique_tokens"):
if len(set(seq)) >= min_unique_tokens:
yield seq
elif verbose:
Expand Down Expand Up @@ -199,7 +216,7 @@ def read_files_to_tokenized_docs(files, args, encoder):
else:
random.shuffle(files)

for f in tqdm(files, mininterval=10, smoothing=0):
for f in tqdm(files, mininterval=10, smoothing=0, desc="reading/tokenizing files"):
docs.extend(file_to_tokenized_docs_generator(f, encoder, args))

if not args.preserve_data_order:
Expand All @@ -221,8 +238,7 @@ def arrays_to_sequences(token_list_iterable, sequence_length=2049):

if len(accum) > sequence_length:
chunks = split_list(accum, sequence_length)
for chunk in chunks[:-1]:
yield chunk
yield from chunks[:-1]
accum = chunks[-1]

if len(accum) > 0:
Expand Down Expand Up @@ -285,6 +301,7 @@ def create_tfrecords(files, args):

if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
files = get_files(args.input_dir)
files = get_files(args.input_path)
print(f"Creating TFRecords from files: {files}")

results = create_tfrecords(files, args)

0 comments on commit 1ac79c1

Please sign in to comment.