In [1]:
import ijson
import torch
import pickle
import numpy as np

from tqdm import tqdm
from numpy.lib.format import open_memmap

from transformers import PreTrainedTokenizerFast

MAX_SEQ_LEN = 256
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="models/tokenizer.json", 
    pad_token="[PAD]", 
    unk_token="[UNK]", 
    eos_token = "<|endoftext|>",
    max_len=MAX_SEQ_LEN,
    add_prefix_space=False,
    
)

In [2]:
def save_mmap_numpy_array(
    source,
    dest,
    tokenizer,
    window_size,
    stride_size,
    FLUSH_COUNTER=1000000
):
    with open(source, "r", encoding="utf-8") as f_in:
        print("Calculating number of examples...", end=" ")
        n = sum(1 for _ in ijson.items(f_in, "item"))
        print(f"Total of {n} stories")
        
    with open(source, "r", encoding="utf-8") as f_in, open(f"{dest}.pickle", "wb") as f_out:
        items = ijson.items(f_in, "item")
        array_length = 0
        current_pos = 0
        idx2pos = []
        
        for story in tqdm(items, total=n, desc="Calculating number of elements needed...", leave=True):
            tokens = tokenizer(
                story,
                truncation=False,
                padding="max_length",
                max_length=window_size,
                stride=stride_size,
                return_overflowing_tokens=False,
                return_attention_mask=False,
                return_token_type_ids=False,
            )["input_ids"]

            array_length += len(tokens)
            idx2pos.extend(
                [current_pos +i for i in range(len(tokens) - window_size + 1)]
            )
            current_pos+=len(tokens)
        print(len(idx2pos), array_length)
        pickle.dump(idx2pos, f_out)
    
    with open(source, "r", encoding="utf-8") as f_in:
        items = ijson.items(f_in, "item")
        total_tokens = open_memmap(
            f"{dest}.npy", 
            dtype=np.uint16, 
            mode="w+",
            shape=(array_length,)
        )

        current_index=0
        for i, story in tqdm(enumerate(items), total=n):
            tokens = tokenizer(
                story,
                truncation=False,
                padding="max_length",
                max_length=window_size,
                stride=stride_size,
                return_tensors="np",
                return_token_type_ids=False
            )
            tokens = tokens["input_ids"].astype(np.uint16)
            total_tokens[
                current_index:current_index+tokens.shape[1]
            ] = tokens[0]
            current_index+=tokens.shape[1]

            if i % FLUSH_COUNTER == FLUSH_COUNTER - 1:
                total_tokens.flush()
        total_tokens.flush()
    return total_tokens.shape

In [3]:
souce = "data/train-sampled.json"
dest = "data/train-sampled"
save_mmap_numpy_array(
    souce, 
    dest, 
    tokenizer, 
    window_size=MAX_SEQ_LEN,
    stride_size=MAX_SEQ_LEN-1,
)

Calculating number of examples... Total of 271769 stories


Calculating number of elements needed...: 100%|██████████| 271769/271769 [04:17<00:00, 1055.25it/s]


4156149 73457244


100%|██████████| 271769/271769 [03:58<00:00, 1139.39it/s]


(73457244,)

In [5]:
souce = "data/valid.json"
dest = "data/valid"
save_mmap_numpy_array(
    souce, 
    dest, 
    tokenizer, 
    window_size=MAX_SEQ_LEN,
    stride_size=MAX_SEQ_LEN - 1,
)

Calculating number of examples... Total of 27630 stories


Calculating number of elements needed...: 100%|██████████| 27630/27630 [00:21<00:00, 1290.19it/s]


396687 7442337


100%|██████████| 27630/27630 [00:23<00:00, 1159.88it/s]


(7442337,)

In [6]:
from torch.utils.data import Dataset, DataLoader

class TinyStoriesDataset(Dataset):
    def __init__(
        self,  
        input_file,
        tokenizer,
        seq_len,
        device="cuda",
        lazy_load=True
    ):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.DEVICE = device
        self.lazy_load = lazy_load
        
        self._load_data(input_file, lazy_load)

    def _load_data(self, file, lazy_load):
        memmap_flag = "r" if lazy_load else None
        self.data = np.load(f"{file}.npy", mmap_mode = memmap_flag)
        with open(f"{file}.pickle", "rb") as f:
            self.idx2pos = pickle.load(f)
    
    def __len__(self):
        return len(self.idx2pos)

    def __getitem__(self, idx):
        i = self.idx2pos[idx]
        x = self.data[i:i+self.seq_len]
        
        if (i+self.seq_len+1)>=len(self.data):
            y = np.pad(x[1:], pad_width=(0,1), mode="constant", constant_values=0)
        else:
            next_token = self.data[i+self.seq_len+1]
            if (
                x[-1] in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id) and 
                next_token!=self.tokenizer.pad_token_id
            ):
                y = np.pad(x[1:], pad_width=(0,1), mode="constant", constant_values=0)
            else:
                y = self.data[i+1:i+1+self.seq_len]

        return (
            torch.from_numpy(x.astype(np.int64)).to(self.DEVICE),
            torch.from_numpy(y.astype(np.int64)).to(self.DEVICE)
        )
        

In [7]:
ld_a = DataLoader(
    TinyStoriesDataset(
        input_file="data/train-sampled",
        tokenizer=tokenizer,
        seq_len=MAX_SEQ_LEN,
        device="cuda",
        lazy_load=True,
    ),
    batch_size=128,
    shuffle=True
)

In [9]:
from time import time

n_iters = 10
t = time()
for i, _ in enumerate(ld_a):
    if i==n_iters:
        break
    
print(f"Time taken: {time() - t:.2f}s")
print(f"Average = {(time() - t) / n_iters:.2f}s")

Time taken: 0.72s
Average = 0.07s


In [9]:
ld_b = DataLoader(
    TinyStoriesDataset(
        input_file="data/train-sampled",
        tokenizer=tokenizer,
        seq_len=MAX_SEQ_LEN,
        device="cuda",
        lazy_load=False,
    ),
    batch_size=128,
    shuffle=True
)

In [10]:
from time import time

n_iters = 10
t = time()
for i, _ in enumerate(ld_b):
    if i==n_iters:
        break
    
print(f"Time taken: {time() - t:.2f}s")
print(f"Average = {(time() - t) / n_iters:.2f}s")

Time taken: 2.71s
Average = 0.27s


In [11]:
ld_a.dataset[0]

(tensor([ 675,  682,  506,  644,   15,  567,  506,  805, 1217,   15,  642,  529,
          506,  593, 1540,  840,   17,  567,  508,  840,  863,  506,  638,  724,
          648,  922,   17,  922,  775,  513, 1334,  901,  806,  552,  773,  610,
          794,  552, 2164,  609,   17,  552,  824, 1637,  552,  773,  556,  506,
         1990,   15,  506, 2195,   15,  994, 1071,  506, 1060,  349,  530,   17,
          687,  578,   15,  922,  750,  506,  805, 1540,  857,  567,  581,  946,
           17,  552,  696,  608,  892,  806,  529,  932,   15,  614,  552,  824,
          547, 1626,  556,  820, 1081,   17,  552,  947,  513, 1243,  508,  857,
         1098, 1196,  552,  529, 1008,  530,  514,  712, 1305,  547,  633,  581,
         1664, 1718,   17,   92, 1430,  364,  670, 1076,   15,  514,  922, 2164,
          609,   17,  552,  927,  506, 2195,  567,  508], device='cuda:0'),
 tensor([ 682,  506,  644,   15,  567,  506,  805, 1217,   15,  642,  529,  506,
          593, 1540,  840,   17, 