In [1]:
import io
import zlib
import json
import ijson
import torch
import tables
from tqdm import tqdm

from transformers import PreTrainedTokenizerFast

MAX_SEQ_LEN = 128+1
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="models/tokenizer.json", 
    pad_token="[PAD]", 
    unk_token="[UNK]", 
    max_len=MAX_SEQ_LEN,
    add_prefix_space=False
)

In [2]:
def save_tokens_subsequences(
        source,
        dest,
        tokenizer,
        window_size,
        stride,
        compress=False
):
    with open(source, "r", encoding="utf-8") as f_in:
        data_in = json.load(f_in)

    data_out = []
    for story in tqdm(data_in):
        tokens = tokenizer(
            story,
            return_special_tokens_mask=False,
            truncation=True,
            add_special_tokens=False,
            max_length=window_size,
            padding="max_length",
            return_overflowing_tokens=True,
            stride=stride,
            return_tensors="pt"
        )

        data_out.append(
            tokens["input_ids"] * tokens["attention_mask"]
        )

    data_out = torch.cat(data_out)
    
    if compress:
        output_bytes = io.BytesIO()
        torch.save(data_out, output_bytes)

        compressed_bytes = zlib.compress(output_bytes.getvalue())        
        with open(f"{dest}.zlib", "wb") as f_out:
            f_out.write(compressed_bytes)    
    else:
        torch.save(data_out, f"{dest}", pickle_protocol=5)

In [None]:
souce = "data/train-sampled.json"
dest = "data/train-sampled.pt"
save_tokens_subsequences(
    souce, 
    dest, 
    tokenizer, 
    window_size=MAX_SEQ_LEN,
    stride=MAX_SEQ_LEN - 4
)

100%|██████████| 135884/135884 [08:51<00:00, 255.55it/s]


torch.Size([3550385, 129])


In [4]:
souce = "data/valid-sampled.json"
dest = "data/valid-sampled.pt"
save_tokens_subsequences(
    souce, 
    dest, 
    tokenizer, 
    window_size=MAX_SEQ_LEN,
    stride=MAX_SEQ_LEN - 1
)

100%|██████████| 1381/1381 [00:17<00:00, 80.17it/s] 


torch.Size([139577, 129])


In [8]:
def save_token_subsequences_to_hdf5(
    source:str,
    dest:str,
    tokenizer,
    compression:tables.Filters,
    window_size:int,
    stride:int,
    name:str="data",
):        
    with open(source, "r", encoding="utf-8") as f_in:
        n_stories = len(list(ijson.items(f_in, "item")))

    with open(source, "r", encoding="utf-8") as f_in, \
        tables.open_file(f"{dest}.h5", "w", filters=compression) as h5_out:
        data_iter = ijson.items(f_in, "item")
        data_out = h5_out.create_earray(
            h5_out.root,
            atom=tables.Int64Atom(),
            shape=(0, window_size),
            name=name,
            expectedrows=5,
            filters=compression
        )

        for story in tqdm(data_iter, total=n_stories):
            tokens = tokenizer(
                story,
                return_special_tokens_mask=False,
                truncation=True,
                add_special_tokens=False,
                max_length=window_size,
                padding="max_length",
                return_overflowing_tokens=True,
                stride=stride,
                return_tensors="np"
            )

            data_out.append(
                tokens["input_ids"] * tokens["attention_mask"]
            )

        data_out.flush()


In [9]:
souce = "data/train-sampled.json"
dest = "data/train-sampled"
save_token_subsequences_to_hdf5(
    souce, 
    dest, 
    tokenizer, 
    compression=tables.Filters(complevel=1, complib="blosc2"),
    window_size=MAX_SEQ_LEN,
    stride=MAX_SEQ_LEN - 2
)

100%|██████████| 135884/135884 [12:53<00:00, 175.61it/s]


In [22]:
from torch.utils.data import Dataset, DataLoader

class TinyStoriesTokensLoader(Dataset):
    def __init__(
        self,
        file_path:str,
        max_seq_len:int,
        device="cpu",
        lazy_load=False,
    ):
        super(TinyStoriesTokensLoader, self).__init__()
        self.max_seq_len = max_seq_len,
        self.device = device

        self._load_data(file_path, lazy_load)

    def _load_data(self, file_path:str, lazy_load:bool):
        self.data = torch.load(file_path, mmap=lazy_load)

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        tokens = self.data[idx].to(self.device)
        return tokens[:-1], tokens[1:]       
        

In [25]:
ld_a = DataLoader(
    TinyStoriesTokensLoader(
        "data/train-sampled.pt",
        MAX_SEQ_LEN,
        device="cuda",
        lazy_load=True,
    ),
    batch_size=512,
    shuffle=True
)

In [26]:
from time import time

n_iters = 5
t = time()
for i, _ in enumerate(ld_a):
    if i==n_iters:
        break
    
print(f"Time taken: {time() - t:.2f}s")
print(f"Average = {(time() - t) / n_iters:.2f}s")

Time taken: 0.52s
Average = 0.10s


In [9]:
from torch.utils.data import Dataset, DataLoader

class TinyStoriesTokensHDF5Loader(Dataset):
    def __init__(
        self,
        file_path:str,
        max_seq_len:int,
        compression:tables.Filters,
        name:str="data",
        device="cpu",
    ):
        super(TinyStoriesTokensHDF5Loader, self).__init__()
        self.max_seq_len = max_seq_len,
        self.device = device

        self._load_data(file_path, compression, name)

    def _load_data(self, file_path:str, compression:tables.Filters, name:str):
        self.file = tables.open_file(file_path, mode="r", filters=compression)
        self.data = self.file.root[name]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        tokens = torch.from_numpy(self.data[idx]).to(self.device, torch.long)
        return tokens[:-1], tokens[1:]       
        

In [10]:
ld_b = DataLoader(
    TinyStoriesTokensHDF5Loader(
        "data/train-sampled.h5",
        MAX_SEQ_LEN,
        device="cuda",
        compression=tables.Filters(complevel=1, complib="blosc2")
    ),
    batch_size=512,
    shuffle=True
)

In [12]:
from time import time

n_iters = 10
t = time()
for i, _ in enumerate(ld_b):
    if i==n_iters:
        break
    
print(f"Time taken: {time() - t:.2f}s")
print(f"Average = {(time() - t) / n_iters:.2f}s")

Time taken: 47.62s
Average = 4.76s


In [13]:
from time import time
t = time()
next(iter(ld_a))
print(time() - t)

t = time()
next(iter(ld_b))
print(time() - t)

0.016007423400878906
1.6069536209106445
