In [2]:
import os
import zstandard
import json
import jsonlines
import io
import datetime

def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
    raise TypeError ("Type %s not serializable" % type(obj))

# Modified version of lm_dataformat Archive for single file.
class Archive:
    def __init__(self, file_path, compression_level=3):
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)    
        self.fh = open(self.file_path, 'wb')
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
        self.compressor = self.cctx.stream_writer(self.fh)        
    
    def add_data(self, data, meta={}):
        self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
    
    def commit(self):
        self.compressor.flush(zstandard.FLUSH_FRAME)        
        self.fh.flush()
        self.fh.close()

# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
    def __init__(self):
        pass

    def read_jsonl(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
        with open(file, 'rb') as fh:
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

                text = ob['text']

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
                    yield text, (ob['meta'] if 'meta' in ob else {})
                else:
                    yield text

ModuleNotFoundError: No module named 'zstandard'

In [None]:
import glob
import os
import math

import tqdm

document_count = 0
total_text_size = 0
dataset_directory = "/datadrive/openwebtext2"
files = glob.glob(os.path.join(dataset_directory, "*jsonl.zst"))

archives = [Archive("{}/shards/shard_{}".format(dataset_directory,i)) for i in range(100)]
archives_dict = {i:a for i, a in enumerate(archives)}

for file_path in tqdm.tqdm(files, dynamic_ncols=True):
    reader = Reader()
    for document, metadata in reader.read_jsonl(file_path, get_meta=True):
        document_count += 1
        archive_index = document_count % 100
        archives_dict[archive_index].add_data(document)
        total_text_size += len(document)

for a in archives:
    a.commit()

billion = math.pow(10, 9)
print(f"Total Document Count: {document_count:,}")
print(f"Total Uncompressed Text Size: {(total_text_size / billion):.2f} GB")

In [None]:
import os
from pathlib import Path

import torch
from tokenizers import ByteLevelBPETokenizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset
from transformers import GPT2Tokenizer

from utils.vocabulary import Reader
from utils.vocabulary import Vocab


class WebTextDocumentIterator:
    def __init__(self, dataset_paths):
        self.dataset_paths = dataset_paths

    def get_document(self, reader, path):
        return reader.read_jsonl(path)

    def __iter__(self):
        reader = Reader()
        for path in self.dataset_paths:
            yield from self.get_document(reader, path)


class FileIterator:
    def __init__(self, dataset_paths):
        self.dataset_paths = dataset_paths

    def get_file(self, path):
        with open(path, "r", encoding="utf-8") as f:
            yield from f.readlines()

    def __iter__(self):
        for path in self.dataset_paths:
            yield from self.get_file(path)


class TokenizerIterator:
    def __init__(self, seq_len, tokenizer, dataset_paths):
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        self.document_iter = WebTextDocumentIterator(dataset_paths)

    def tokenize_doc(self, x):
        tokenized = self.tokenizer(text=x, truncation=True).input_ids
        tokenized.append(self.tokenizer.eos_token_id)

        tokenized.insert(0, self.tokenizer.eos_token_id)
        if len(tokenized) >= self.seq_len:
            for i in range(len(tokenized) - self.seq_len):
                yield tokenized[i : i + self.seq_len], tokenized[i + 1 : i + 1 + self.seq_len], len(
                    tokenized[i : i + self.seq_len]
                )
        else:
            pass

    def __iter__(self):
        for x in self.document_iter:
            yield from self.tokenize_doc(x)


class BatchIterator:
    def __init__(self, seq_len, batch_size, drop_last, tokenizer, dataset_paths):

        stream_split = len(dataset_paths) // batch_size
        streams = [dataset_paths[(i*stream_split):((i+1)*stream_split)] for i in range(0, batch_size)]
        self.tokenizers = [TokenizerIterator(
            seq_len=seq_len, tokenizer=tokenizer, dataset_paths=stream
        ) for stream in streams]
        self.batch_size = batch_size
        self.drop_last = drop_last

    def collate_fn(self, batch):
        data_list, label_list, seq_len_list = [], [], []
        for _data, _label, _seq in batch:
            data_list.append(_data)
            label_list.append(_label)
            seq_len_list.append(_seq)
        return (
            torch.LongTensor(data_list).permute(1,0),
            torch.LongTensor(label_list).permute(1,0),
            torch.LongTensor(seq_len_list),
        )
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.shuffled_data_list) for _ in range(self.batch_size)])
    
    def __iter__(self):
        return self.get_streams()

    def __iter__(self):
        batch = []
        try:
            while True:
                for tokenizer in self.tokenizers:
                    a = next(iter(tokenizer))
                    batch.append(a)
                if len(batch) == self.batch_size:
                    yield self.collate_fn(batch)
                    batch = []
        except StopIteration:
            return


class WebTextIter(IterableDataset):
    def __init__(self, batch_size, drop_last, dataset_paths, seq_len, tokenizer=None):
        if tokenizer is None:
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.seq_len = seq_len
        self.dataset_paths = dataset_paths
        self.batch_iter = BatchIterator(
            seq_len=seq_len,
            batch_size=batch_size,
            drop_last=drop_last,
            tokenizer=tokenizer,
            dataset_paths=dataset_paths,
        )

    def __iter__(self):
        for x in self.batch_iter:
            yield x

            



In [None]:

dataset_directory = "/datadrive/openwebtext2/shards/"
files = glob.glob(os.path.join(dataset_directory, "*"))
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.eos_token_id

In [None]:
import os
from pathlib import Path

import torch
from tokenizers import ByteLevelBPETokenizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset
from transformers import GPT2Tokenizer

from utils.vocabulary import Reader
from utils.vocabulary import Vocab
import random
from itertools import chain, cycle, islice
import torch.utils.data as data

import time
import torch
import numpy as np


class WebTextDocumentIterator:
    def __init__(self, dataset_paths):
        self.dataset_paths = dataset_paths

    def get_document(self, reader, path):
        return reader.read_jsonl(path)

    def __iter__(self):
        reader = Reader()
        for path in self.dataset_paths:
            yield from self.get_document(reader, path)


class FileIterator:
    def __init__(self, dataset_paths):
        self.dataset_paths = dataset_paths

    def get_file(self, path):
        with open(path, "r", encoding="utf-8") as f:
            yield from f.readlines()

    def __iter__(self):
        for path in self.dataset_paths:
            yield from self.get_file(path)


class TokenizerIterator:
    def __init__(self, seq_len, tokenizer, dataset_paths):
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        self.document_iter = WebTextDocumentIterator(dataset_paths)

    def tokenize_doc(self, x):
        tokenized = self.tokenizer(text=x, truncation=True).input_ids
        tokenized.append(self.tokenizer.eos_token_id)

        tokenized.insert(0, self.tokenizer.eos_token_id)
        if len(tokenized) >= self.seq_len:
            for i in range(len(tokenized) - self.seq_len):
                yield tokenized[i : i + self.seq_len], tokenized[i + 1 : i + 1 + self.seq_len], len(
                    tokenized[i : i + self.seq_len]
                )
        else:
            pass

    def __iter__(self):
        for x in self.document_iter:
            yield from self.tokenize_doc(x)


class BatchIterator:
    def __init__(self, seq_len, batch_size, drop_last, tokenizer, dataset_paths):

        self.dataset_paths = dataset_paths
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.seq_len = seq_len

    def collate_fn(self, batch):
        data_list, label_list, seq_len_list = [], [], []
        for _data, _label, _seq in batch:
            data_list.append(_data)
            label_list.append(_label)
            seq_len_list.append(_seq)
        return (
            torch.LongTensor(data_list).permute(1,0),
            torch.LongTensor(label_list).permute(1,0),
            torch.LongTensor(seq_len_list),
        )
    def process_data(self, dataset):
        self.tokenizer_iter = TokenizerIterator(self.seq_len, tokenizer, [dataset])
        for x in self.tokenizer_iter:
            yield x
            
    def shuffled_data_list(self, i):
        split = len(self.dataset_paths) // self.batch_size
        dataset_paths = self.dataset_paths[(i*split):((i+1)*split)]
        return random.sample(dataset_paths, len(dataset_paths))
        
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.shuffled_data_list(i)) for i in range(self.batch_size)])
    
    def __iter__(self):
        return self.get_streams()



class WebTextIter(IterableDataset):
    def __init__(self, batch_size, drop_last, dataset_paths, seq_len, tokenizer=None):
        if tokenizer is None:
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.seq_len = seq_len
        self.dataset_paths = dataset_paths
        self.batch_iter = BatchIterator(
            seq_len=seq_len,
            batch_size=batch_size,
            drop_last=drop_last,
            tokenizer=tokenizer,
            dataset_paths=dataset_paths,
        )

    def __iter__(self):
        for x in self.batch_iter:
            yield self.collate_fn(x)

            
    def collate_fn(self, batch):
        data_list, label_list, seq_len_list = [], [], []
        for _data, _label, _seq in batch:
            data_list.append(_data)
            label_list.append(_label)
            seq_len_list.append(_seq)
        return (
            torch.LongTensor(data_list).permute(1,0),
            torch.LongTensor(label_list).permute(1,0),
            torch.LongTensor(seq_len_list),
        )


In [None]:
from torch.utils.data import DataLoader
wt = WebTextIter(dataset_paths=files[:20], batch_size=10, drop_last=True, seq_len=5)
dl = DataLoader(wt, batch_size=None, sampler=None)
count = 0
seen = []
for i, n in enumerate(dl):
    print(n)
    break

In [None]:
%debug


In [None]:
import random
random.seed(4)
random.shuffle(files)
total_size = 0
for path in files:
    total_size += os.path.getsize(path)
total_size

In [None]:
count

In [None]:
test_size = 0
test_files = []
 
for path in files:
    if test_size >= test:
        break
    file_size = os.path.getsize(path)
    print(file_size)
    if (test_size + file_size) >= test:
        print("no")
        continue
    else:
        test_size += file_size
        test_files.append(path)
test_files, test_size

In [None]:
val_size = 0
val_files = []
 
for path in files:
    if path in test_files:
        continue
    if val_size >= val:
        break
    file_size = os.path.getsize(path)
    if val_size + file_size >= val:
        continue
    else:
        val_size += file_size
    val_files.append(path)
val_files, val_size

In [None]:
import shutil
for f in files:
    fname = os.path.basename(f)

    if f in val_files or f in test_files:
        print(fname)
        continue
    shutil.move("/datadrive/openwebtext2/train/{}".format(fname), f)

In [None]:
len(files)