In [4]:
!pip install zstandard
!pip install jsonlines
import os
import zstandard
import json
import jsonlines
import io
import datetime

def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
    raise TypeError ("Type %s not serializable" % type(obj))

# Modified version of lm_dataformat Archive for single file.
class Archive:
    def __init__(self, file_path, compression_level=3):
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)    
        self.fh = open(self.file_path, 'wb')
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
        self.compressor = self.cctx.stream_writer(self.fh)        
    
    def add_data(self, data, meta={}):
        self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
    
    def commit(self):
        self.compressor.flush(zstandard.FLUSH_FRAME)        
        self.fh.flush()
        self.fh.close()

# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
    def __init__(self):
        pass

    def read_jsonl(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
        with open(file, 'rb') as fh:
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

                text = ob['text']

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
                    yield text, (ob['meta'] if 'meta' in ob else {})
                else:
                    yield text



In [91]:
import glob
import os
import math

import tqdm

document_count = 0
total_text_size = 0
dataset_directory = "/datadrive/openwebtext2"
files = glob.glob(os.path.join(dataset_directory, "*jsonl.zst"))

archives = [Archive("{}/shards/shard_{}".format(dataset_directory,i)) for i in range(100)]
archives_dict = {i:a for i, a in enumerate(archives)}

for file_path in tqdm.tqdm(files, dynamic_ncols=True):
    reader = Reader()
    for document, metadata in reader.read_jsonl(file_path, get_meta=True):
        document_count += 1
        archive_index = document_count % 100
        archives_dict[archive_index].add_data(document)
        total_text_size += len(document)

for a in archives:
    a.commit()

billion = math.pow(10, 9)
print(f"Total Document Count: {document_count:,}")
print(f"Total Uncompressed Text Size: {(total_text_size / billion):.2f} GB")

100%|██████████| 179/179 [41:40<00:00, 13.97s/it]

Total Document Count: 17,103,059
Total Uncompressed Text Size: 65.86 GB





In [126]:
import os
from pathlib import Path

import torch
from tokenizers import ByteLevelBPETokenizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset
from transformers import GPT2Tokenizer

from utils.vocabulary import Reader
from utils.vocabulary import Vocab


class WebTextDocumentIterator:
    def __init__(self, dataset_paths):
        self.dataset_paths = dataset_paths

    def get_document(self, reader, path):
        return reader.read_jsonl(path)

    def __iter__(self):
        reader = Reader()
        for path in self.dataset_paths:
            yield from self.get_document(reader, path)


class FileIterator:
    def __init__(self, dataset_paths):
        self.dataset_paths = dataset_paths

    def get_file(self, path):
        with open(path, "r", encoding="utf-8") as f:
            yield from f.readlines()

    def __iter__(self):
        for path in self.dataset_paths:
            yield from self.get_file(path)


class TokenizerIterator:
    def __init__(self, seq_len, tokenizer, dataset_paths):
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        self.document_iter = WebTextDocumentIterator(dataset_paths)

    def tokenize_doc(self, x):
        tokenized = self.tokenizer(text=x, truncation=True).input_ids
        tokenized.append(self.tokenizer.eos_token_id)

        tokenized.insert(0, self.tokenizer.eos_token_id)
        if len(tokenized) >= self.seq_len:
            for i in range(len(tokenized) - self.seq_len):
                yield tokenized[i : i + self.seq_len], tokenized[i + 1 : i + 1 + self.seq_len], len(
                    tokenized[i : i + self.seq_len]
                )
        else:
            pass

    def __iter__(self):
        for x in self.document_iter:
            yield from self.tokenize_doc(x)


class BatchIterator:
    def __init__(self, seq_len, batch_size, drop_last, tokenizer, dataset_paths):
        self.tokenizer_iter = TokenizerIterator(
            seq_len=seq_len, tokenizer=tokenizer, dataset_paths=dataset_paths
        )
        self.batch_size = batch_size
        self.drop_last = drop_last

    def collate_fn(self, batch):
        data_list, label_list, seq_len_list = [], [], []
        for _data, _label, _seq in batch:
            data_list.append(_data)
            label_list.append(_label)
            seq_len_list.append(_seq)
        return (
            torch.LongTensor(data_list),
            torch.LongTensor(label_list),
            torch.LongTensor(seq_len_list),
        )

    def __iter__(self):
        batch = []
        for x in self.tokenizer_iter:
            batch.append(x)
            if len(batch) == self.batch_size:
                yield self.collate_fn(batch)
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield self.collate_fn(batch)
        else:
            pass


class WebTextIter(IterableDataset):
    def __init__(self, batch_size, drop_last, dataset_paths, seq_len, tokenizer=None):
        if tokenizer is None:
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.seq_len = seq_len
        self.dataset_paths = dataset_paths
        self.batch_iter = BatchIterator(
            seq_len=seq_len,
            batch_size=batch_size,
            drop_last=drop_last,
            tokenizer=tokenizer,
            dataset_paths=dataset_paths,
        )

    def __iter__(self):
        for x in self.batch_iter:
            yield x



In [127]:
def collate_fn(batch):
        data_list, label_list, seq_len_list = [], [], []
        for _data, _label, _seq in batch:
            data_list.append(_data)
            label_list.append(_label)
            seq_len_list.append(_seq)
        return (
            torch.LongTensor(data_list),
            torch.LongTensor(label_list),
            torch.LongTensor(seq_len_list),
        )

dataset_directory = "/datadrive/openwebtext2/shards/"
files = glob.glob(os.path.join(dataset_directory, "*"))
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.eos_token_id

50256

In [128]:
wt = WebTextIter(dataset_paths=files[0:1], batch_size=32, drop_last=True, seq_len=1024)
count = 0
for i, n in enumerate(wt):
    count+=1

In [101]:
import random
random.seed(4)
random.shuffle(files)
total_size = 0
for path in files:
    total_size += os.path.getsize(path)
total_size

26406002781

In [129]:
count

3015

In [111]:
test_size = 0
test_files = []
 
for path in files:
    if test_size >= test:
        break
    file_size = os.path.getsize(path)
    print(file_size)
    if (test_size + file_size) >= test:
        print("no")
        continue
    else:
        test_size += file_size
        test_files.append(path)
test_files, test_size

263090877
no
263287002
no
264040938
no
263660218
no
263895576
no
264419479
no
265058953
no
262984568
no
264424789
no
263419654
no
264099334
no
263798847
no
265420565
no
262535265
no
263247368
no
264325418
no
263874260
no
264918087
no
263857548
no
264084020
no
263262171
no
264017063
no
264662936
no
262955373
no
264793648
no
264213747
no
262676355
no
263735509
no
262528529
no
263453641
no
263331397
no
265449467
no
263940650
no
264326414
no
266099415
no
263179145
no
263528765
no
264510899
no
263951767
no
263931596
no
264601985
no
262915959
no
265175054
no
263042530
no
264318073
no
263965778
no
264920592
no
264827199
no
264744368
no
264673493
no
264279959
no
265542126
no
264195155
no
263804243
no
263445165
no
263724380
no
263877082
no
264790324
no
263149832
no
264498005
no
262558457
no
265021735
no
263875077
no
263224181
no
264054118
no
263756430
no
263844602
no
263849177
no
263906114
no
264914946
no
262860344
no
265120063
no
263734497
no
264716065
no
262874629
no
263573503
no
264360138
no

([], 0)

In [52]:
val_size = 0
val_files = []
 
for path in files:
    if path in test_files:
        continue
    if val_size >= val:
        break
    file_size = os.path.getsize(path)
    if val_size + file_size >= val:
        continue
    else:
        val_size += file_size
    val_files.append(path)
val_files, val_size

(['/datadrive/openwebtext2/2009-05.jsonl.zst',
  '/datadrive/openwebtext2/2005-10.jsonl.zst',
  '/datadrive/openwebtext2/2007-05.jsonl.zst',
  '/datadrive/openwebtext2/2005-08.jsonl.zst',
  '/datadrive/openwebtext2/2005-09.jsonl.zst'],
 28723231)

In [62]:
import shutil
for f in files:
    fname = os.path.basename(f)

    if f in val_files or f in test_files:
        print(fname)
        continue
    shutil.move("/datadrive/openwebtext2/train/{}".format(fname), f)

In [65]:
len(files)

179