# Train

```{note}
现在有了 model, tokenizer 和 dataset, 我们可以开始训练了。
```

## Configuration

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

# Imports from local files
from torch_train.torch_tokenizer import BPETokenizer
from torch_train.torch_model import TransformerModel
from torch_train.torch_dataset import PretrainDataset

# ==========================================
# Configuration
# ==========================================
CONFIG = {
    # Data
    "parquet_path": "data/wikitext-103-raw-v1-train.parquet", # Use train set for lower resource usage
    "tokenizer_path": "wiki-tokenizer-1.json",
    "seq_len": 512,           # Reduced context window
    "batch_size": 8,         # Reduced batch size
    
    # Model (Tiny size for stability)
    "d_model": 512,          # Reduced d_model
    "n_head": 8,
    "d_hidden": 2048,         # Reduced d_hidden
    "n_layer": 16,            # Reduced n_layer
    "dropout": 0.1,
    
    # Training
    "lr": 3e-4,
    "epochs": 2,             # Increased epochs since data is smaller
    "log_interval": 10,
    "save_path": "llm_checkpoint.pt",
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

模型的参数量：
1. 除 transformer block 外的参数数量：
    *  embedding 层：`d_model * vocab_size`
    *  output 层：`vocab_size * d_model`
    *  RMSNorm：`d_model`
2.  每个 transformer block 的参数数量：
    *  attention 的参数量：
        *  qkv 投影：`3 * d_model * d_model`
        *  o 投影：`d_model * d_model`
        *  总的和 QK 的 RMSNorm：`3 * d_model`
    *  feed forward network 的参数量：
        *  gate 和 up projection：`2 * d_model * d_hidden`
        *  down projection：`d_hidden * d_model`

$$
\text{参数量} = d\_model * (2 * vocab\_size + 1) + n\_layer * (4 * d\_model * d\_model + 3 * d\_model * d\_hidden + 4 * d\_model)
$$


In [2]:
512 * (2 * 20000 + 1) + 16 * (4 * 512 * 512 + 3 * 512 * 2048 + 4 * 512)

87622144

## 加载 tokenizer

In [4]:
print(f"Using device: {CONFIG['device']}")

print("Loading Tokenizer...")
tokenizer = BPETokenizer()
if os.path.exists(CONFIG["tokenizer_path"]):
    tokenizer.load(CONFIG["tokenizer_path"])
    print(f"Tokenizer loaded. Vocab size: {len(tokenizer.vocab)}")
else:
    raise FileNotFoundError(f"Tokenizer file not found at {CONFIG['tokenizer_path']}")

vocab_size = len(tokenizer.vocab)
pad_id = tokenizer.special_tokens.get("<PAD>", 0)

Using device: cuda
Loading Tokenizer...
Tokenizer loaded. Vocab size: 20000


## DataLoader

In [None]:
print(f"Loading data from {CONFIG['parquet_path']}...")
file_paths = [CONFIG["parquet_path"]] 

print("Initializing Dataset...")
# Request seq_len + 1 to handle input/target shift
dataset = PretrainDataset(file_paths, tokenizer, seq_len=CONFIG["seq_len"] + 1)

dataloader = DataLoader(
    dataset, 
    batch_size=CONFIG["batch_size"], 
    shuffle=True, 
    num_workers=2
)

## 模型

In [5]:
print("Initializing Model...")
model = TransformerModel(
    d_model=CONFIG["d_model"],
    n_head=CONFIG["n_head"],
    d_hidden=CONFIG["d_hidden"],
    n_layer=CONFIG["n_layer"],
    vocab_size=vocab_size,
    max_seq_len=CONFIG["seq_len"] + 1, 
    dropout=CONFIG["dropout"]
).to(CONFIG["device"])

print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

Initializing Model...
Model parameters: 87.61M


## Training Loop

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=CONFIG["lr"])
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)

model.train()

print("Starting training...")
for epoch in range(CONFIG["epochs"]):
    total_loss = 0
    # 使用 tqdm 包装 dataloader
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}", unit="batch")
    
    for batch_idx, data in enumerate(pbar):
        data = data.to(CONFIG["device"])
        
        # Input: [B, seq_len], Target: [B, seq_len] (shifted)
        input_ids = data[:, :-1]
        target_ids = data[:, 1:]

        # Generate position_ids
        B, T = input_ids.shape
        position_ids = torch.arange(T, device=CONFIG["device"]).unsqueeze(0).expand(B, T)
        
        optimizer.zero_grad()
        output = model(input_ids, position_ids)
        
        loss = criterion(output.reshape(-1, vocab_size), target_ids.reshape(-1))
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        # 更新进度条上的 loss
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
            
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} Complete. Average Loss: {avg_loss:.4f}")
    
    torch.save(model.state_dict(), CONFIG["save_path"])
    print(f"Checkpoint saved to {CONFIG['save_path']}")