# DataLoader

```{note}
`Dataset` 和 `DataLoader` 是 PyTorch 数据处理的核心组件。简单来说：

*   Dataset：负责 “存” 和 “取”。它定义了数据在哪里，以及如何取出一个样本（通常通过索引）。
*   DataLoader：负责 “运”。它从 Dataset 里批量取出数据，打成包 (Batch)，可以顺便做打乱 (Shuffle) 和多进程加速。
```

## DataSet

为了直观理解，我们结合刚才编写的 `BPETokenizer` 来理解 `Dataset`。

我们需要继承 `torch.utils.data.Dataset` 并实现两个方法：
- `__len__`: 告诉 PyTorch 数据集有多大。
- `__getitem__`: 告诉 PyTorch 第 `idx` 个样本长什么样。

In [8]:
import torch
from torch.utils.data import Dataset
import math
import random
import pandas as pd

class PretrainDataset(Dataset):
    def __init__(self, file_paths, tokenizer, seq_len, text_col='page'):
        """
        :param file_paths: Parquet 文件路径列表
        :param tokenizer: 分词器 (需要包含 <EOS> 和 <PAD>)
        :param seq_len: 训练序列长度 (Context Window)
        """
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        
        # 获取特殊 token ID
        self.eos_id = tokenizer.special_tokens.get("<EOS>")
        self.pad_id = tokenizer.special_tokens.get("<PAD>")
        
        if self.eos_id is None:
            raise ValueError("Tokenizer must have <EOS> defined for pretraining.")
        if self.pad_id is None:
            print("Warning: <PAD> not found in tokenizer. Using 0 as pad_id.")
            self.pad_id = 0
            
        print(f"Processing {len(file_paths)} files...")
        
        # 1. 读取并编码所有文本
        texts = []
        for file_path in file_paths:
            try:
                df = pd.read_parquet(file_path)
                if text_col in df.columns:
                    texts.extend(df[text_col].dropna().tolist())
            except Exception as e:
                print(f"Error reading {file_path}: {e}")
        
        # 为了增加随机性，可以在拼接前打乱文本顺序
        random.shuffle(texts)
        
        all_token_ids = []
        for text in texts:
            ids = tokenizer.encode(text)
            all_token_ids.extend(ids)
            # 每条数据后加 EOS
            all_token_ids.append(self.eos_id)
            
        # 转为 Tensor 存储
        self.data = torch.tensor(all_token_ids, dtype=torch.long)
        
        # 计算样本总数
        self.num_samples = math.ceil(len(self.data) / self.seq_len)
        print(f"Total tokens: {len(self.data)}. Total samples (seq_len={seq_len}): {self.num_samples}")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 计算切片范围
        start = idx * self.seq_len
        end = min(start + self.seq_len, len(self.data))
        
        chunk = self.data[start:end]
        
        # 如果长度不足 seq_len，进行 Padding
        if len(chunk) < self.seq_len:
            pad_len = self.seq_len - len(chunk)
            padding = torch.full((pad_len,), self.pad_id, dtype=torch.long)
            chunk = torch.cat([chunk, padding])
            
        return chunk

## IterableDataset
          
面对超大文件或者海量数据，不要一次性读入内存。你需要切换到 PyTorch 的 `IterableDataset`，并采用流式读取 (Streaming)的方式。

这里有三个核心调整点：

1.  继承 `IterableDataset`：
    *   不要实现 `__getitem__`（因为它暗示随机访问，需要知道总长度）。
    *   实现 `__iter__`，使用 Python 的 `yield` 关键字逐行/逐块“吐”出数据。

2.  文件切分 (Sharding)：
    *   在多进程 (`num_workers > 0`) 模式下，如果不做处理，每个 Worker 都会从头读取所有文件，导致数据重复。
    *   必须在 `__iter__` 中获取 `torch.utils.data.get_worker_info()`，根据 Worker ID 分配不同的文件子集。

3.  Shuffle 的变化：
    *   `DataLoader(shuffle=True)` 对 IterableDataset **无效**。
    *   解决方案：维护一个内存缓冲区 (Buffer)，先读入比如 10000 条数据，在缓冲区内随机打乱，然后 yield 出来。

In [None]:
from torch.utils.data import IterableDataset

class IterablePretrainDataset(IterableDataset):
    def __init__(self, file_paths, tokenizer, seq_len):
        """
        :param file_paths: Parquet 文件路径列表
        :param tokenizer: 分词器 (需要包含 <EOS> 和 <PAD>)
        :param seq_len: 训练序列长度 (Context Window)
        """
        self.file_paths = file_paths
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        
        # 获取特殊 token ID
        self.eos_id = tokenizer.special_tokens.get("<EOS>")
        self.pad_id = tokenizer.special_tokens.get("<PAD>")
        
        if self.eos_id is None:
            raise ValueError("Tokenizer must have <EOS> defined for pretraining.")
        if self.pad_id is None:
            print("Warning: <PAD> not found in tokenizer. Using 0 as pad_id.")
            self.pad_id = 0
            
    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is None:
            # 单进程模式，处理所有文件
            my_files = self.file_paths
        else:
            # 多进程模式，按 worker_id 分配文件
            # 简单的 stride 分配: file[i] 分给 worker[i % num_workers]
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            my_files = [f for i, f in enumerate(self.file_paths) if i % num_workers == worker_id]
            
        # 可以在文件级别 shuffle (如果需要的话)
        # random.shuffle(my_files)
        
        buffer = []
        
        for file_path in my_files:
            try:
                # 使用 pyarrow 流式读取 Parquet
                parquet_file = pq.ParquetFile(file_path)
                
                # 每次读取一个行组 (Row Group) 或者指定 batch_size
                # batch_size=1000 行
                for batch in parquet_file.iter_batches(batch_size=1000):
                    df = batch.to_pandas()
                    
                    # 查找文本列
                    if "page" in df.columns:
                        texts = df["page"].dropna().tolist()
                    elif "text" in df.columns:
                        texts = df["text"].dropna().tolist()
                    else:
                        continue # 跳过没有文本列的 batch
                        
                    for text in texts:
                        # Tokenize
                        ids = self.tokenizer.encode(text)
                        ids.append(self.eos_id) # 添加 EOS
                        buffer.extend(ids)
                        
                        # 当 buffer 足够切分时，yield 出来
                        while len(buffer) >= self.seq_len:
                            chunk = buffer[:self.seq_len]
                            buffer = buffer[self.seq_len:]
                            yield torch.tensor(chunk, dtype=torch.long)
                            
            except Exception as e:
                print(f"Error reading file {file_path}: {e}")
                continue
                
        # 处理剩余的 buffer (如果不为空)
        # 这里选择 yield 出来，交给 collate_fn 去 padding
        if len(buffer) > 0:
            yield torch.tensor(buffer, dtype=torch.long)

## DataLoader

我们使用 `PretrainDataset` 为例，演示如何使用 DataLoader 加载数据。

In [10]:
from bpe_tokenizer import BPETokenizer

tokenizer = BPETokenizer()
tokenizer.load('wiki-tokenizer-1.json')

In [11]:
from torch.utils.data import DataLoader

file_paths = ['data/wikitext-103-raw-v1-validation.parquet']
dataset = PretrainDataset(file_paths, tokenizer, seq_len=256, text_col='page')
dataloader = DataLoader(
    dataset,
    batch_size=2,      # 每次取 2 个句子
    shuffle=True,      # 打乱顺序
)


Processing 1 files...
Total tokens: 266833. Total samples (seq_len=256): 1043


In [12]:
print("Starting iteration...")
for batch_idx, batch in enumerate(dataloader):
    print(f"\n--- Batch {batch_idx} ---")
    print(f"Shape: {batch.shape}")
    # print(f"Data:\n{batch}")
    
    # 简单的解码演示 (只解码第一个样本，忽略 padding 0)
    # 注意：我们的 tokenizer decode 需要 list[int]
    first_sample_ids = batch[0].tolist()
    # 过滤掉 padding (0)
    valid_ids = [i for i in first_sample_ids if i != 0]
    print(f"Decoded (1st sample): '{tokenizer.decode(valid_ids)}'")
    break


Starting iteration...

--- Batch 0 ---
Shape: torch.Size([2, 256])
Decoded (1st sample): ' ) , this caused concern among German government officials and clergy over data security and the potential for espionage . To assuage these concerns , Microsoft Germany agreed to provide a means to disable the utility . Following letters of complaint about discrimination from Scientology lawyers , some American companies such as General Electric , IBM and Ford Motor Company instructed their German subsidiaries to cease the use of protective declarations . 
 The city @-@ state of Hamburg set up a full @-@ time office dedicated to opposing Scientology , the Scientology Task Force for the Hamburg Interior Authority , under the leadership of Ursula Caberta . In 2005 , in a case brought by a Scientologist , the Federal Administrative Court of Germany ordered the city of Hamburg to cease recommending the use of protective declarations to its business community , finding that the practice infringed relig