In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open('./input/tanshi.txt','r',encoding='utf-8') as f:
    text = f.read()

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {s:i for i,s in enumerate(chars)} 
itos = {i:s for i,s in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [4]:
data = torch.tensor(encode(text),dtype = torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [5]:
batch_size = 10          # 每次训练的样本数
block_size = 5          # 上下文窗口大小（模型看到的最大上下文长度）
max_iters = 50000       # 训练步数
eval_interval = 300     # 每隔多少步评估一次训练/验证损失
learning_rate = 1e-4    # 优化器学习率
#device = 'cpu'          # 运行设备（'cpu' 或 'cuda'）
device = 'cpu' if not torch.cuda.is_available() else 'cuda'
eval_iters = 200        # 用于估计损失时的批次数
n_embed = 10            # token 嵌入维度（非常小，仅演示用）

In [6]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-block_size,(batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x,y = x.to(device),y.to(device)
    return x,y

In [29]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed,head_size,bias=False)
        self.query = nn.Linear(n_embed,head_size,bias=False)
        self.value = nn.Linear(n_embed,head_size,bias=False)
        self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))
    def forward(self,x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        wei = F.softmax(wei,dim=-1)
        v = self.value(x)
        out = wei @ v
        return out

class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed,n_embed)
    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads],dim=-1)
        return self.proj(out)

class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd,n_embd),
            nn.ReLU(),
        )
    def forward(self,x):
        return self.net(x)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding_table = nn.Embedding(block_size, n_embed)
        self.sa_head = MultiHead(2, n_embed // 2)
        self.ffwd = FeedFoward(n_embed)  
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B,T = idx.shape # (B,T) = (1,5)
        tok_emb = self.token_embedding_table(idx) # (B,T,C)= (1,5,10)
        pos_emb = self.pos_embedding_table(torch.arange(T))
        x = tok_emb + pos_emb
        x = self.sa_head(x)
        x = self.ffwd(x)
        logits = self.lm_head(x)

        if targets is None:
            return logits, None
        
        B,T,C = logits.shape
        logits = logits.view(B*T,C)
        targets = targets.view(-1) # (B*T,)                
        
        loss = F.cross_entropy(logits, targets) 
        return logits,loss

    def generate(self,idx,max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits,loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits,dim=-1)
            idx_next = torch.multinomial(probs,num_samples=1)
            idx = torch.cat((idx,idx_next),dim=1)
        return idx

model = Model()
model = model.to(device)
opt = torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [31]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train','val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits,loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

for iter in range(10000):
    xb,yb = get_batch('train')
    logits,loss = model(xb,yb) #logits: (1,5,1001)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")



step 0: train loss 6.9736, val loss 6.9715
step 300: train loss 6.9047, val loss 6.9027
step 600: train loss 6.8348, val loss 6.8248
step 900: train loss 6.7073, val loss 6.7027
step 1200: train loss 6.4787, val loss 6.4619
step 1500: train loss 6.2028, val loss 6.1920
step 1800: train loss 6.0453, val loss 6.0331
step 2100: train loss 5.9377, val loss 5.9540
step 2400: train loss 5.9013, val loss 5.9222
step 2700: train loss 5.8596, val loss 5.8575
step 3000: train loss 5.8372, val loss 5.8139
step 3300: train loss 5.7944, val loss 5.7923
step 3600: train loss 5.7774, val loss 5.7482
step 3900: train loss 5.7328, val loss 5.7542
step 4200: train loss 5.7415, val loss 5.7009
step 4500: train loss 5.7082, val loss 5.6972
step 4800: train loss 5.7148, val loss 5.7010
step 5100: train loss 5.6986, val loss 5.6885
step 5400: train loss 5.6606, val loss 5.6859
step 5700: train loss 5.6700, val loss 5.6309
step 6000: train loss 5.6280, val loss 5.6369
step 6300: train loss 5.6340, val loss 5

In [35]:
context = torch.zeros((1,1),dtype=torch.long,device=device)
print(decode(model.generate(context,max_new_tokens=200)[0].tolist()))


馀名剑不诗
辛白。
断未
关书。
字悠，出惟即林得何上人生江报，勤自犹，茫漏经年，作从隐晓笑山。
知易我丛中自，。。里去拥共，。。
年
手以伤车，白入烟，帆闲秋正极，，乃逢乐绕见。。
吟候今西此
雪出楚病紫
避本度句流空幽浮未。云不白生。
积馀。载临，
香高浮白何落能，
饮荒老。，
辞。
毛海国
两覆门诗松里过。，归同帝烟空诗险和
沧永县半任
西，寥舍尊古尝能归结僧连，交
灵久，三苍催舟，其
迢无
