In [80]:
import tensorflow as tf

import os
import glob

from typing import List, Optional

In [184]:
def get_all_files_with_specific_filetypes_in_a_directory(directory: str, filetypes: List[str]=None):
    if not filetypes:
        filetypes = ["jsonl.zst", "txt", "xz", "tar.gz"]
    
    files_with_specific_filetypes = []
    for filetype in filetypes:
        files_with_specific_filetypes.extend(glob.glob(directory + f"*.{filetype}"))
        
    return files_with_specific_filetypes

In [209]:
directory = "openwebtext2/"
files_with_jsonl_zst_filetype = get_all_files_with_specific_filetypes_in_a_directory(directory, ["jsonl.zst"])

In [222]:
directory = "tfrecords/"
files_with_tfrecords_filetype = get_all_files_with_specific_filetypes_in_a_directory(directory, ["tfrecords"])
len(files_with_tfrecords_filetype)

159

In [223]:
tfrecords_dataset = tf.data.TFRecordDataset(files_with_tfrecords_filetype, num_parallel_reads=tf.data.AUTOTUNE)

In [218]:
batch_size = 2

In [224]:
tfrecords_dataset = tfrecords_dataset.map(decode_fn, num_parallel_calls=tf.data.AUTOTUNE)

In [225]:
tfrecords_dataset = tfrecords_dataset.batch(batch_size, drop_remainder=True)  # batch must be *AFTER* map

In [227]:
tfrecords_dataset = tfrecords_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [240]:
type(tfrecords_dataset.as_numpy_iterator().next()['text']) == numpy.ndarray

True

In [231]:
for i, batch in enumerate(tfrecords_dataset.as_numpy_iterator()):
    print(type(batch["text"]), batch["text"].shape)
    if i>=2: break

<class 'numpy.ndarray'> (2, 1024)
<class 'numpy.ndarray'> (2, 1024)
<class 'numpy.ndarray'> (2, 1024)


In [113]:
import tempfile
import numpy as np

In [138]:
example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")

In [149]:
print([len(i) for i in splitted_chunks])

[1024, 1024, 1024]


In [220]:
# read the data back out
def decode_fn(record_bytes):
    return tf.io.parse_single_example(
        record_bytes,

        {
            "text": tf.io.FixedLenFeature((chunk_size,), dtype=tf.int64)
        }
    )

In [None]:
tfrecords_file_name = f"tfrecords/openwebtext2_0_{num_of_chunks_per_file}.tfrecords"

for batch in tf.data.TFRecordDataset([tfrecords_file_name]).map(decode_fn):
    print((batch))

In [8]:
from tokenizers import Tokenizer
from transformers import GPT2TokenizerFast

In [9]:
def construct_gpt2_tokenizer():
    return GPT2TokenizerFast.from_pretrained('gpt2')

In [10]:
tokenizer = construct_gpt2_tokenizer()

In [12]:
from lm_dataformat import Reader

In [18]:
file_with_jsonl_zst_filetype

'openwebtext2\\2005-06.jsonl.zst'

In [55]:
def tokenize(doc: str, token_that_separate_docs: int) -> List[int]:
    tokenized_doc = tokenizer.encode(doc)
    tokenized_doc.append(token_that_separate_docs)
    return tokenized_doc

In [75]:
def split_tokenized_doc_into_chunks(tokenized_doc: List[int], chunk_size: int) -> List[List[int]]:
    # the last chunk may be less than the chunk size
    splitted_chunks = [tokenized_doc[i:i + chunk_size] for i in range(0, len(tokenized_doc), chunk_size)]
    return splitted_chunks

In [188]:
token_that_separate_docs = 50256
chunk_size = 1024
minimized_chunk_size = 128
num_of_chunks_per_file = 100000

In [189]:
[len(i) for i in split_tokenized_doc_into_chunks(tokenized_doc, chunk_size)]

[1024, 1024, 496]

In [190]:
def prepend_the_last_chunk_to_next_doc_or_discard_it(last_chunk: List[int], minimized_chunk_size: int) -> List[int]:
    if len(last_chunk) >= minimized_chunk_size:
        return last_chunk
    return []

In [96]:
last_chunk = [1, 2] # splitted_chunks.pop(-1)

In [97]:
prepended_data = prepend_the_last_chunk_to_next_doc_or_discard_it(last_chunk, minimized_chunk_size)

In [157]:
Chunk = List[int]

In [171]:
def write_chunks_to_tfrecords_file(chunks: List[Chunk], tfrecords_file_index: int, num_of_chunks_per_file: int):
    tfrecords_file_name = f"tfrecords/openwebtext2_{tfrecords_file_index}_{num_of_chunks_per_file}.tfrecords"
    
    with tf.io.TFRecordWriter(tfrecords_file_name) as file_writer:
        for chunk in chunks:
            record_bytes = tf.train.Example(features = tf.train.Features(feature={
                "text": tf.train.Feature(int64_list=tf.train.Int64List(value=chunk))
            })).SerializeToString()
            file_writer.write(record_bytes)

In [None]:
chunks_to_be_written: List[Chunk] = []
prepended_data = []
tfrecords_file_index = 0


In [203]:

for file_with_jsonl_zst_filetype in files_with_jsonl_zst_filetype:
    reader = Reader(file_with_jsonl_zst_filetype)

    for doc in reader.stream_data(threaded=False):
        tokenized_doc = tokenize(doc, token_that_separate_docs)
        tokenized_doc += prepended_data
        splitted_chunks = split_tokenized_doc_into_chunks(tokenized_doc, chunk_size)
        last_chunk = splitted_chunks.pop(-1)
        prepended_data = prepend_the_last_chunk_to_next_doc_or_discard_it(last_chunk, minimized_chunk_size)

        chunks_to_be_written.extend(splitted_chunks)
        if len(chunks_to_be_written) >= num_of_chunks_per_file:
            write_chunks_to_tfrecords_file(chunks_to_be_written[:num_of_chunks_per_file], tfrecords_file_index, num_of_chunks_per_file)
            tfrecords_file_index += 1
            chunks_to_be_written = chunks_to_be_written[num_of_chunks_per_file:]
            print(tfrecords_file_index)

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159


In [204]:
len(prepended_data), len(chunks_to_be_written), tfrecords_file_index

(0, 8815, 159)

In [None]:
def tfrecord_parse(one_tfrecord):
    

In [21]:
list(tfrecords_dataset.take(1).as_numpy_iterator())

[b'\n\xe1\x0f\n\xde\x0f\n\x04text\x12\xd5\x0f\x1a\xd2\x0f\n\xcf\x0f\xc7\xe6\x01\xb1(\x9f\x02\xaf\x07\xfd\xab\x01\xbb\n\xb7\x97\x01\x83 \x9f\x02\xb4T\xc2\x05\xd7\x03\r2\r\x94/\xd2\x02\xe7\xae\x01\xa2\x1a\x804\x9aK\xfe\x1a\x81\x08\xf5\x02\xb1(\xc2\x05\xd4\x06\xaaY\xa0\x03\xec\x0c\xe3\x02\x81\x02\xed&\xf1\x04\x86\x02\xb1\'\r\xb5\xe8\x01\xe1Z\xf5d\x19\xa7\xc8\x01\xaf\x07\xfd\xab\x01\x94/\xdb+\xe0\x1e\xb2\x08\x9c\x02\xf5/\xad\x17\xdb+\xe6\'\x9e\x03\x81\x02\x98\x12\xc9\x02\xcd\x96\x01\xdb+\xd8E\xa2\x19\x81\x02\x86-\x9f\x02\xcas\x1e\xdb+\xf5\xa6\x01\x19\xe1W\xd0F\xc9\x02\x9d\xb6\x01\xd0\x81\x01\x9c\x03\x0c\xa9,\xb2 \xa3A\xf1\x08\xd6$\xaf\x07\xfd\xab\x01\xb2H\x93N\x0b\x8e\x1e\xa2\x02\xbc%\x89\x03\x89\x03\xaaj\xce\x1a\xe0[\xe71\x9d\x0c)\xc12\xc5\x16\x0b\xaf\x07\xfd\xab\x01\xe5\x02\xaa\x90\x01\x08\xe1\n\x9c\x08\xaf\x07\xa5\x04\x90\x02\xbc\'\x0b\x86\x02\xac\x17\x9e\x02\x81\x02\xbc\x0c\x0c\x95\x0f\x0c\xd7\x05\x83 \x9f\x02\x86\x02\xe7\xae\x01\x9e\x02\x99\x02\xb2H\xb7w\x0b\xfb\x04\xd4\x03\xcb\x04\xb