In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math

# hyperparameters
batch_size = 32         # 批次处理的序列数
block_size = 68        # 上下文最大长度
max_iters = 1000        # 总迭代次数（示例中为1000，可以根据需要调整）
eval_interval = 100     # 评估间隔
learning_rate = 1e-3    # 学习率
n_embd = 64            # embedding 维度
n_head = 4              # 注意力头数
n_layer = 4             # Transformer 层数
dropout = 0           # dropout 率

# 初始每个注意力头中用于参数 Token 化的 token 数量
num_param_tokens = 64

# 设备选择
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

eval_iters = 200

torch.manual_seed(1337)

##########################
# TokenFormer版注意力模块 #
##########################

class TokenFormerHead(nn.Module):
    """ 一个基于参数 Token 化的注意力头
        - query: 由输入 token 得到
        - key & value: 由一组可学习的参数 tokens 得到
    """
    def __init__(self, head_size, num_param_tokens):
        super().__init__()
        self.query = nn.Linear(n_embd, head_size, bias=True)
        # 用可学习的参数 tokens 表示 key 和 value
        self.KP = nn.Parameter(torch.randn(num_param_tokens, head_size))
        self.VP = nn.Parameter(torch.randn(num_param_tokens, head_size))
        self.dropout = nn.Dropout(dropout)
        self.scale = head_size ** -0.5

    def forward(self, x):
        # x 的形状: (B, T, n_embd)
        B, T, _ = x.shape
        q = self.query(x)  # (B, T, head_size)
        attn_scores = q @ self.KP.T * self.scale  # (B, T, num_param_tokens)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        out = attn_weights @ self.VP  # (B, T, head_size)
        return out

class MultiHeadAttention(nn.Module):
    """ 多头 TokenFormer 注意力模块 """
    def __init__(self, num_heads, head_size, num_param_tokens):
        super().__init__()
        self.heads = nn.ModuleList([
            TokenFormerHead(head_size, num_param_tokens) for _ in range(num_heads)
        ])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

##########################
# 其他模块保持不变       #
##########################

class FeedFoward(nn.Module):
    """ 一个简单的前馈层 """
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer 块：先进行注意力交互再进行前馈计算 """
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, num_param_tokens)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

##########################
# Progressive Scaling 部分 #
##########################

def progressive_scale_model(model, new_num_param_tokens):
    """
    对模型中所有 TokenFormerHead 模块进行扩展，
    将每个头中的 KP 和 VP 参数扩展到 new_num_param_tokens 行（即增加参数 tokens）。
    新增的参数 tokens 采用零初始化。
    """
    for block in model.blocks:
        for head in block.sa.heads:
            old_KP = head.KP  # (old_num, head_size)
            old_VP = head.VP  # (old_num, head_size)
            old_num = old_KP.shape[0]
            if new_num_param_tokens <= old_num:
                continue
            head_size = old_KP.shape[1]
            new_KP = torch.zeros(new_num_param_tokens - old_num, head_size, device=old_KP.device)
            new_VP = torch.zeros(new_num_param_tokens - old_num, head_size, device=old_VP.device)
            head.KP = nn.Parameter(torch.cat([old_KP, new_KP], dim=0))
            head.VP = nn.Parameter(torch.cat([old_VP, new_VP], dim=0))
    elapsed_total = time.time() - global_start_time
    print(f"Cumulative training time: {elapsed_total:.2f} seconds")
    print(f"\n========== Scaling Event ==========")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Parameter tokens per head expanded to: {new_num_param_tokens}, Total model parameters now: {total_params/1e6:.2f}")
    return model

##########################
# 数据预处理及训练部分 #
##########################

with open('pokemon.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch: i for i, ch in enumerate(chars) }
itos = { i: ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data_ = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_) - block_size, (batch_size,))
    x = torch.stack([data_[i:i+block_size] for i in ix])
    y = torch.stack([data_[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss_and_ppl():
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch('train')
        _, loss = model(X, Y)
        losses[k] = loss.item()
    avg_loss = losses.mean().item()
    model.train()
    return avg_loss, math.exp(avg_loss)

# 打印整个训练集的总输入 token 数量
total_input_tokens = len(train_data)
print(f"Total input tokens in training data: {total_input_tokens}")

# 构造模型并转移到设备上
model = GPTLanguageModel(vocab_size)
model = model.to(device)
print(f"Initial model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f} M parameters")

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# 设定 progressive scaling 的计划（示例中为较大扩展）
scaling_schedule = {250: 256, 500: 512, 750: 1024, 900: 2048}

# 记录上一次 scaling 前的 token 数量
prev_token_count = num_param_tokens

# 统计训练过程中处理的输入 token 总数
total_training_tokens = 0

# 全局训练开始时间，用于计算累计训练时间
global_start_time = time.time()

start_time = time.time()
for iter in range(max_iters):
    # 检查当前迭代是否在 scaling 计划中
    if iter in scaling_schedule:
        new_token_count = scaling_schedule[iter]
        model = progressive_scale_model(model, new_token_count)
        prev_token_count = new_token_count

    # 统计每个 batch 处理的输入 token 数量
    xb, yb = get_batch('train')
    total_training_tokens += xb.shape[0] * xb.shape[1]
    
    if iter % eval_interval == 0 or iter == max_iters - 1:
        loss_val, ppl_val = estimate_loss_and_ppl()
        print(f"Iter {iter}: Loss {loss_val:.4f}, Perplexity {ppl_val:.4f}")
    
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
end_time = time.time()

print(f"\nTotal training time: {end_time - global_start_time:.2f} seconds.")
print(f"Total training tokens processed: {total_training_tokens}")

# 生成文本
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated = model.generate(context, max_new_tokens=500)[0].tolist()
print("\nGenerated Text:")
print(decode(generated))


Total input tokens in training data: 458319
Initial model parameters: 0.41 M parameters
Iter 0: Loss 4.3477, Perplexity 77.3031
Iter 100: Loss 2.6374, Perplexity 13.9765
Iter 200: Loss 2.5282, Perplexity 12.5315
Cumulative training time: 37.09 seconds

Parameter tokens per head expanded to: 256, Total model parameters now: 0.61
Iter 300: Loss 2.4851, Perplexity 12.0020
Iter 400: Loss 2.4708, Perplexity 11.8323
Cumulative training time: 65.99 seconds

Parameter tokens per head expanded to: 512, Total model parameters now: 0.87
Iter 500: Loss 2.4442, Perplexity 11.5212
Iter 600: Loss 2.4480, Perplexity 11.5649
Iter 700: Loss 2.4395, Perplexity 11.4672
Cumulative training time: 97.52 seconds

Parameter tokens per head expanded to: 1024, Total model parameters now: 1.39
Iter 800: Loss 2.4331, Perplexity 11.3942
Cumulative training time: 116.40 seconds

Parameter tokens per head expanded to: 2048, Total model parameters now: 2.44
Iter 900: Loss 2.4298, Perplexity 11.3566
Iter 999: Loss 2.43