In [1]:
import os

import numpy as np
import torch

In [13]:
seed_offset = 0
class DataLoaderLite:
    
    def __init__(self, data_root, T, B,
                 split="train",   # "train" or "val"
                 num_processes=1, process_rank=0, seed=1234):
        
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.split = split
        assert split in ["train", "val"], f"Invalid split: {split}"
        self.rng = np.random.default_rng(seed+seed_offset)

        # get the shards filenames
        self.shards = [s for s in os.listdir(data_root) if self.split in s]
        self.shards = sorted(self.shards)
        self.shards = [os.path.join(data_root, s) for s in self.shards]
        print(f"{split}: {len(self.shards)} shard(s)")

        # Memory-map shards so they don’t fully load into RAM
        self.mem = [np.load(f, mmap_mode='r') for f in self.shards]
        self.shard_lengths = [m.shape[0] for m in self.mem]
        
        # Build global window index ---
        self._build_index()
        self.ptr = 0

    def _build_index(self):
        """Build a list of all (shard_id, start_offset) windows."""
        all_indices = []
        for sid, L in enumerate(self.shard_lengths):
            # number of windows in this shard (discard leftover < T+1 tokens)
            max_windows = (L - (self.T + 1)) // self.T
            if max_windows <= 0:
                continue
            starts = (np.arange(max_windows) * self.T).astype(np.int64)
            shard_ids = np.full_like(starts, sid, dtype=np.int64)
            pairs = np.stack([shard_ids, starts], axis=1)
            all_indices.append(pairs)

        all_indices = np.concatenate(all_indices, axis=0)
        print(all_indices.shape)

        if self.split == "train":
            # shuffle and split across ranks (each GPU sees unique slice)
            self.rng.shuffle(all_indices)
            self.index = all_indices[self.process_rank::self.num_processes]
        else:
            # validation: no shuffle, same across ranks
            self.index = all_indices

        print(f"{self.split} index size: {len(self.index)} windows")

    def __len__(self):
        """Number of full batches in the dataset."""
        return len(self.index) // self.B

    def get_batch(self):
        """Return one batch of shape (B, T) for x and y."""
        rows = self.index[self.ptr : self.ptr + self.B]
        self.ptr += self.B

        xs, ys = [], []
        for sid, start in rows:
            # slice (T+1) tokens from shard
            t = self.mem[sid][start : start + self.T + 1].astype(np.uint16)
            t = torch.from_numpy(t)
            xs.append(t[:-1])
            ys.append(t[1:])
        x = torch.stack(xs).long()
        y = torch.stack(ys).long()
        return x, y

    def reset(self, seed=None):
        """Reset pointer and reshuffle (train only)."""
        if self.split == "train":
            if seed is not None:
                self.rng = np.random.default_rng(seed+seed_offset)
            self._build_index()
        self.ptr = 0

In [14]:
train_loader = DataLoaderLite(data_root="../shards", B=4, T=1024, process_rank=0, num_processes=1, split="train")

# === Step 3: Inspect some batches ===
print("Train batches:")
for i in range(3):
    x, y = train_loader.get_batch()
    print(f"Batch {i}: x.shape={x.shape}, y.shape={y.shape}")
    print(x[:,:10])
    print(y[:,:10])
    

train: 5 shard(s)
(443342, 2)
train index size: 443342 windows
Train batches:
Batch 0: x.shape=torch.Size([4, 1024]), y.shape=torch.Size([4, 1024])
tensor([[  284,   467,   706,   340,    13,  5501,   286,   514,   481,  5409],
        [  286,   777,  6958,   481,  1205,  2354,   286,   262, 20583,   290],
        [ 2641,   262,  1027,    75,  5910,    13,  5855, 29252,    11,   314],
        [ 4800,   290, 28554,   286,   617,   286,   674,   749, 14133,   290]])
tensor([[  467,   706,   340,    13,  5501,   286,   514,   481,  5409,   329],
        [  777,  6958,   481,  1205,  2354,   286,   262, 20583,   290,  3675],
        [  262,  1027,    75,  5910,    13,  5855, 29252,    11,   314,  2626],
        [  290, 28554,   286,   617,   286,   674,   749, 14133,   290,  9566]])
Batch 1: x.shape=torch.Size([4, 1024]), y.shape=torch.Size([4, 1024])
tensor([[  743,  1663,   772,  4785,   780,   783,   336, 10141,   389,  3599],
        [ 7161,  9041,   460,   651, 40186,    13,  1002, 23

In [15]:
train_loader = DataLoaderLite(data_root="../shards", B=4, T=1024, process_rank=0, num_processes=2, split="train")

# Inspect some batches
print("Train batches:")
for i in range(3):
    x, y = train_loader.get_batch()
    print(f"Batch {i}: x.shape={x.shape}, y.shape={y.shape}")
    print(x[:,:10])
    print(y[:,:10])

train_loader = DataLoaderLite(data_root="../shards", B=4, T=1024, process_rank=1, num_processes=2, split="train")

# Inspect some batches
print("Train batches:")
for i in range(3):
    x, y = train_loader.get_batch()
    print(f"Batch {i}: x.shape={x.shape}, y.shape={y.shape}")
    print(x[:,:10])
    print(y[:,:10])

train: 5 shard(s)
(443342, 2)
train index size: 221671 windows
Train batches:
Batch 0: x.shape=torch.Size([4, 1024]), y.shape=torch.Size([4, 1024])
tensor([[  284,   467,   706,   340,    13,  5501,   286,   514,   481,  5409],
        [ 2641,   262,  1027,    75,  5910,    13,  5855, 29252,    11,   314],
        [  743,  1663,   772,  4785,   780,   783,   336, 10141,   389,  3599],
        [  292,  1040,   290, 20461,    78,   577, 11490,    13,  1052,  3942]])
tensor([[  467,   706,   340,    13,  5501,   286,   514,   481,  5409,   329],
        [  262,  1027,    75,  5910,    13,  5855, 29252,    11,   314,  2626],
        [ 1663,   772,  4785,   780,   783,   336, 10141,   389,  3599,   284],
        [ 1040,   290, 20461,    78,   577, 11490,    13,  1052,  3942,  1438]])
Batch 1: x.shape=torch.Size([4, 1024]), y.shape=torch.Size([4, 1024])
tensor([[  475,  2138,   287,  2754,   284,   262,  2842,   287,   543, 13384],
        [ 1244,   307,  1444,   416,   534,  1200,   264,  1

In [16]:
train_loader = DataLoaderLite(data_root="../shards", B=4, T=1024, process_rank=0, num_processes=3, split="train")

# Inspect some batches
print("Train batches:")
for i in range(3):
    x, y = train_loader.get_batch()
    print(f"Batch {i}: x.shape={x.shape}, y.shape={y.shape}")
    print(x[:,:10])
    print(y[:,:10])

train_loader = DataLoaderLite(data_root="../shards", B=4, T=1024, process_rank=1, num_processes=3, split="train")

# Inspect some batches
print("Train batches:")
for i in range(3):
    x, y = train_loader.get_batch()
    print(f"Batch {i}: x.shape={x.shape}, y.shape={y.shape}")
    print(x[:,:10])
    print(y[:,:10])

train_loader = DataLoaderLite(data_root="../shards", B=4, T=1024, process_rank=2, num_processes=3, split="train")

# Inspect some batches
print("Train batches:")
for i in range(3):
    x, y = train_loader.get_batch()
    print(f"Batch {i}: x.shape={x.shape}, y.shape={y.shape}")
    print(x[:,:10])
    print(y[:,:10])
    


train: 5 shard(s)
(443342, 2)
train index size: 147781 windows
Train batches:
Batch 0: x.shape=torch.Size([4, 1024]), y.shape=torch.Size([4, 1024])
tensor([[  284,   467,   706,   340,    13,  5501,   286,   514,   481,  5409],
        [ 4800,   290, 28554,   286,   617,   286,   674,   749, 14133,   290],
        [  292,  1040,   290, 20461,    78,   577, 11490,    13,  1052,  3942],
        [14429, 10223,  4493,   355,   880,    13, 14039,  1610,  3732,   329]])
tensor([[  467,   706,   340,    13,  5501,   286,   514,   481,  5409,   329],
        [  290, 28554,   286,   617,   286,   674,   749, 14133,   290,  9566],
        [ 1040,   290, 20461,    78,   577, 11490,    13,  1052,  3942,  1438],
        [10223,  4493,   355,   880,    13, 14039,  1610,  3732,   329,  4783]])
Batch 1: x.shape=torch.Size([4, 1024]), y.shape=torch.Size([4, 1024])
tensor([[  393,   262,  2347,   285,   286,   262, 11538,    11,   477,   530],
        [16388,  4040,   389, 23993,  2035,   780,   262,  6

In [17]:
indices = torch.randn((443342,2))
# num_processes = 2
print(indices[0::2].shape)
print(indices[1::2].shape)

torch.Size([221671, 2])
torch.Size([221671, 2])


In [18]:
# num_processes = 3
print(indices[0::3].shape)
print(indices[1::3].shape)
print(indices[2::3].shape)

torch.Size([147781, 2])
torch.Size([147781, 2])
torch.Size([147780, 2])


torch.Size([221671, 2])