In [None]:
import os

import numpy as np
import torch

In [None]:
import tiktoken

filename = "../shards/finewebedu10T_train_0001.npy"

enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens["<|endoftext|>"]  # end of text token
print(eot)

np_tensor = np.load(filename, mmap_mode='r')
t = torch.tensor(np_tensor, dtype=torch.long)
print(t.shape)

doc_breaks = (t == eot).nonzero().flatten().tolist()

# Split on EOT
doc_breaks = (t == eot).nonzero().flatten().tolist()
print("total eots", len(doc_breaks))
docs, start = [], 0
for end in doc_breaks:
    docs.append(t[start:end+1])  # include EOT
    start = end + 1
if start < len(t):  # last doc without trailing EOT
    docs.append(t[start:])
# print few docs
for i,doc in enumerate(docs[:3]):
    print(i, len(doc))
rng = np.random.RandomState(42)
rng.shuffle(docs)
for i,doc in enumerate(docs[:3]):
    print(i, len(doc))
out_tensor = torch.cat(docs, dim=0)
# get total eots
total_eots = (out_tensor == eot).sum().item()
print("total eots", total_eots)
print(out_tensor.shape)


In [None]:
seed_offset = 0
class DataLoaderLite:
    def __init__(self, B, T, process_rank, num_processes, split, data_root='../shards', seed=42):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in ["train", "val"], f"Invalid split: {split}"
        self.split = split
        self.rng = np.random.RandomState(seed+seed_offset)
        self.eot = enc._special_tokens["<|endoftext|>"]  # end of text token
        
        # get the shards filenames
        shards = [s for s in os.listdir(data_root) if self.split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        print(self.shards)
        # shuffle the shards
        self.rng.shuffle(self.shards)
        assert len(shards) > 0, f"No shards found for split: {self.split}"
        print(self.shards)

        # load the dataset
        self.cur_shard = 0
        self.tokens = self._load_tokens(self.shards[self.cur_shard])
        self.cur_pos = self.process_rank * (self.B * self.T)

        
    def get_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.cur_pos : self.cur_pos + B * T + 1]
        x = buf[:-1].view(B, T)  # input to the model
        y = buf[1:].view(B, T)  # output of the model
        # advance position
        self.cur_pos += B * T * self.num_processes
        
        # if loading next batch is out of bounds, load the next shard
        if self.cur_pos + B * T * self.num_processes + 1 > len(self.tokens):
            self.cur_shard = (self.cur_shard + 1) % len(self.shards)
            self.tokens = self._load_tokens(self.shards[self.cur_shard])
            self.cur_pos = self.process_rank * (self.B * self.T)
        
        return x, y
    
    def _load_tokens(self, filename):
        # memory mapping for efficiency
        np_tensor = np.load(filename, mmap_mode='r')
        # For validation split, return tokens as-is without shuffling
        if self.split == "val":
            return torch.tensor(np_tensor, dtype=torch.long)
        else:
            # For training split, shuffle documents to reduce temporal patterns
            t = torch.tensor(np_tensor, dtype=torch.long)
            # Split the token sequence into individual documents at end-of-text markers
            doc_breaks = (t == self.eot).nonzero().flatten().tolist()
            docs, start = [], 0 
            # Extract each document including its EOT token
            for end in doc_breaks:
                docs.append(t[start:end+1])  # include EOT
                start = end + 1
            if start < len(t):  # last doc without trailing EOT
                docs.append(t[start:])
            # Randomly shuffle the order of documents to break temporal patterns
            self.rng.shuffle(docs)
            # Concatenate all shuffled documents back into a single tensor
            return torch.cat(docs, dim=0)
    def reset(self):
        self.cur_shard = 0
        self.tokens = self._load_tokens(self.shards[self.cur_shard])
        self.cur_pos = self.process_rank * (self.B * self.T)

train_loader = DataLoaderLite(B=4, T=1024, process_rank=0, num_processes=1, split="train")
val_loader = DataLoaderLite(B=4, T=1024, process_rank=0, num_processes=1, split="val")
