In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("mps")

# 1. Mask and helpers

In [3]:
def make_pad_mask(ids, pad_id):
    """
    Inputs:
    ids: [B, T]
    pad_id: int

    Outputs:
    pad_mask: [B, T]
    
    """
    return ids.eq(pad_id)

In [14]:
def make_causal_mask(T, device=None):
    """
    Inputs
    ------
    T : int  (target length)

    Outputs
    -------
    causal_mask : FloatTensor [T, T]
        Upper triangle (strict) is -inf; others 0.
        For nn.MultiheadAttention(attn_mask) with batch_first=True.

    Purpose
    -------
    Prevent decoder positions from attending to future tokens.
    upper triangle are -inf, all other positions are 0
    """
    m = torch.zeros(T, T, dtype=torch.float32, device=device)
    m = m.masked_fill(torch.triu(torch.ones_like(m), diagonal=1).bool(), float('-inf'))
    return m

# 2. Positional Encoding

In [15]:
class PositionalEncoding(nn.Module):
    """
    Sinusoidal PE added to token embeddings
    给词向量添加正弦位置编码

    Inputs:
    x: [B, T, d_model]

    Outputs:
    x_pe: [B, T, d_model]

    add deterministic position information since attention is order-agnostic
    """
    def __init__(self, d_model, max_len=4096, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model, dtype=torch.float32) 
        # create an empty pe of shape [max_len, d_model], every row ==> correspond a position, every column ==> sin or cos value of a freq
        
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) 
        #[L,1], pos is 0, 1, 2,...,max_len-1
        
        div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        # div[i] = 10000^(- (2i/d_model)), the bigger i, the smaller value, shape is [d_model/2]
        
        pe[:, 0::2] = torch.sin(pos * div) # even --> sin
        pe[:, 1::2] = torch.cos(pos * div) # odd --> cos
        self.register_buffer("pe", pe.unsqueeze(0), persistent=False) #[1, L, d]
        #register_buffer：告诉 PyTorch 这是个常量（不会被优化器更新），但会跟着 model.to(device) 移动到显卡/CPU。
        #persistent=False，表示不会写进 state_dict()。

    def forward(self, x):
        B, T, D = x.shape
        return self.dropout(x + self.pe[:, :T, :])

# 3. FFN, Encoder, and Decoder block

In [16]:
class PositionwiseFFN(nn.Module):
    """
    Position-wise feed-forward network

    inputs:
    x [B, T, d_modl]

    outputs:
    y [B, T, d_model]

    two layer MLP (per time step): linear (d->d_ff) -> GELU -> Dropout -> Linear(d_ff->d)
    """

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x

In [37]:
class EncoderBlock(nn.Module):
    """
    one transformer encoder layer (pre-LN)

    inputs:
    x: [B, T_src, d_model]
    src_key_padding_mask: [B, T_src] or None

    outputs:
    y: [B, T_src, d_model]

    LN → 自注意力 → 残差，然后 LN → FFN → 残差
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, src_key_padding_mask):
        # first layer norm
        h = self.ln1(x)
        # multihead attention: Q=K=V=h
        attn_out, _ = self.mha(h, h, h, key_padding_mask=src_key_padding_mask, need_weights=False)
        # residual + dropout
        x = x + self.drop(attn_out)
        # layer norm
        h = self.ln2(x)
        # feedforward + dropout + residual
        x = x + self.drop(self.ffn(h))
        return x

In [18]:
class DecoderBlock(nn.Module):
    """
    one transformer decoder layer (pre-LN)

    inputs:
    y: [B, T_tgt, d_model]
    enc_out: [B, T_src, d_model]
    tgt_key_padding_mask: [B, T_tgt] or None
    tgt_causal_mask: [T_tgt, T_tgt] or None
    src_key_padding_mask: [B, T_src] or None

    outputs:
    y_out: [B, T_tgt, d_model]


    解码器自注意力（因果遮罩） → 交叉注意力（看编码器输出） → FFN，每段前 Pre-LN，段后做残差。
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ln3 = nn.LayerNorm(d_model)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, y, enc_out,
                tgt_key_padding_mask,
                tgt_causal_mask,
                src_key_padding_mask):
        # Masked self-attn (causal + tgt padding), must be causal, cannot see future
        h = self.ln1(y)
        sa_out, _ = self.self_attn(h, h, h,
                                   attn_mask=tgt_causal_mask,
                                   key_padding_mask=tgt_key_padding_mask,
                                   need_weights=False)
        # residual + dropout
        y = y + self.drop(sa_out)
        # Cross-attn over encoder outputs (mask pads on encoder keys) Q=decoder, K, V=encoder
        h = self.ln2(y)
        ca_out, _ = self.cross_attn(h, enc_out, enc_out,
                                    key_padding_mask=src_key_padding_mask,
                                    need_weights=False)
        y = y + self.drop(ca_out)
        # FFN
        h = self.ln3(y)
        y = y + self.drop(self.ffn(h))
        return y

# 4. Full encoder, decoder, and seq2seq model

In [32]:
class Encoder(nn.Module):
    """
    Encoder: embedding + posenc + N layers encoderblock
    inputs: src_ids [B, T_src]
            src_key_padding_mask [B, T_src]

    outputs: enc_out [B, T_src, d_model]
    """

    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, pad_id, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pe = PositionalEncoding(d_model, dropout=dropout)
        self.layers = nn.ModuleList([EncoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

    def forward(self, src_ids, src_key_padding_mask):
        x = self.pe(self.emb(src_ids))
        for layer in self.layers:
            x = layer(x, src_key_padding_mask)
        return x

In [41]:
class Decoder(nn.Module):
    """
    decoder: embedding + posenc + N layers decoderblock
    inputs: tgt_in [B, T_tgt]
            enc_out [B, T_src, d]
            tgt_key_padding_mask [B, T_tgt]
            tgt_causal_mask [T_tgt, T_tgt]
            src_key_padding_mask [B, T_src]
    outputs:
            dec_out [B, T_tgt, d_model]
    """
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, pad_id, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pe = PositionalEncoding(d_model, dropout=dropout)
        self.layers = nn.ModuleList([DecoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

    def forward(self, tgt_in, enc_out, tgt_key_padding_mask, tgt_causal_mask, src_key_padding_mask):
        y = self.pe(self.emb(tgt_in))
        for layer in self.layers:
            y = layer(y, enc_out, tgt_key_padding_mask, tgt_causal_mask, src_key_padding_mask)
        return y

In [39]:
class TransformerSeq2Seq(nn.Module):
    """
    full encode-decoder model
    inputs (training):
            src_ids [B, T_src]
            tgt_in [B, T_tgt]
            src_key_padding_mask [B, T_src]
            tgt_key_padding_mask [B, T_tgt]
            tgt_causal_mask [T_tgt, T_tgt]

    outputs:
            logits [B, T_tgt, V]
    """

    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, pad_id, dropout=0.1, tie_weights=True):
        super().__init__()
        self.encoder = Encoder(vocab_size, d_model, num_layers, num_heads, d_ff, pad_id, dropout)
        self.decoder = Decoder(vocab_size, d_model, num_layers, num_heads, d_ff, pad_id, dropout)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False) #project decoder to vocab
        if tie_weights:
            #weights share: output layer and decoder embedding sharing the weights (common skill)
            self.lm_head.weight = self.decoder.emb.weight

    def forward(self, src_ids, tgt_in, src_key_padding_mask, tgt_key_padding_mask, tgt_causal_mask):
        # encode a whole paragraph
        enc_out = self.encoder(src_ids, src_key_padding_mask) #[B, T_src, d]

        # decoder read in (w/ mask)
        dec_out = self.decoder(tgt_in, enc_out, tgt_key_padding_mask, tgt_causal_mask, src_key_padding_mask) #[B, T_tgt, d]
        
        # project to vocab
        logits = self.lm_head(dec_out) #[B, T_tgt, V]
        return logits

# 5. Toy data

In [23]:
def make_toy_batch(task, batch_size, T_max, V, pad, bos, eos):
    """
    构造随机小批次
    Inputs:
        task: 'copy' 或 'reverse'
        batch_size: B
        T_max: 最大源长度（每条样本真实长度随机于 [2, T_max]）
        V: 词表大小（>= eos+1）
        pad/bos/eos: 特殊符号 ID
    Returns:
        src_ids [B, T_src_max]
        tgt_in  [B, T_tgt_max]  (BOS + target[:-1])
        tgt_out [B, T_tgt_max]  (gold targets)
        src_pad [B, T_src_max]  (True=PAD)
        tgt_pad [B, T_tgt_max]  (True=PAD)
    """
    assert task in {"copy", "reverse"}
    lengths = np.random.randint(2, T_max + 1, size=batch_size)

    src_list, tgt_in_list, tgt_out_list = [], [], []
    for L in lengths:
        # 源 token 取自 [eos+1 .. V-1]，避免与 PAD/BOS/EOS 冲突
        src = np.random.randint(eos + 1, V, size=L).tolist()
        if task == "copy":
            tgt_core = src
        else:
            tgt_core = src[::-1]            # 反转任务

        # 目标序列加 BOS/EOS
        tgt = [bos] + tgt_core + [eos]
        # teacher-forcing: 输入右移一位
        tgt_in = tgt[:-1]
        tgt_out = tgt[1:]

        src_list.append(torch.tensor(src, dtype=torch.long))
        tgt_in_list.append(torch.tensor(tgt_in, dtype=torch.long))
        tgt_out_list.append(torch.tensor(tgt_out, dtype=torch.long))

    # pad_sequence 统一长度，padding_value=pad
    src_ids = nn.utils.rnn.pad_sequence(src_list, batch_first=True, padding_value=pad)
    tgt_in  = nn.utils.rnn.pad_sequence(tgt_in_list, batch_first=True, padding_value=pad)
    tgt_out = nn.utils.rnn.pad_sequence(tgt_out_list, batch_first=True, padding_value=pad)

    # 生成 PAD 掩码（True=PAD，需要屏蔽）
    src_pad = make_pad_mask(src_ids, pad)
    tgt_pad = make_pad_mask(tgt_out, pad)
    return src_ids, tgt_in, tgt_out, src_pad, tgt_pad

# 6. Loss, Decoding, Scheduler

In [24]:
def seq_ce_loss(logits, targets, pad_id, label_smoothing=0.1):
    """
    cross entropy
    inputs: logits [B, T, V], targets [B, T]
    outputs: loss (scaler)
    """

    B, T, V = logits.shape
    return F.cross_entropy(logits.reshape(B*T, V), targets.reshape(B*T), ignore_index=pad_id, label_smoothing=label_smoothing, reduction="mean")

In [45]:
@torch.no_grad()
def greedy_decode(model, src_ids, pad, bos, eos, max_len):
    """
    greedy decode (batch)
    Inputs:
            src_ids [B,T_src]
    Returns:
            ys [B, L]（以 BOS 开头，包含 EOS 或截断）
    """
    model.eval()
    B = src_ids.size(0)
    device = src_ids.device

    # 1). encoder once
    src_key_pad = make_pad_mask(src_ids, pad).to(device) #[B, T_src]
    enc_out = model.encoder(src_ids, src_key_pad) # [B, T_src, d]

    # 2). decoder sequence from BOS to generate stepwise
    ys = torch.full((B,1), bos, dtype=torch.long, device=device) #[B,1]
    finished = torch.zeros(B, dtype=torch.bool, device=device) # if generate EOS

    while ys.size(1) < max_len:
        #construct tgt pad and causal pad
        tgt_key_pad = make_pad_mask(ys, pad).to(device)             # [B,t]
        causal = make_causal_mask(ys.size(1), device=device)        # [t,t]

        # only use Decoder（already have enc_out）
        dec_out = model.decoder(ys, enc_out, tgt_key_pad, causal, src_key_pad)  # [B,t,d]
        logits = model.lm_head(dec_out)[:, -1, :]                    # 取最后一个位置的分布 [B,V]
        next_ids = torch.argmax(logits, dim=-1, keepdim=True)        # 贪心选最大概率 [B,1]
        ys = torch.cat([ys, next_ids], dim=1)                        # 追加到序列

        # if all samples have generated EOS, then stop earlier
        finished = finished | next_ids.squeeze(1).eq(eos)
        if torch.all(finished):
            break

    return ys

In [26]:
class TransformerLRScheduler:
    """
    Transformer 经典学习率：lr = d_model^-0.5 * min(step^-0.5, step*warmup^-1.5)
    用法：
        sched = TransformerLRScheduler(optimizer, d_model, warmup_steps)
        每个 step 调用 sched.step() 更新学习率
    """
    def __init__(self, optimizer, d_model, warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup = warmup_steps
        self.step_num = 0

    def _calc_lr(self, step):
        step = max(1, step)
        scale = self.d_model ** (-0.5)                     # d_model^-0.5
        return scale * min(step ** (-0.5), step * (self.warmup ** (-1.5)))

    def step(self):
        self.step_num += 1
        lr = self._calc_lr(self.step_num)
        for g in self.optimizer.param_groups:
            g['lr'] = lr


# 7. Train and Evaluate

In [27]:
def train_epoch(model, optimizer, scheduler, task, steps, V, T_max, pad, bos, eos,
                grad_clip=1.0, print_every=100):
    """
    合成数据上迭代 steps 次，观察 loss 下降
    """
    model.train()
    for step in range(1, steps + 1):
        # 取一批合成数据
        src, tgt_in, tgt_out, src_pad, tgt_pad = make_toy_batch(task, batch_size=64, T_max=T_max, V=V, pad=pad, bos=bos, eos=eos)
        src = src.to(DEVICE); tgt_in = tgt_in.to(DEVICE); tgt_out = tgt_out.to(DEVICE)
        src_pad = src_pad.to(DEVICE); tgt_pad = tgt_pad.to(DEVICE)

        # 构造解码端因果掩码（每个 batch 的 T_tgt 可不同，这里按本批长度构造）
        tgt_causal = make_causal_mask(tgt_in.size(1), device=DEVICE)

        # 前向：teacher forcing
        logits = model(src, tgt_in, src_pad, tgt_pad, tgt_causal)     # [B,T,V]
        loss = seq_ce_loss(logits, tgt_out, pad_id=pad, label_smoothing=0.1)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        # 裁剪梯度防止爆炸
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        scheduler.step()   # 更新学习率

        if step % print_every == 0:
            # 打印当前 loss 和 lr（便于观察 warmup 后的变化）
            print(f"[train step {step}/{steps}] loss={loss.item():.4f} lr={scheduler._calc_lr(scheduler.step_num):.6f}")

In [28]:
@torch.no_grad()
def evaluate_token_accuracy(model, task, V, T_max, pad, bos, eos, batches=20):
    """
    简单 token-level 准确率（忽略 PAD）
    """
    model.eval()
    total, correct = 0, 0
    for _ in range(batches):
        src, tgt_in, tgt_out, src_pad, tgt_pad = make_toy_batch(task, batch_size=64, T_max=T_max, V=V, pad=pad, bos=bos, eos=eos)
        src = src.to(DEVICE); tgt_in = tgt_in.to(DEVICE); tgt_out = tgt_out.to(DEVICE)
        src_pad = src_pad.to(DEVICE); tgt_pad = tgt_pad.to(DEVICE)
        tgt_causal = make_causal_mask(tgt_in.size(1), device=DEVICE)

        logits = model(src, tgt_in, src_pad, tgt_pad, tgt_causal)     # [B,T,V]
        pred = logits.argmax(-1)                                      # 取概率最大类 [B,T]
        mask = ~tgt_out.eq(pad)                                       # 只统计非 PAD
        correct += pred.eq(tgt_out).masked_select(mask).sum().item()
        total   += mask.sum().item()
    return correct / max(1, total)

# 8. Main

In [46]:
def main():
    # ------------ 超参 ------------
    TASK = "reverse"     # 'copy' 或 'reverse'
    V = 200              # 词表大小（包含特殊符号）
    PAD, BOS, EOS = 0, 1, 2
    D_MODEL = 128
    N_LAYERS = 2
    N_HEADS = 4
    D_FF = 4 * D_MODEL
    DROPOUT = 0.1
    WARMUP = 400
    STEPS  = 800         # 可增大到 2000+ 更稳定
    T_MAX  = 12          # 采样时的最大源长度

    # ------------ 构建模型 & 优化器 ------------
    model = TransformerSeq2Seq(
        vocab_size=V,
        d_model=D_MODEL,
        num_layers=N_LAYERS,
        num_heads=N_HEADS,
        d_ff=D_FF,
        pad_id=PAD,
        dropout=DROPOUT,
        tie_weights=True
    ).to(DEVICE)

    # AdamW + Transformer 学习率调度
    optimizer = torch.optim.AdamW(model.parameters(), lr=1.0, betas=(0.9, 0.98), weight_decay=0.01)
    scheduler = TransformerLRScheduler(optimizer, d_model=D_MODEL, warmup_steps=WARMUP)

    print(f"[Info] Device = {DEVICE}, Task = {TASK}")
    train_epoch(model, optimizer, scheduler, TASK, STEPS, V, T_MAX, PAD, BOS, EOS, grad_clip=1.0, print_every=100)

    # ------------ 评估 ------------
    acc = evaluate_token_accuracy(model, TASK, V, T_MAX, PAD, BOS, EOS, batches=20)
    print(f"[Eval] token accuracy ≈ {acc:.3f}")

    # ------------ 解码演示（贪心） ------------
    model.eval()
    src, tgt_in, tgt_out, _, _ = make_toy_batch(TASK, batch_size=3, T_max=T_MAX, V=V, pad=PAD, bos=BOS, eos=EOS)
    src = src.to(DEVICE)
    greedy = greedy_decode(model, src, PAD, BOS, EOS, max_len=T_MAX+3)
    print("\n[Greedy decode samples]")
    for i in range(src.size(0)):
        print("src :", src[i].tolist())
        print("pred:", greedy[i].tolist())

if __name__ == "__main__":
    main()

[Info] Device = mps, Task = reverse
[train step 100/800] loss=2030475.1250 lr=0.001105
[train step 200/800] loss=661915.6875 lr=0.002210
[train step 300/800] loss=289922.2500 lr=0.003315
[train step 400/800] loss=172438.4219 lr=0.004419
[train step 500/800] loss=112699.7344 lr=0.003953
[train step 600/800] loss=68927.2969 lr=0.003608
[train step 700/800] loss=52565.7969 lr=0.003341
[train step 800/800] loss=33424.7109 lr=0.003125
[Eval] token accuracy ≈ 0.083

[Greedy decode samples]
src : [84, 126, 122, 102, 78, 188, 165, 0, 0, 0, 0]
pred: [1, 2, 2]
src : [196, 137, 13, 176, 135, 0, 0, 0, 0, 0, 0]
pred: [1, 2, 2]
src : [51, 178, 76, 17, 156, 25, 104, 195, 89, 178, 195]
pred: [1, 68, 2]
