# Train

```{note}
现在有了 model, tokenizer 和 dataset, 我们可以开始训练了。
```

## Configuration

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

# Imports from local files
from torch_train.torch_tokenizer import BPETokenizer
from torch_train.torch_model import TransformerModel
from torch_train.torch_dataset import PretrainDataset

# ==========================================
# Configuration
# ==========================================
CONFIG = {
    # Data
    "parquet_path": "data/wikitext-103-raw-v1-train-sampled.parquet", # Use sample set for lower resource usage
    "tokenizer_path": "wiki-tokenizer-1.json",
    "seq_len": 512,           # Reduced context window
    "batch_size": 8,         # Reduced batch size
    
    # Model (Tiny size for stability)
    "d_model": 256,          # Reduced d_model
    "n_head": 8,
    "d_hidden": 1024,         # Reduced d_hidden
    "n_layer": 16,            # Reduced n_layer
    "dropout": 0.1,
    
    # Training
    "lr": 3e-4,
    "epochs": 2,             # Increased epochs since data is smaller
    "log_interval": 10,
    "save_path": "llm_checkpoint_27m.pt",
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

模型的参数量：
1. 除 transformer block 外的参数数量：
    *  embedding 层：`d_model * vocab_size`
    *  output 层：`vocab_size * d_model`
    *  RMSNorm：`d_model`
2.  每个 transformer block 的参数数量：
    *  attention 的参数量：
        *  qkv 投影：`3 * d_model * d_model`
        *  o 投影：`d_model * d_model`
        *  总的和 QK 的 RMSNorm：`3 * d_model`
    *  feed forward network 的参数量：
        *  gate 和 up projection：`2 * d_model * d_hidden`
        *  down projection：`d_hidden * d_model`


In [2]:
d_model = CONFIG["d_model"]
vocab_size = 20000
n_layer = CONFIG["n_layer"]
d_hidden = CONFIG["d_hidden"]
d_model * (2 * vocab_size + 1 ) + n_layer * (4 * d_model * d_model + 3 * d_model * d_hidden + 4 * d_model)


27033856

## 加载 tokenizer

In [12]:
print(f"Using device: {CONFIG['device']}")

print("Loading Tokenizer...")
tokenizer = BPETokenizer()
tokenizer.load(CONFIG["tokenizer_path"])
print(f"Tokenizer loaded. Vocab size: {len(tokenizer.vocab)}")

vocab_size = len(tokenizer.vocab)
pad_id = tokenizer.special_tokens.get("<PAD>", 0)

Using device: cuda
Loading Tokenizer...
Tokenizer loaded. Vocab size: 20000


## DataLoader

In [4]:
print(f"Loading data from {CONFIG['parquet_path']}...")
file_paths = [CONFIG["parquet_path"]] 

print("Initializing Dataset...")
# Request seq_len + 1 to handle input/target shift
dataset = PretrainDataset(file_paths, tokenizer, seq_len=CONFIG["seq_len"] + 1)

Loading data from data/wikitext-103-raw-v1-train-sampled.parquet...
Initializing Dataset...
Processing 1 files...


Processing files: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Tokenizing texts: 100%|██████████| 2944/2944 [01:01<00:00, 48.00it/s]


Total tokens: 12880663. Total samples (seq_len=513): 25109


In [5]:
dataloader = DataLoader(
    dataset, 
    batch_size=CONFIG["batch_size"], 
    shuffle=True, 
    num_workers=2
)

## 模型

In [6]:
print("Initializing Model...")
model = TransformerModel(
    d_model=CONFIG["d_model"],
    n_head=CONFIG["n_head"],
    d_hidden=CONFIG["d_hidden"],
    n_layer=CONFIG["n_layer"],
    vocab_size=vocab_size,
    max_seq_len=CONFIG["seq_len"] + 1, 
    dropout=CONFIG["dropout"]
).to(CONFIG["device"])

print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

Initializing Model...
Model parameters: 27.03M


## Training Loop

In [7]:
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = optim.AdamW(model.parameters(), lr=CONFIG["lr"])
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)

# Total steps for scheduler
total_steps = len(dataloader) * CONFIG["epochs"]
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)

model.train()

print("Starting training...")
for epoch in range(CONFIG["epochs"]):
    total_loss = 0
    # 使用 tqdm 包装 dataloader
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}", unit="batch")
    
    for batch_idx, data in enumerate(pbar):
        data = data.to(CONFIG["device"])
        
        # Input: [B, seq_len], Target: [B, seq_len] (shifted)
        input_ids = data[:, :-1]
        target_ids = data[:, 1:]

        # Generate position_ids
        B, T = input_ids.shape
        position_ids = torch.arange(T, device=CONFIG["device"]).unsqueeze(0).expand(B, T)
        
        optimizer.zero_grad()
        output = model(input_ids, position_ids)
        
        loss = criterion(output.reshape(-1, vocab_size), target_ids.reshape(-1))
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        
        # 更新进度条上的 loss 和 LR
        current_lr = scheduler.get_last_lr()[0]
        pbar.set_postfix({"loss": f"{loss.item():.4f}", "lr": f"{current_lr:.2e}"})
        
        # 每 200 step 打印一次 loss
        if (batch_idx + 1) % 200 == 0:
             print(f"Epoch {epoch+1} | Step {batch_idx+1}/{len(dataloader)} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}")
            
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} Complete. Average Loss: {avg_loss:.4f}")
    
    torch.save(model.state_dict(), CONFIG["save_path"])
    print(f"Checkpoint saved to {CONFIG['save_path']}")

Starting training...


Epoch 1/2:   6%|▋         | 200/3139 [00:50<12:25,  3.94batch/s, loss=6.3364, lr=2.99e-04]

Epoch 1 | Step 200/3139 | Loss: 6.3364 | LR: 2.99e-04


Epoch 1/2:  13%|█▎        | 400/3139 [01:40<11:23,  4.01batch/s, loss=5.9198, lr=2.97e-04]

Epoch 1 | Step 400/3139 | Loss: 5.9198 | LR: 2.97e-04


Epoch 1/2:  19%|█▉        | 600/3139 [02:31<10:44,  3.94batch/s, loss=5.6097, lr=2.93e-04]

Epoch 1 | Step 600/3139 | Loss: 5.6097 | LR: 2.93e-04


Epoch 1/2:  25%|██▌       | 800/3139 [03:21<09:52,  3.95batch/s, loss=5.2382, lr=2.88e-04]

Epoch 1 | Step 800/3139 | Loss: 5.2382 | LR: 2.88e-04


Epoch 1/2:  32%|███▏      | 1000/3139 [04:12<08:57,  3.98batch/s, loss=5.1867, lr=2.82e-04]

Epoch 1 | Step 1000/3139 | Loss: 5.1867 | LR: 2.82e-04


Epoch 1/2:  38%|███▊      | 1200/3139 [05:02<07:59,  4.04batch/s, loss=5.0102, lr=2.74e-04]

Epoch 1 | Step 1200/3139 | Loss: 5.0102 | LR: 2.74e-04


Epoch 1/2:  45%|████▍     | 1400/3139 [05:53<07:25,  3.91batch/s, loss=5.0341, lr=2.65e-04]

Epoch 1 | Step 1400/3139 | Loss: 5.0341 | LR: 2.65e-04


Epoch 1/2:  51%|█████     | 1600/3139 [06:45<06:28,  3.96batch/s, loss=4.8956, lr=2.54e-04]

Epoch 1 | Step 1600/3139 | Loss: 4.8956 | LR: 2.54e-04


Epoch 1/2:  57%|█████▋    | 1800/3139 [07:36<05:39,  3.95batch/s, loss=4.7985, lr=2.43e-04]

Epoch 1 | Step 1800/3139 | Loss: 4.7985 | LR: 2.43e-04


Epoch 1/2:  64%|██████▎   | 2000/3139 [08:27<04:51,  3.91batch/s, loss=4.7320, lr=2.31e-04]

Epoch 1 | Step 2000/3139 | Loss: 4.7320 | LR: 2.31e-04


Epoch 1/2:  70%|███████   | 2200/3139 [09:17<03:57,  3.96batch/s, loss=4.7224, lr=2.18e-04]

Epoch 1 | Step 2200/3139 | Loss: 4.7224 | LR: 2.18e-04


Epoch 1/2:  76%|███████▋  | 2400/3139 [10:08<03:06,  3.96batch/s, loss=4.7545, lr=2.04e-04]

Epoch 1 | Step 2400/3139 | Loss: 4.7545 | LR: 2.04e-04


Epoch 1/2:  83%|████████▎ | 2600/3139 [10:59<02:15,  3.96batch/s, loss=4.7348, lr=1.90e-04]

Epoch 1 | Step 2600/3139 | Loss: 4.7348 | LR: 1.90e-04


Epoch 1/2:  89%|████████▉ | 2800/3139 [11:49<01:24,  4.02batch/s, loss=4.4990, lr=1.75e-04]

Epoch 1 | Step 2800/3139 | Loss: 4.4990 | LR: 1.75e-04


Epoch 1/2:  96%|█████████▌| 3000/3139 [12:40<00:35,  3.94batch/s, loss=4.5412, lr=1.60e-04]

Epoch 1 | Step 3000/3139 | Loss: 4.5412 | LR: 1.60e-04


Epoch 1/2: 100%|██████████| 3139/3139 [13:15<00:00,  3.95batch/s, loss=4.5805, lr=1.50e-04]


Epoch 1 Complete. Average Loss: 5.2000
Checkpoint saved to llm_checkpoint_27m.pt


Epoch 2/2:   6%|▋         | 200/3139 [00:50<12:23,  3.95batch/s, loss=4.6308, lr=1.35e-04]

Epoch 2 | Step 200/3139 | Loss: 4.6308 | LR: 1.35e-04


Epoch 2/2:  13%|█▎        | 400/3139 [01:41<11:39,  3.92batch/s, loss=4.3921, lr=1.20e-04]

Epoch 2 | Step 400/3139 | Loss: 4.3921 | LR: 1.20e-04


Epoch 2/2:  19%|█▉        | 600/3139 [02:31<10:38,  3.98batch/s, loss=4.5454, lr=1.06e-04]

Epoch 2 | Step 600/3139 | Loss: 4.5454 | LR: 1.06e-04


Epoch 2/2:  25%|██▌       | 800/3139 [03:22<09:50,  3.96batch/s, loss=4.3014, lr=9.15e-05]

Epoch 2 | Step 800/3139 | Loss: 4.3014 | LR: 9.15e-05


Epoch 2/2:  32%|███▏      | 1000/3139 [04:13<08:59,  3.96batch/s, loss=4.2931, lr=7.80e-05]

Epoch 2 | Step 1000/3139 | Loss: 4.2931 | LR: 7.80e-05


Epoch 2/2:  38%|███▊      | 1200/3139 [05:03<08:09,  3.96batch/s, loss=4.2559, lr=6.52e-05]

Epoch 2 | Step 1200/3139 | Loss: 4.2559 | LR: 6.52e-05


Epoch 2/2:  45%|████▍     | 1400/3139 [05:54<07:20,  3.95batch/s, loss=4.3630, lr=5.33e-05]

Epoch 2 | Step 1400/3139 | Loss: 4.3630 | LR: 5.33e-05


Epoch 2/2:  51%|█████     | 1600/3139 [06:44<06:31,  3.93batch/s, loss=4.1623, lr=4.23e-05]

Epoch 2 | Step 1600/3139 | Loss: 4.1623 | LR: 4.23e-05


Epoch 2/2:  57%|█████▋    | 1800/3139 [07:35<05:41,  3.92batch/s, loss=4.3556, lr=3.24e-05]

Epoch 2 | Step 1800/3139 | Loss: 4.3556 | LR: 3.24e-05


Epoch 2/2:  64%|██████▎   | 2000/3139 [08:26<04:48,  3.95batch/s, loss=4.1755, lr=2.37e-05]

Epoch 2 | Step 2000/3139 | Loss: 4.1755 | LR: 2.37e-05


Epoch 2/2:  70%|███████   | 2200/3139 [09:16<03:57,  3.96batch/s, loss=4.3754, lr=1.63e-05]

Epoch 2 | Step 2200/3139 | Loss: 4.3754 | LR: 1.63e-05


Epoch 2/2:  76%|███████▋  | 2400/3139 [10:07<03:04,  4.01batch/s, loss=4.0642, lr=1.01e-05]

Epoch 2 | Step 2400/3139 | Loss: 4.0642 | LR: 1.01e-05


Epoch 2/2:  83%|████████▎ | 2600/3139 [10:58<02:16,  3.95batch/s, loss=4.7023, lr=5.42e-06]

Epoch 2 | Step 2600/3139 | Loss: 4.7023 | LR: 5.42e-06


Epoch 2/2:  89%|████████▉ | 2800/3139 [11:48<01:25,  3.97batch/s, loss=4.2579, lr=2.15e-06]

Epoch 2 | Step 2800/3139 | Loss: 4.2579 | LR: 2.15e-06


Epoch 2/2:  96%|█████████▌| 3000/3139 [12:39<00:34,  4.02batch/s, loss=4.3050, lr=3.63e-07]

Epoch 2 | Step 3000/3139 | Loss: 4.3050 | LR: 3.63e-07


Epoch 2/2: 100%|██████████| 3139/3139 [13:14<00:00,  3.95batch/s, loss=4.3537, lr=0.00e+00]


Epoch 2 Complete. Average Loss: 4.3229
Checkpoint saved to llm_checkpoint_27m.pt


## 推理

In [10]:
import math

# 在测试集上进行评估
test_path = "data/wikitext-103-raw-v1-test.parquet"
print(f"Loading test data from {test_path}...")
test_file_paths = [test_path]
test_dataset = PretrainDataset(test_file_paths, tokenizer, seq_len=CONFIG["seq_len"] + 1)
test_dataloader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=2)

def calculate_perplexity(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_steps = 0
    criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
    
    print("Calculating Perplexity...")
    with torch.no_grad():
        for data in tqdm(dataloader, desc="Evaluating"):
            data = data.to(device)
            input_ids = data[:, :-1]
            target_ids = data[:, 1:]
            
            B, T = input_ids.shape
            position_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
            
            output = model(input_ids, position_ids)
            loss = criterion(output.reshape(-1, vocab_size), target_ids.reshape(-1))
            
            total_loss += loss.item()
            total_steps += 1
            
    avg_loss = total_loss / total_steps
    ppl = math.exp(avg_loss)
    return avg_loss, ppl

test_loss, test_ppl = calculate_perplexity(model, test_dataloader, CONFIG["device"])
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Perplexity: {test_ppl:.4f}")

Loading test data from data/wikitext-103-raw-v1-test.parquet...
Processing 1 files...


Processing files: 100%|██████████| 1/1 [00:00<00:00, 117.19it/s]
Tokenizing texts: 100%|██████████| 62/62 [00:01<00:00, 40.49it/s]


Total tokens: 304473. Total samples (seq_len=513): 594
Calculating Perplexity...


Evaluating: 100%|██████████| 75/75 [00:06<00:00, 11.99it/s]

Test Loss: 4.4311
Test Perplexity: 84.0240





In [11]:
def generate(model, tokenizer, prompt, max_new_tokens=50):
    model.eval()
    
    # 1. Tokenize the prompt
    input_ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor(input_ids, dtype=torch.long, device=CONFIG["device"]).unsqueeze(0)
    
    print(f"Prompt: {prompt}")
    print("-" * 40)
    
    generated = input_ids.copy()
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # 2. Crop context if too long (keep only the last seq_len tokens)
            ctx = input_tensor[:, -CONFIG["seq_len"]:]
            
            # 3. Create position IDs
            B, T = ctx.shape
            position_ids = torch.arange(T, device=CONFIG["device"]).unsqueeze(0).expand(B, T)
            
            # 4. Forward pass
            output = model(ctx, position_ids)
            
            # 5. Get the logits for the last token
            next_token_logits = output[0, -1, :]
            
            # 6. Greedy decoding: choose the token with the highest probability
            next_token_id = torch.argmax(next_token_logits).item()
            
            # 7. Append the predicted token
            generated.append(next_token_id)
            input_tensor = torch.cat((input_tensor, torch.tensor([[next_token_id]], device=CONFIG["device"])), dim=1)
            
            # 8. Stop if <EOS> is generated
            if next_token_id == tokenizer.special_tokens.get("<EOS>"):
                break
                
    # 9. Decode the generated tokens to string
    decoded = tokenizer.decode(generated)
    print(decoded)
    print("=" * 40)
    return decoded

# Sample prompts
prompts = [
    "The weather is",
    "It was a",
    "In the early",
    "The game began with"
]

for p in prompts:
    generate(model, tokenizer, p)

Prompt: The weather is
----------------------------------------
The weather is a tropical cyclone that is a tropical cyclone , and is a tropical cyclone that is a tropical storm . The storm is a tropical storm , and is a tropical storm . The storm is a tropical storm , and the storm is a tropical storm . The storm
Prompt: It was a
----------------------------------------
It was a member of the Royal Navy . The ship was a ship of the ship 's main armament , and the ship was built in the Atlantic Ocean . The ship was built in the Pacific Ocean , and the ship was built in the Atlantic Ocean . The ship
Prompt: In the early
----------------------------------------
In the early 20th century . The first of the first two @-@ third @-@ century @-@ century @-@ style castles were the first of the first two @-@ third @-@ century American , the first of the first two @-@ third @-@ century American , the first of the first two
Prompt: The game began with
----------------------------------------
The 