In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, json
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import numpy as np

device = torch.device('cuda:0')

# Tokenizer utils

In [2]:
class CharTokenizer:
    def __init__(self, chars, add_unk=True):
        self.add_unk = add_unk
        # 可选的 UNK：放在 vocab[0]
        if add_unk:
            chars = ["<unk>"] + [c for c in chars if c != "<unk>"]
        self.chars = list(chars)
        self.stoi  = {ch: i for i, ch in enumerate(self.chars)}
        self.itos  = {i: ch for i, ch in enumerate(self.chars)}
        self.vocab_size = len(self.chars)
        self.eos_id = None  # 字符级一般不用 eos

    @classmethod
    def build_from_text(cls, train_text, *, cover_val=None, cover_test=None, add_unk=True):
        # 只用 train 建表；若想保证不 OOV，可把 val/test 的新字符并入：
        vocab = sorted(list(set(train_text)))
        if cover_val is not None:
            vocab = sorted(list(set(vocab).union(set(cover_val))))
        if cover_test is not None:
            vocab = sorted(list(set(vocab).union(set(cover_test))))
        return cls(vocab, add_unk=add_unk)

    def encode(self, s: str):
        if self.add_unk:
            unk = self.stoi["<unk>"]
            return [self.stoi.get(c, unk) for c in s]
        else:
            # 无 unk 模式：确保外部不含未知字符
            return [self.stoi[c] for c in s]

    def decode(self, ids):
        return "".join(self.itos.get(int(i), "") for i in ids)

    def save(self, path: str):
        obj = {"chars": self.chars, "add_unk": self.add_unk}
        json.dump(obj, open(path, "w", encoding="utf-8"), ensure_ascii=False, indent=2)

    @classmethod
    def load(cls, path: str):
        obj = json.load(open(path, "r", encoding="utf-8"))
        return cls(obj["chars"], add_unk=obj.get("add_unk", True))


# ===== 3) 编码为 LongTensor，并落盘为 .bin（uint32 memmap）=====
def encode_to_memmap(s: str, spath: str, tokenizer: CharTokenizer):
    ids = np.array(tokenizer.encode(s), dtype=np.uint32)
    m = np.memmap(spath, dtype=np.uint32, mode="w+", shape=(ids.size,))
    m[:] = ids
    del m
    return ids.size

In [4]:
# 1) read 3 pieces of text（each with 1 row）
ds_name = "karpathy/tiny_shakespeare"
ds = load_dataset(ds_name)

train_text = ds["train"][0]["text"]
val_text   = ds["validation"][0]["text"]
test_text  = ds["test"][0]["text"]


train_bin = "data/train.bin"
val_bin   = "data/val.bin"
test_bin  = "data/test.bin"

# 选择：仅用 train 建表 + <unk> 兜底（推荐）
tokenizer = CharTokenizer.build_from_text(train_text, add_unk=True)
# 若你更想“完全覆盖 val/test 的字符”，改成：
# tokenizer = CharTokenizer.build_from_text(train_text, cover_val=val_text, cover_test=test_text, add_unk=False)
tok_path = "data/char_tokenizer.json"
tokenizer.save(tok_path)

train_tokens = encode_to_memmap(train_text, train_bin, tokenizer)
val_tokens   = encode_to_memmap(val_text,   val_bin,   tokenizer)
test_tokens  = encode_to_memmap(test_text,  test_bin,  tokenizer)

# 保存元信息
meta = {
    "tokenizer_type": "char",
    "tokenizer_path": tok_path,
    "vocab_size": tokenizer.vocab_size,
    "has_unk": tokenizer.add_unk,
    "unk_id": tokenizer.stoi.get("<unk>") if tokenizer.add_unk else None,
    "train_tokens": int(train_tokens),
    "val_tokens": int(val_tokens),
    "test_tokens": int(test_tokens),
}
json.dump(meta, open("data/meta.json", "w", encoding="utf-8"), ensure_ascii=False, indent=2)

print("vocab_size:", tokenizer.vocab_size)
print("corpus sizes:", train_tokens, val_tokens, test_tokens)

# ===== 4) 直接读回来作为 LongTensor（若你仍想在内存里调试）=====
train_data = torch.from_numpy(np.memmap(train_bin, dtype=np.uint32, mode="r").copy()).long()
val_data   = torch.from_numpy(np.memmap(val_bin,   dtype=np.uint32, mode="r").copy()).long()
test_data  = torch.from_numpy(np.memmap(test_bin,  dtype=np.uint32, mode="r").copy()).long()

# （可选）安全自检：val/test 是否存在 OOV（仅在 add_unk=False 时有意义）
if not tokenizer.add_unk:
    assert set(val_text).issubset(set(train_text)) and set(test_text).issubset(set(train_text)), \
        "val/test 含有 train 未见字符；要么加 <unk>，要么改为并入 val/test 字符建表"

vocab_size: 66
corpus sizes: 1003854 55770 55770


# Load and prepare dataset

In [5]:
class RandomContiguousWindows(Dataset):
    """
    get random continuous windows from a single long sequence. 

    To reduce sample-wise processing costs, we perform one `unfold` in the __init__
    
    This class returns a view of a shape "(N-T) × (T+1)" (no copying), __getitem__ only index once
    """
    def __init__(self, src_1d_long: torch.Tensor, block_size: int, epoch_samples: int):
        assert src_1d_long.dtype == torch.long and src_1d_long.dim() == 1
        self.src = src_1d_long
        self.T = block_size
        self.length = epoch_samples  # 每个 epoch 想暴露多少“随机窗口”
        # 预构建所有起点的窗口视图（注意：这是 as_strided 视图，不是复制）
        # shape: (N - T) × (T+1)
        self.windows = self.src.unfold(dimension=0, size=self.T + 1, step=1)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # idx 实际只是“取多少次”的计数器，这里我们每次随机挑一个起点
        i = torch.randint(0, self.windows.size(0), (1,)).item()
        w = self.windows[i]     # (T+1,)
        x = w[:-1]              # (T,)
        y = w[1:]               # (T,)
        return x, y


In [6]:
block_size = 128
train_ds = RandomContiguousWindows(train_data, block_size, epoch_samples=100000)
val_ds   = RandomContiguousWindows(val_data,   block_size, epoch_samples=10000)

In [7]:
train_loader = DataLoader(train_ds, batch_size=64, shuffle=False,
                          num_workers=1, pin_memory=True, drop_last=True,
                          prefetch_factor=2, persistent_workers=True)

val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False,
                          num_workers=1, pin_memory=True, drop_last=True)

# Architecture

## DropPath

In [8]:
class DropPath(nn.Module):
    """
    Stochastic depth (a.k.a. DropPath) for residual branches.
    Applies only during training.
    """
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = keep + torch.rand(shape, dtype=x.dtype, device=x.device)
        mask.floor_()
        return x / keep * mask

## FFN

In [9]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm as _spectral_norm


def _round_to_multiple(x: int, base: int, mode: str = "nearest") -> int:
    """
    Round integer x to a multiple of `base`.

    Args:
        x: Value to round.
        base: The multiple to round to (e.g., 128, 256). If <=0, returns x.
        mode: One of {"nearest", "up", "down"}.

    Returns:
        int: Rounded value.
    """
    if base <= 0:
        return x
    if mode == "up":
        return int(math.ceil(x / base) * base)
    if mode == "down":
        return int(math.floor(x / base) * base)
    # nearest
    down = int(math.floor(x / base) * base)
    up = int(math.ceil(x / base) * base)
    return up if (x - down) > (up - x) else down


class FFN(nn.Module):
    """
    Flexible feed-forward network for Transformer blocks (no norm/residual inside).

    Design:
        - GLU family ("swiglu"/"geglu"): one Linear produces 2 * hidden_dim, split into
          value and gate, then element-wise multiply, followed by the output projection.
        - Plain two-layer FFN ("gelu"/"silu"): Linear -> activation -> Linear.
        - No normalization, residual, or DropPath in this module; those belong outside.

    Args:
        d_model (int):
            Input (and default output) channel size, typically the Transformer model width.
        hidden_dim (int | None):
            Intermediate width. If None, it is inferred based on `kind` and then aligned:
              * kind in {"gelu", "silu"}   -> target = 4 * d_model
              * kind in {"swiglu", "geglu"} -> target = (8/3) * d_model (≈ 2.67×)
            The target is rounded to the nearest multiple of `align_to`
            (default: 256 if d_model >= 1024 else 128).
        out_dim (int | None):
            Output channel size. Defaults to `d_model` (convenient for residual add).
        kind (str):
            Activation/topology: one of {"swiglu", "geglu", "gelu", "silu"}.
            - "swiglu": gate uses SiLU
            - "geglu" : gate uses GELU
            - "gelu"  : plain two-layer FFN with GELU
            - "silu"  : plain two-layer FFN with SiLU
        dropout (float):
            Dropout probability applied once on the FFN output. Default 0.0.
        bias (bool):
            Whether Linear layers use bias. Defaults to False (common in modern LLMs).
        spectral_norm (bool):
            If True, wrap Linear layers with spectral normalization. Default False.
            (Not typically needed; reserved for special stability constraints.)
        align_to (int | None):
            Multiple to align `hidden_dim` to when auto-inferred. If None, uses
            256 when d_model >= 1024, else 128.

    Attributes:
        d_model (int): Saved model width.
        hidden_dim (int): Final intermediate width after inference/alignment.
        out_dim (int): Final output width.
    """
    def __init__(
        self,
        d_model: int,
        hidden_dim: int | None = None,
        out_dim: int | None = None,
        kind: str = "swiglu",
        dropout: float = 0.0,
        bias: bool = False,
        spectral_norm: bool = False,
        align_to: int | None = None,
    ):
        super().__init__()
        assert kind in {"swiglu", "geglu", "gelu", "silu"}

        self.d_model = d_model
        self.kind = kind
        self.out_dim = out_dim if out_dim is not None else d_model

        # --- Auto-infer hidden_dim when not provided ---
        if hidden_dim is None:
            base = align_to if align_to is not None else (256 if d_model >= 1024 else 128)
            if kind in {"swiglu", "geglu"}:
                target = int(round((8.0 / 3.0) * d_model))  # ≈2.67× d_model
            else:  # 'gelu' or 'silu'
                target = 4 * d_model
            hidden_dim = _round_to_multiple(target, base, mode="nearest")
        self.hidden_dim = hidden_dim

        # --- Layers ---
        self.dropout = nn.Dropout(dropout) if (dropout and dropout > 0.0) else nn.Identity()

        if kind in {"swiglu", "geglu"}:
            # Produce 2*hidden_dim in one matmul; then split into (value, gate).
            lin_in = nn.Linear(d_model, 2 * hidden_dim, bias=bias)
            lin_out = nn.Linear(hidden_dim, self.out_dim, bias=bias)
        else:  # 'gelu' or 'silu'
            lin_in = nn.Linear(d_model, hidden_dim, bias=bias)
            lin_out = nn.Linear(hidden_dim, self.out_dim, bias=bias)

        if spectral_norm:
            lin_in = _spectral_norm(lin_in)
            lin_out = _spectral_norm(lin_out)

        self.proj_in = lin_in
        self.proj_out = lin_out

        # Activation: used as gate (GLU) or mid-activation (two-layer FFN).
        self.act = {
            "swiglu": F.silu,   # gate activation
            "geglu":  F.gelu,   # gate activation
            "gelu":   F.gelu,   # mid activation
            "silu":   F.silu,   # mid activation
        }[kind]

        self.reset_parameters()

    def reset_parameters(self):
        """Xavier-uniform for Linear weights; zeros for biases."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Input tensor of shape (..., d_model).

        Returns:
            Tensor of shape (..., out_dim). If out_dim == d_model, it fits residual add.
        """
        if self.kind in {"swiglu", "geglu"}:
            value, gate = self.proj_in(x).chunk(2, dim=-1)
            z = value * self.act(gate)      # GLU-style gating
        else:
            z = self.act(self.proj_in(x))   # plain two-layer FFN mid-activation
        out = self.proj_out(z)
        return self.dropout(out)


class TransformerFFNBlock(nn.Module):
    """
    FFN sublayer with Pre-Norm, residual add, and optional DropPath.
    This block intentionally does NOT include attention; it's just the MLP piece.

    x -> x + DropPath( FFN( Norm(x) ) )

    Args:
        d_model (int): Model width.
        hidden_dim (int | None): Passed to FFN (same auto-infer rule applies if None).
        norm_type (str): "rmsnorm" | "layernorm".
        drop_path (float): Stochastic depth probability on the residual branch.
        ffn_kind (str): Passed to FFN ("swiglu" | "geglu" | "gelu" | "silu").
        ffn_dropout (float): Output dropout inside FFN.
        bias (bool): Linear bias in FFN.
        spectral_norm (bool): Spectral norm on FFN linears.
    """
    def __init__(
        self,
        d_model: int,
        hidden_dim: int | None,
        *,
        norm_type: str = "rmsnorm",
        drop_path: float = 0.0,
        ffn_kind: str = "swiglu",
        ffn_dropout: float = 0.0,
        bias: bool = False,
        spectral_norm: bool = False,
    ):
        super().__init__()
        if norm_type == "rmsnorm":
            self.norm = nn.RMSNorm(d_model, eps=1e-6)
        elif norm_type == "layernorm":
            self.norm = nn.LayerNorm(d_model, eps=1e-5)
        else:
            raise ValueError(f"Unknown norm_type: {norm_type}")

        self.ffn = FFN(
            d_model=d_model,
            hidden_dim=hidden_dim,   # may be None -> auto-infer
            out_dim=d_model,
            kind=ffn_kind,
            dropout=ffn_dropout,
            bias=bias,
            spectral_norm=spectral_norm,
        )
        self.drop_path = DropPath(drop_path) if drop_path and drop_path > 0.0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.drop_path(self.ffn(self.norm(x)))

## Attention

In [10]:
import math, torch, torch.nn as nn, torch.nn.functional as F

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.1, max_seq_len=2048):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head   = n_head
        self.head_dim = n_embd // n_head

        # 线性层：一次性产生 qkv，然后输出投影
        self.qkv  = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd,     bias=False)

        self.attn_drop  = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)

        # 预先构建“因果上三角 mask”，复用；persistent=False 不随 ckpt 存盘（可选）
        causal = torch.triu(
            torch.ones(max_seq_len, max_seq_len, dtype=torch.bool),
            diagonal=1
        )
        self.register_buffer("causal_mask", causal, persistent=False)

    def forward(self, x):  # x: (B, T, C)
        B, T, C = x.shape
        # 拆出 q,k,v 并重排到 (B, h, T, d)
        qkv = self.qkv(x)                      # (B,T,3C)
        q, k, v = qkv.split(C, dim=2)         # 各 (B,T,C)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # 注意力分数 + 因果 mask（只取前 T×T 的切片）
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)   # (B,h,T,T)
        att = att.masked_fill(self.causal_mask[:T, :T], float('-inf'))

        # 数值更稳：softmax 用 float32，再 cast 回原 dtype
        att = F.softmax(att.float(), dim=-1).to(q.dtype)
        att = self.attn_drop(att)

        y = att @ v                                # (B,h,T,d)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))          # (B,T,C)
        return y

class TransformerAttnBlock(nn.Module):
    """
    Attention sublayer with Pre-Norm, residual add, and optional DropPath:
        x -> x + DropPath( Attn( Norm(x) ) )
    """
    def __init__(
        self,
        d_model: int,
        n_head: int,
        *,
        norm_type: str = "rmsnorm",     # "rmsnorm" | "layernorm"
        attn_dropout: float = 0.0,      # dropout on attn weights / proj
        drop_path: float = 0.0,         # stochastic depth prob
        max_seq_len: int = 2048,
        bias: bool = False,
    ):
        super().__init__()
        if norm_type == "rmsnorm":
            self.norm = nn.RMSNorm(d_model, eps=1e-6)
        elif norm_type == "layernorm":
            self.norm = nn.LayerNorm(d_model, eps=1e-5)
        else:
            raise ValueError(f"Unknown norm_type: {norm_type}")

        # 复用你之前的注意力实现（带缓存因果 mask 更省分配）
        self.attn = CausalSelfAttention(
            n_embd=d_model,
            n_head=n_head,
            dropout=attn_dropout,
            max_seq_len=max_seq_len,
        )
        self.drop_path = DropPath(drop_path) if drop_path and drop_path > 0.0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.drop_path(self.attn(self.norm(x)))

## Transformer Block

In [11]:
class TransformerBlock(nn.Module):
    """
    Full decoder block:
        x -> x + DropPath( Attn( Norm(x) ) )
          -> x -> x + DropPath( FFN(  Norm(x) ) )
    """
    def __init__(
        self,
        d_model: int,
        n_head: int,
        *,
        max_seq_len: int = 2048,
        norm_type: str = "rmsnorm",
        # DropPath：建议随层数线性增大（见下“经验值”）
        drop_path: float = 0.0,
        # 注意力与 FFN 的内部超参：
        attn_dropout: float = 0.0,
        ffn_kind: str = "swiglu",          # "swiglu" | "geglu" | "gelu" | "silu"
        ffn_hidden: int | None = None,     # None → 自动推断并按 128/256 对齐
        ffn_dropout: float = 0.0,
        bias: bool = False,
        spectral_norm: bool = False,
    ):
        super().__init__()
        self.attn_blk = TransformerAttnBlock(
            d_model, n_head,
            norm_type=norm_type,
            attn_dropout=attn_dropout,
            drop_path=drop_path,
            max_seq_len=max_seq_len,
            bias=bias,
        )
        self.ffn_blk = TransformerFFNBlock(
            d_model=d_model,
            hidden_dim=ffn_hidden,
            norm_type=norm_type,
            drop_path=drop_path,
            ffn_kind=ffn_kind,
            ffn_dropout=ffn_dropout,
            bias=bias,
            spectral_norm=spectral_norm,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.attn_blk(x)
        x = self.ffn_blk(x)
        return x


## GPT

In [12]:
class TinyGPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        n_layer: int = 12,
        n_head: int  = 12,
        n_embd: int  = 768,
        block_size: int = 1024,
        *,
        norm_type: str = "rmsnorm",
        drop_path_max: float = 0.0,        # e.g. 0.05~0.1 for deeper nets
        ffn_kind: str = "swiglu",
        ffn_hidden: int | None = None,     # None -> auto per your FFN
        ffn_dropout: float = 0.0,
        attn_dropout: float = 0.0,
        bias: bool = False,
        spectral_norm: bool = False,
        emb_dropout: float = 0.0           # keep small for pretrain
    ):
        super().__init__()
        self.block_size = block_size

        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)   # 简单位置；需要 RoPE 时替换
        self.drop    = nn.Dropout(emb_dropout)

        # linearly increasing DropPath across depth
        def get_dp(i):  # i in [0, n_layer-1]
            return drop_path_max * (i / max(1, n_layer - 1))

        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=n_embd, n_head=n_head,
                max_seq_len=block_size,
                norm_type=norm_type,
                drop_path=get_dp(i),
                attn_dropout=attn_dropout,
                ffn_kind=ffn_kind,
                ffn_hidden=ffn_hidden,
                ffn_dropout=ffn_dropout,
                bias=bias,
                spectral_norm=spectral_norm,
            )
            for i in range(n_layer)
        ])

        self.ln_f = nn.RMSNorm(n_embd, eps=1e-6) if norm_type == "rmsnorm" else nn.LayerNorm(n_embd, eps=1e-5)

        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        # weight tying
        self.lm_head.weight = self.tok_emb.weight

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    @property
    def num_params(self): 
        return sum(p.numel() for p in self.parameters())

    def forward(self, idx, targets=None):          # idx: (B,T)
        B, T = idx.shape
        assert T <= self.block_size
        tok = self.tok_emb(idx)                    # (B,T,C)
        pos = self.pos_emb(torch.arange(T, device=idx.device, dtype=torch.long))  # (T,C)
        x   = self.drop(tok + pos[None, :, :])     # (B,T,C)

        for blk in self.blocks:
            x = blk(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)                   # (B,T,V)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        self.eval()
        idx = idx.to(next(self.parameters()).device)
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-6)
            if top_k is not None and top_k > 0:
                k = min(top_k, logits.size(-1))
                v = torch.topk(logits, k, dim=-1).values
                logits = logits.masked_fill(logits < v[:, [-1]], float('-inf'))
            probs = F.softmax(logits, dim=-1)
            next_tok = torch.multinomial(probs, num_samples=1)   # (B,1)
            idx = torch.cat([idx, next_tok], dim=1)
        return idx

In [18]:
model = TinyGPT( tokenizer.vocab_size, n_layer=8, n_head=4, n_embd=256, block_size=block_size).to(device)

In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# 从 train_loader 里取一批
xb, yb = next(iter(train_loader))     # xb, yb: (B, T)
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)

_, loss = model(xb, yb)

print(f"Model Param #: {model.num_params/1e6:.4f} M" , 
      f" Loss at rand initialization: {float(loss):.4f}")

Model Param #: 6.0833 M  Loss at rand initialization: 4.1207


In [20]:
import math
from itertools import cycle

max_steps     = 1500
eval_interval = 100
init_lr = 3e-4
warmup_steps = max(1, int(0.03 * max_steps))

def grouped_params(model, wd=0.1):
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad: continue
        (decay if p.ndim >= 2 and ('norm' not in n.lower()) else no_decay).append(p)
    return [{"params": decay, "weight_decay": wd},
            {"params": no_decay, "weight_decay": 0.0}]

optimizer = torch.optim.AdamW(grouped_params(model, wd=0.1),
                              lr=init_lr, betas=(0.9,0.95), eps=1e-8)

model.train()
train_iter = cycle(train_loader)
tokens_seen = 0

for step in range(1, max_steps + 1):
    xb, yb = next(train_iter)
    xb = xb.to(device, non_blocking=True); yb = yb.to(device, non_blocking=True)

    # 线性 warmup（3% 步数）
    if step <= warmup_steps:
        scale = step / warmup_steps
        for g in optimizer.param_groups: g["lr"] = init_lr * scale

    optimizer.zero_grad(set_to_none=True)
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(device=="cuda")):
        _, loss = model(xb, yb)

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    tokens_seen += xb.numel()

    if step % 10 == 0:
        lr_now = optimizer.param_groups[0]["lr"]
        print(f"step {step:4d} | train loss {loss.item():.3f} | lr {lr_now:.6g} | tokens {tokens_seen:,}")

    if step % eval_interval == 0:
        model.eval()
        with torch.no_grad():
            vloss_total, vcount = 0.0, 0
            for vx, vy in val_loader:
                vx = vx.to(device, non_blocking=True); vy = vy.to(device, non_blocking=True)
                _, vloss = model(vx, vy)
                vloss_total += vloss.item(); vcount += 1
        print(f"==> eval loss {vloss_total / max(1,vcount):.3f}")
        model.train()


step   10 | train loss 3.561 | lr 6.66667e-05 | tokens 81,920
step   20 | train loss 3.426 | lr 0.000133333 | tokens 163,840
step   30 | train loss 3.294 | lr 0.0002 | tokens 245,760
step   40 | train loss 2.928 | lr 0.000266667 | tokens 327,680
step   50 | train loss 2.726 | lr 0.0003 | tokens 409,600
step   60 | train loss 2.647 | lr 0.0003 | tokens 491,520
step   70 | train loss 2.583 | lr 0.0003 | tokens 573,440
step   80 | train loss 2.523 | lr 0.0003 | tokens 655,360
step   90 | train loss 2.482 | lr 0.0003 | tokens 737,280
step  100 | train loss 2.433 | lr 0.0003 | tokens 819,200
==> eval loss 2.463
step  110 | train loss 2.430 | lr 0.0003 | tokens 901,120
step  120 | train loss 2.388 | lr 0.0003 | tokens 983,040
step  130 | train loss 2.378 | lr 0.0003 | tokens 1,064,960
step  140 | train loss 2.363 | lr 0.0003 | tokens 1,146,880
step  150 | train loss 2.356 | lr 0.0003 | tokens 1,228,800
step  160 | train loss 2.335 | lr 0.0003 | tokens 1,310,720
step  170 | train loss 2.313 |

In [21]:
import torch
import torch.nn.functional as F
from typing import List, Optional, Callable, Union


@torch.no_grad()
def stream_generate(
    model,
    tokenizer: Optional[object] = None,
    prompt: Union[str, List[int]] = "",
    *,
    max_new_tokens: int = 128,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    repetition_penalty: float = 1.0,
    eos_id: Optional[int] = None,
    device: str = "cuda",
    print_every: int = 1,
):
    """
    边生成边打印。返回完整字符串和 id。
    """
    model.eval()
    device = device if torch.cuda.is_available() and device.startswith("cuda") else "cpu"
    model = model.to(device)

    if isinstance(prompt, str):
        if tokenizer is None:
            raise ValueError("prompt 是字符串时需要提供 tokenizer（含 encode/decode）。")
        ids = tokenizer.encode(prompt)
        print(prompt, end="", flush=True)
    else:
        ids = list(prompt)

    x = torch.tensor([ids], dtype=torch.long, device=device)
    collected_new_ids = []

    for t in range(max_new_tokens):
        block_size = getattr(model, "block_size", x.size(1))
        x_cond = x[:, -block_size:]

        logits, _ = model(x_cond)
        logits = logits[:, -1, :]

        if temperature and temperature > 0:
            logits = logits / temperature

        if repetition_penalty != 1.0 and x.numel() > 0:
            unique_tokens = torch.unique(x)
            logits[:, unique_tokens] /= repetition_penalty

        if top_k is not None and top_k > 0:
            k = min(top_k, logits.size(-1))
            v = torch.topk(logits, k, dim=-1).values[:, [-1]]
            logits = logits.masked_fill(logits < v, float("-inf"))

        if top_p is not None and 0.0 < top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
            probs = F.softmax(sorted_logits, dim=-1)
            cum = torch.cumsum(probs, dim=-1)
            mask = cum > top_p
            mask[..., 1:] = mask[..., :-1].clone()
            mask[..., 0] = False
            sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
            logits = torch.full_like(logits, float("-inf"))
            logits.scatter_(1, sorted_idx, sorted_logits)

        probs = F.softmax(logits, dim=-1)
        next_tok = torch.multinomial(probs, num_samples=1)
        x = torch.cat([x, next_tok], dim=1)
        nid = next_tok.item()
        collected_new_ids.append(nid)

        if eos_id is not None and nid == eos_id:
            break

        # 流式打印新 token
        if tokenizer is not None and hasattr(tokenizer, "decode") and ((t+1) % print_every == 0):
            # 只解码新 token（简单做法：解码全序列最后若干；更省事直接 decode 全部也行）
            print(tokenizer.decode([nid]), end="", flush=True)

    # 最后收尾
    out_ids = x[0].tolist()
    text = tokenizer.decode(out_ids) if (tokenizer is not None and hasattr(tokenizer, "decode")) else str(out_ids)
    print()  # 换行
    return text, out_ids

In [26]:
_ = stream_generate(
    model, tokenizer,
    prompt="User: hi\nI love you!:",
    max_new_tokens=1000, temperature=0.8, top_k=50, top_p=0.9,
    repetition_penalty=1.05, eos_id=getattr(tokenizer, "eos_id", None),
    device="cuda", print_every=1
)

User: hi
I love you!: I have a bounder hangmen a
more-honest will a loss as you are gone.

POLIXENES:
If you fear you, promises so boy.

POLIXENES:
I do, a soon.

MARCIUS:
True, no more of that I was much in him, the
lips to gett him in him that sit this allieve.

ANGELO:
When is here? what art thou? thy thricing is the
two large a worthily complainanus?

ISABELLA:
Go to, get your faith, I did in the last
white of you are good to me, therefore to stay
so so young.

LEONTES:
But, I warrant to stay; I am and the more.

First too the boy: she straight, let's have mad me be
me meed to my lord. If he die, I saw
here sworn it.

CORIOLANUS:
True, sir, and my true scons.

COMINIUS:
Well, thou wilt me be alowed
Before I shall lack the thine of thy coat,
That his to make me hour with thee.

MARCIUS:
Come, there both: go you are more between like
The sain of their way of my son, I come to thee with
thee, let me stay not to stay awhile; therefore
shall be short stroke upon my fair brother,
With my