# batched data loader with torch.stack

In [58]:
import tiktoken
import torch 
import datasets
import random

# r50k_base vocab size: 50,257 https://arxiv.org/pdf/2404.09894
enc = tiktoken.get_encoding("r50k_base")
assert enc.decode(enc.encode("hello world")) == "hello world"

def encode(string):
    return torch.tensor(enc.encode(string))

def decode(tensor):
    return enc.decode([tensor.item()])
    
dataset = datasets.load_dataset('karpathy/tiny_shakespeare')
dataset_tok = dataset.map(lambda row: {"tok": torch.tensor(encode(row["text"]), device="mps")}, remove_columns="text")

In [None]:
from tqdm import tqdm 

def get_sample(split, sample_length):
    tokens = dataset_tok[split]["tok"][0]
    s = random.randint(0, len(tokens)-sample_length)
    return tokens[s:s+sample_length], tokens[s+sample_length]

for i in tqdm(range(200)): get_sample("train", 10)

In [7]:
import torch

In [21]:
tokens=torch.tensor(dataset_tok["train"]["tok"][0])

In [22]:
tokens[0]

tensor(5962)

In [27]:
bs=10
lst=[tokens[x:x+10] for x in range(1,bs)]
lst

[tensor([22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,  3285]),
 tensor([  25,  198, 8421,  356, 5120,  597, 2252,   11, 3285,  502]),
 tensor([ 198, 8421,  356, 5120,  597, 2252,   11, 3285,  502, 2740]),
 tensor([8421,  356, 5120,  597, 2252,   11, 3285,  502, 2740,   13]),
 tensor([ 356, 5120,  597, 2252,   11, 3285,  502, 2740,   13,  198]),
 tensor([5120,  597, 2252,   11, 3285,  502, 2740,   13,  198,  198]),
 tensor([ 597, 2252,   11, 3285,  502, 2740,   13,  198,  198, 3237]),
 tensor([2252,   11, 3285,  502, 2740,   13,  198,  198, 3237,   25]),
 tensor([  11, 3285,  502, 2740,   13,  198,  198, 3237,   25,  198])]

In [28]:
torch.stack(lst)

tensor([[22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,  3285],
        [   25,   198,  8421,   356,  5120,   597,  2252,    11,  3285,   502],
        [  198,  8421,   356,  5120,   597,  2252,    11,  3285,   502,  2740],
        [ 8421,   356,  5120,   597,  2252,    11,  3285,   502,  2740,    13],
        [  356,  5120,   597,  2252,    11,  3285,   502,  2740,    13,   198],
        [ 5120,   597,  2252,    11,  3285,   502,  2740,    13,   198,   198],
        [  597,  2252,    11,  3285,   502,  2740,    13,   198,   198,  3237],
        [ 2252,    11,  3285,   502,  2740,    13,   198,   198,  3237,    25],
        [   11,  3285,   502,  2740,    13,   198,   198,  3237,    25,   198]])

In [30]:
torch.stack(lst).size()

torch.Size([9, 10])

In [29]:
torch.cat(lst)

tensor([22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,  3285,
           25,   198,  8421,   356,  5120,   597,  2252,    11,  3285,   502,
          198,  8421,   356,  5120,   597,  2252,    11,  3285,   502,  2740,
         8421,   356,  5120,   597,  2252,    11,  3285,   502,  2740,    13,
          356,  5120,   597,  2252,    11,  3285,   502,  2740,    13,   198,
         5120,   597,  2252,    11,  3285,   502,  2740,    13,   198,   198,
          597,  2252,    11,  3285,   502,  2740,    13,   198,   198,  3237,
         2252,    11,  3285,   502,  2740,    13,   198,   198,  3237,    25,
           11,  3285,   502,  2740,    13,   198,   198,  3237,    25,   198])

In [31]:
torch.cat(lst).size()

torch.Size([90])

## final

In [59]:
import tiktoken, torch, datasets, random

# r50k_base vocab size: 50,257 https://arxiv.org/pdf/2404.09894
enc = tiktoken.get_encoding("r50k_base")
assert enc.decode(enc.encode("hello world")) == "hello world"

def encode(string):
    return torch.tensor(enc.encode(string))

def decode(tensor):
    return enc.decode([tensor.item()])
    
dataset = datasets.load_dataset('karpathy/tiny_shakespeare')
dataset_tok = dataset.map(lambda row: {"tok": encode(row["text"])}, remove_columns="text")

In [118]:
def get_sample(split, sample_length, batch_size):
    tokens = dataset_tok[split]["tok"][0]
    idcs = torch.randint(len(tokens)-sample_length, (batch_size,))
    x = torch.stack([torch.tensor(tokens[x:x+sample_length]) for x in idcs])
    y = torch.stack([torch.tensor(tokens[x+1:x+sample_length+1]) for x in idcs])
    return x, y

In [119]:
get_sample("test", 3, 10)

(tensor([[   13,   198,   198],
         [ 3073,   503,   286],
         [   40,  1276, 22389],
         [  750, 12201,   502],
         [   81, 13055,   511],
         [  198,  1870,   407],
         [  952,   318,   339],
         [ 8643,  1340,  9399],
         [30158,  3525,  9399],
         [ 1139,    11,   393]]),
 tensor([[  198,   198,    38],
         [  503,   286,   262],
         [ 1276, 22389,    25],
         [12201,   502,    13],
         [13055,   511, 15626],
         [ 1870,   407,  9642],
         [  318,   339,   668],
         [ 1340,  9399,    25],
         [ 3525,  9399,    25],
         [   11,   393,   356]]))