In [32]:
import pickle
import numpy as np
import itertools
import time
import torch
import os
import segutil

In [None]:
with open('')

In [2]:
data_train = np.memmap('data/shakespeare_char/train.bin', dtype=np.uint16, mode='r')
data_val = np.memmap('data/shakespeare_char/val.bin', dtype=np.uint16, mode='r')

In [3]:
len(data_val), len(data_train)

(111540, 1003854)

In [4]:
dataset = 'shakespeare_char'
block_size = 256
batch_size = 64
device = 'mps'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

In [7]:
data_dir = os.path.join('data', dataset)
def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [5]:
# 012345678  --->  012, 234, 456, 678
# 01234567  --->  012, 234, 456, 67  -> last block is not full
block_size=3
data_size=8
n_blocks=4

(data_size-1)/(block_size-1)

3.5

In [31]:
class DataLoader: 
    def __init__(self, fpath, batch_size, block_size) -> None:
        self.batch_size = batch_size
        self.fpath = fpath
        # We recreate np.memmap every batch to avoid a memory leak, as per
        # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
        data = np.memmap(self.fpath, dtype=np.uint16, mode='r')
        n_blocks = np.ceil((len(data)-1)/(block_size-1)).astype(int)
        self.idx = np.arange(n_blocks*block_size).reshape(-1, block_size) - np.arange(n_blocks).reshape(-1, 1)
        self.idx = self.idx % len(data)

    def __iter__(self): 
        self.i = 0
        return self
    
    def __next__(self): 
        if self.i >= self.idx.shape[0]: 
            raise StopIteration()
        else: 
            data = np.memmap(self.fpath, dtype=np.uint16, mode='r')
            batch_idx = self.idx[self.i:self.i+batch_size]
            return data[batch_idx]
            self.i += self.batch_size
    
dl = DataLoader(os.path.join(data_dir, 'train.bin'), batch_size=8, block_size=32)
for bla in itertools.islice(dl, 4): 
    print(bla)

[[18 47 56 ... 43 43 42]
 [42  1 39 ...  0  0 13]
 [13 50 50 ... 47 64 43]
 ...
 [57 43 12 ... 63  0 47]
 [47 57  1 ...  1 46 43]
 [43 56 43 ... 51 43  6]]
[[18 47 56 ... 43 43 42]
 [42  1 39 ...  0  0 13]
 [13 50 50 ... 47 64 43]
 ...
 [57 43 12 ... 63  0 47]
 [47 57  1 ...  1 46 43]
 [43 56 43 ... 51 43  6]]
[[18 47 56 ... 43 43 42]
 [42  1 39 ...  0  0 13]
 [13 50 50 ... 47 64 43]
 ...
 [57 43 12 ... 63  0 47]
 [47 57  1 ...  1 46 43]
 [43 56 43 ... 51 43  6]]
[[18 47 56 ... 43 43 42]
 [42  1 39 ...  0  0 13]
 [13 50 50 ... 47 64 43]
 ...
 [57 43 12 ... 63  0 47]
 [47 57  1 ...  1 46 43]
 [43 56 43 ... 51 43  6]]


In [27]:
class RestartableBatchIterator:
    def __init__(self, iterable, batch_size):
        self.iterable = iterable
        self.batch_size = batch_size
    def __iter__(self):
        self.iter = iter(itertools.batched(self.iterable, self.batch_size))
        return self
    def __next__(self):
        return next(self.iter)
    
list(RestartableBatchIterator([1,2,3,4,5], 3))

[(1, 2, 3), (4, 5)]