1. 导入需要用到的库，主要是pytorch，自己下载的llama tokenizer，以及训练需要用到的管理学习率和可视化的库。

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
import yaml
from llama.tokenizer import Tokenizer
from torch.utils.data import DataLoader, Dataset
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import random
import numpy as np
import os

2. 读取config文件，以及训练集

注意：因为本任务是风格化文本生成，很难比较生成文本的质量，我干脆把train.csv和validation.csv的内容放到一起作为训练集

In [2]:
def read_csv(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def read_yaml(config_path):
    with open(config_path, "r", encoding="utf-8") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

def split_text2chunk(text, bos_id=None, eos_id=None, chunk_size=512):
    '''
    一条文本太长，进行分块
    '''
    for i in range(0, len(text), chunk_size):
        if bos_id is not None and eos_id is not None:
            res = [bos_id] + text[i:i+chunk_size-2] + [eos_id]
        elif bos_id is not None and eos_id is None:
            res = [bos_id] + text[i:i+chunk_size-1]
        elif bos_id is None and eos_id is not None:
            res = text[i:i+chunk_size-1] + [eos_id]
        else:
            res = text[i:i+chunk_size]
        yield res

# 加载tokenizer, text_chunks, config
tokenizer = Tokenizer("./llama/tokenizer.model")
config = read_yaml("./config.yaml")
train_text = read_csv("./archive/train.csv")
print(len(train_text))
print(tokenizer.pad_id)

1059631
-1


In [3]:
import re
def split_text(text):
    # 使用正则表达式匹配所有标点符号、换行符和制表符进行切分
    tokens = re.findall(r"[\w']+|[.,!?;()\n\t]", text)
    return tokens
word_list = split_text(train_text)
print(len(word_list))

265769


3. 构建数据集(dataset)和dataloader，这里有两种选择：
   - 先tokenize，后切分成512一组---->容易保证输入seq_len一致，但可能破坏文本一致性。
   - 先按words语义切分为512的chunk, 再进行tokenize--->容易保持文本一致性，但seq_len不一定相同，需要截断与pad.

实际上文本很长的时候，两种方法差不多。这里为了方便选择第一种

In [4]:
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

#先tokenize，后切分成512一组----》容易保证输入seq_len一致，但可能破坏文本一致性
train_data = tokenizer.encode(train_text, bos=False, eos=False)
print("分词后token数:", len(train_data))
#给每个chunk加上<bos>, <pos>
train_data_list = list(split_text2chunk(train_data, bos_id = tokenizer.bos_id, eos_id = tokenizer.eos_id, chunk_size = config["model_config"]["seq_len"]+1))

padded_train_data = torch.nn.utils.rnn.pad_sequence([torch.tensor(ids) for ids in train_data_list], batch_first=True)
train_dataset = TextDataset(padded_train_data)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

#先按words语义切分为512的chunk, 再进行tokenize---》容易保持文本一致性，但seq_len不一定相同，需要截断与pad

'''
train_chunks = list(split_text2chunk(word_list, config["model_config"]["seq_len"]))
train_data_list = []
for chunk in train_chunks:
    tokens = ' '.join(chunk)
    tokenized_chunk = tokenizer.encode(tokens, False, False)
    train_data_list.append(tokenized_chunk)
padded_train_data = torch.nn.utils.rnn.pad_sequence([torch.tensor(ids) for ids in train_data_list], batch_first=True, padding_value = tokenizer.pad_id)
'''

分词后token数: 349246


'\ntrain_chunks = list(split_text2chunk(word_list, config["model_config"]["seq_len"]))\ntrain_data_list = []\nfor chunk in train_chunks:\n    tokens = \' \'.join(chunk)\n    tokenized_chunk = tokenizer.encode(tokens, False, False)\n    train_data_list.append(tokenized_chunk)\npadded_train_data = torch.nn.utils.rnn.pad_sequence([torch.tensor(ids) for ids in train_data_list], batch_first=True, padding_value = tokenizer.pad_id)\n'

In [5]:
#验证train_dataloader经过encode之后是对的，可以decode解码
batch_data1 = None
for i, tmp in enumerate(train_dataloader):
    if i==0:
        batch_data1 = tmp
        break
for bd in batch_data1:
    words = tokenizer.decode(bd.tolist())
    break
print(words)

doth she tempt: but it is I
That, lying by the violet in the sun,
Do as the carrion does, not as the flower,
Corrupt with virtuous season. Can it be
That modesty may more betray our sense
Than woman's lightness? Having waste ground enough,
Shall we desire to raze the sanctuary
And pitch our evils there? O, fie, fie, fie!
What dost thou, or what art thou, Angelo?
Dost thou desire her foully for those things
That make her good? O, let her brother live!
Thieves for their robbery have authority
When judges steal themselves. What, do I love her,
That I desire to hear her speak again,
And feast upon her eyes? What is't I dream on?
O cunning enemy, that, to catch a saint,
With saints dost bait thy hook! Most dangerous
Is that temptation that doth goad us on
To sin in loving virtue: never could the strumpet,
With all her double vigour, art and nature,
Once stir my temper; but this virtuous maid
Subdues me quite. Even till now,
When men were fond, I smiled and wonder'd how.

DUKE VINCENTIO:
Hai

In [6]:
print(len(train_dataset))
print(len(train_dataloader)) # 每个chunk 512个token 每个batch 32个chunk

681
22


4. 构建decoder-only transformer：
- 自适应输入长度
- 使用pre LayerNorm代替post LayerNorm
- 使用RoPE代替原始的正余弦位置编码
- 使用kv-cache进行推理加速

In [12]:
class tokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model): #d_model为单个token的embedding维度
        super(tokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x)

class RotationEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super(RotationEmbedding, self).__init__()
        self.max_len = max_len
        self.d_model = d_model
    def cal_theta(self, base=10000):
        # return shape (d_model,)
        _2i = torch.arange(0, self.d_model, step=2)
        theta = 1 / (base ** (_2i / self.d_model))
        theta =torch.repeat_interleave(theta, repeats=2)
        return theta
    
    def cal_cos_sin(self):
        pos = torch.arange(0, self.max_len).unsqueeze(1)
        # pos shape (max_len, 1)
        angles = pos * self.cal_theta() # shape (max_len, d_model) 广播运算
        embeddings = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1) # shape (max_len, d_model, 2)
        #dim=-1表示在最后一维插入，变成[cos(angles), sin(angles)]
        # shape (max_len, d_model, 2)
        embeddings = embeddings.unsqueeze(0).unsqueeze(0) # shape (1, 1, max_len, d_model, 2)
        return embeddings
    
    def forward(self, q, k):
        # q shape (batch_size, n_head, max_len, d_head)
        embeddings = self.cal_cos_sin().to(q.device)
        #print("q shape:", q.shape)
        #print("k shape:", k.shape)
        q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
        q2 = q2.reshape(q.shape)
        cos_pos = embeddings[..., 0].squeeze(-1) #(1, 1, max_len, d_head)
        
        sin_pos = embeddings[..., 1].squeeze(-1) #(1, 1, max_len, d_head)
        
        q_len = q.shape[2]  #为了处理推理时的变长序列
        k_len = k.shape[2]

        q = q * cos_pos[:, :, q_len-1, :] + q2 * sin_pos[:, :, q_len-1, :]

        k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
        k2 = k2.reshape(k.shape)
        # 更新kw, *对应位置相乘
        k = k * cos_pos[:, :, k_len-1, :] + k2 * sin_pos[:, :, k_len-1, :]
        return q, k
    
class FeedForward(nn.Module):
    def __init__(self, d_model, d_hidden=2048, drop_prob=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_hidden)
        self.linear2 = nn.Linear(d_hidden, d_model)
        self.activate = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)
        if self.linear1.bias is not None:
            nn.init.zeros_(self.linear1.bias)
        if self.linear2.bias is not None:
            nn.init.zeros_(self.linear2.bias)
        nn.init.xavier_normal_(self.linear1.weight)
        nn.init.xavier_normal_(self.linear2.weight)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.activate(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x
    
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, seq_len, n_heads):
        super(MultiheadAttention, self).__init__()
        #print("mutihead attention para: d_model:{}, seq_len:{}, n_heads:{}".format(d_model, seq_len, n_heads))
        self.d_model = d_model
        self.n_heads = n_heads
        self.seq_len = seq_len
        self.head_dim = d_model // n_heads
        #print("mutihead attention para: d_model:{}, seq_len:{}, n_heads:{}, head_dim:{}".format(d_model, seq_len, n_heads, self.head_dim))

        self.values = nn.Linear(d_model, self.head_dim * n_heads, bias=False)
        self.keys = nn.Linear(d_model, self.head_dim * n_heads, bias=False)
        self.queries = nn.Linear(d_model, self.head_dim * n_heads, bias=False)
        self.fc_out = nn.Linear(n_heads * self.head_dim, d_model)
        self.RoPE = RotationEmbedding(max_len=seq_len, d_model=self.head_dim)
        nn.init.kaiming_normal_(self.values.weight)
        nn.init.kaiming_normal_(self.keys.weight)
        nn.init.kaiming_normal_(self.queries.weight)
        nn.init.xavier_normal_(self.fc_out.weight)
        if self.fc_out.bias is not None:
            nn.init.zeros_(self.fc_out.bias)
        
    def forward(self, x, mask, use_rope=True):
        '''
        x : (batch_size, seq_len, d_model)
        mask : (batch_size, 1, seq_len, seq_len)
        '''
        bs = x.shape[0]

        seq_len = x.shape[1] 

        values = self.values(x).view(bs, seq_len, self.n_heads, self.head_dim)
        keys = self.keys(x).view(bs, seq_len, self.n_heads, self.head_dim)
        queries = self.queries(x).view(bs, seq_len, self.n_heads, self.head_dim)

        values = values.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        queries = queries.permute(0, 2, 1, 3)
        # q,k,v : (batch_size, n_heads, seq_len, d_head)
        #print("queries shape:", queries.shape)
        if use_rope:
            queries, keys = self.RoPE(queries, keys)

        weights = torch.matmul(queries, keys.permute(0, 1, 3, 2)) #(bs, n, seq_len, d_head) @ (bs, n, d_head, seq_len) -> (bs, n, seq_len, seq_len)

        if mask is not None: 
            weights = weights.masked_fill(mask != 1, float("-1e10"))

        attention = F.softmax(weights / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)), dim=-1)

        out = torch.matmul(attention, values)
        out = out.permute(0, 2, 1, 3).contiguous().view(bs, seq_len, int(self.n_heads * self.head_dim))

        out = self.fc_out(out)
        return out

    def predict(self, x, mask, k_cache, v_cache, use_rope=True):
        bs = x.shape[0]
        seq_len = x.shape[1] #推理的时候应该为1

        values = self.values(x).view(bs, seq_len, self.n_heads, self.head_dim)
        keys = self.keys(x).view(bs, seq_len, self.n_heads, self.head_dim)
        # [B, 1, H, D]

        queries = self.queries(x).view(bs, seq_len, self.n_heads, self.head_dim)
        # [B, 1, H, D]

        values = values.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        queries = queries.permute(0, 2, 1, 3)
        # [B, H, 1, D]

        if use_rope:
            queries, keys = self.RoPE(queries, keys)

        if k_cache is not None:
            keys = torch.cat([k_cache, keys], dim=2)
            values = torch.cat([v_cache, values], dim=2)
            # [B, H, L, D]
        weights = torch.matmul(queries, keys.permute(0, 1, 3, 2))

        if mask is not None:
            weights = weights.masked_fill(mask != 1, float("-1e10"))

        attention = F.softmax(weights / (self.head_dim ** (1 / 2)), dim=-1)

        out = torch.matmul(attention, values)
        out = out.permute(0, 2, 1, 3).contiguous().view(bs, seq_len, self.n_heads * self.head_dim)

        out = self.fc_out(out)
        return out, keys, values
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, seq_len, n_heads, d_hidden, drop_prob=0.1):
        super().__init__()
        #self.embed = tokenEmbedding(vocab_size=voc_size, d_model=d_model)
        self.pre_norm = nn.LayerNorm(d_model)
        self.attn = MultiheadAttention(d_model, seq_len=seq_len, n_heads=n_heads)
        self.ffn = FeedForward(d_model, d_hidden=d_hidden, drop_prob=drop_prob)
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.norm_ffn = nn.LayerNorm(d_model)
        nn.init.normal_(self.pre_norm.weight, mean=0, std=1e-2)
        nn.init.zeros_(self.pre_norm.bias)
        nn.init.normal_(self.norm_ffn.weight, mean=0, std=1e-2)
        nn.init.zeros_(self.norm_ffn.bias)
        self.k_cache = None
        self.v_cache = None

    def forward(self, x, mask):
        _x = x
        x = self.pre_norm(x)
        x = self.attn(x, mask)
        x = self.dropout1(x)
        x1 = (x + _x)
        x = self.norm_ffn(x1)
        x = self.ffn(x)
        x = self.dropout1(x)
        return x + x1

    def clear_cache(self):
        self.k_cache = None
        self.v_cache = None

    def predict(self, x, mask, use_rope=True):
        _x = x
        x = self.pre_norm(x)
        x, k_cache, v_cache = self.attn.predict(x, mask, self.k_cache, self.v_cache)
        self.k_cache = k_cache
        self.v_cache = v_cache
        x = self.dropout1(x)
        x1 = (x + _x)
        x = self.norm_ffn(x1)
        x = self.ffn(x)
        x = self.dropout1(x)
        return x + x1

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, bos_id, eos_id, pad_id, voc_size, d_model, seq_len, n_heads, n_layers, d_hidden, drop_prob=0.1):
        super().__init__()
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, seq_len, n_heads, d_hidden, drop_prob) for _ in range(n_layers)])
        self.embed = tokenEmbedding(vocab_size=voc_size, d_model=d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.last_norm = nn.LayerNorm(d_model)
        self.last_linear = nn.Linear(d_model, voc_size)
        #self.softmax = nn.Softmax(dim=-1)
        self.pad_id = pad_id
        self.start_id = bos_id
        self.end_id = eos_id
        self.mode = "train"
        nn.init.xavier_uniform_(self.last_linear.weight)
        nn.init.normal_(self.last_norm.weight, mean=0, std=1e-2)
        
    def set_mode(self, mode):
        if mode not in ['train', 'generate']:
            raise ValueError("Unsupported mode: {}".format(mode))
        self.mode = mode
    
    def create_mask(self, x):
        #x_shape: (batch_size, seq_len)
        #mask==False 的位置, 在注意力计算后会被忽略
        pad_mask = torch.ne(x, self.pad_id).unsqueeze(1).unsqueeze(3).to(x.device) #batch_size, 1, seq_len, 1
        seq_len = x.shape[1]
        low_tri_mask = (torch.tril(torch.ones(seq_len, seq_len)) == 1).bool().to(x.device) 
        mask = torch.logical_and(pad_mask, low_tri_mask)
        return mask
    
    def forward(self, x):
        mask = self.create_mask(x)
        x = self.embed(x)
        x = self.dropout1(x)
        for layer in self.decoder_layers:
            x = layer(x, mask)
        x = self.last_norm(x)
        x = self.last_linear(x)
        return x

    def clear_cache(self):
        for layer in self.decoder_layers:
            layer.clear_cache()
            
    def generate(self, x, max_len=512, use_rope=True):
        mask = self.create_mask(x)
        x = self.embed(x)
        x = self.dropout1(x)
        for layer in self.decoder_layers:
            x= layer.predict(x, mask, use_rope=use_rope)
        x = self.last_norm(x)
        x = self.last_linear(x)
        return x

In [13]:
#设置超参， config的有些超参没用上，不过无所谓
pad_id = tokenizer.pad_id
start_id = tokenizer.bos_id
end_id = tokenizer.eos_id
voc_size = tokenizer.n_words
d_model = config["model_config"]['d_model']         #512
n_heads = config["model_config"]['n_heads']         #8
d_hidden = config["model_config"]['d_hidden']       #2048
drop_prob1 = config["model_config"]['drop_prob1']   #0.2
drop_prob2 = config["model_config"]['drop_prob2']   #0.4
n_layers = config["model_config"]['n_layers']       #6
seq_len = config["data_config"]["max_seq_length"]   #512


5. 进行训练
- 设置device, criterion(交叉熵), optimizer(Adam)
- 采用teacher-forcing 和 schedule sampling 两种不同的训练方式分别训练
- 选择适当的初始化方案，如Xavier and Kaiming initialization
- 每1k个step保存一次checkpoint
- 采用线性学习率warm_up
- 注：实际训练需要单卡有12G显存，没有采用混合精度训练。
- 训练时，由于不知道epoch应该设置多少最好，可以采用保存checkpoint的方式多次训练，直到认为结果够好。或者自己用perplexity/BLEU评估。

In [14]:
#设置损失函数和优化器
model = DecoderOnlyTransformer(bos_id = start_id, eos_id = end_id, pad_id=pad_id, d_model=d_model, n_heads=n_heads, d_hidden=d_hidden, drop_prob=drop_prob1, voc_size=voc_size, seq_len=seq_len, n_layers=n_layers)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config["training_config"]["learning_rate"]*10, weight_decay=config["training_config"]["weight_decay"])

#设置checkpoint保存路径
import os
checkpoint_dir = os.getcwd()
print(f"teacher-forcing checkpoint 路径:{checkpoint_dir}")

teacher-forcing checkpoint 路径:/home/liyihang/assignment3-transformer


5.1 teacher-forcing方式训练

- 优点：能够迅速收敛；训练速度快；可以并行
- 缺点：模型的能力依赖数据集，泛化性差（只学习预测数据集的下一个token）

In [15]:
writer = SummaryWriter()

# 计算总训练步骤数和 warm-up 步骤数
num_epochs = 50
total_steps = num_epochs * len(train_dataloader)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
print(f"训练总steps:{total_steps}, warm-up steps:{warmup_steps}")

init_checkpoint = "checkpoint_steps_8800.pth"
save_interval = 500

#可以设置init checkpoint加载预训练模型
init_steps = 8800
if init_checkpoint is not None:
    checkpoint = torch.load(init_checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
cur_steps = 0  
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for idx, batch_data in enumerate(train_dataloader):
        # 假设 batch 包含 input_ids
        input_ids = batch_data.to(device)
        # 创建 labels，即 input_ids 向右偏移一个时间步
        labels = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()
        
        # 将梯度置零
        optimizer.zero_grad()
        #with torch.cuda.amp.autocast():
        # 前向传播
        logits = model(input_ids)

        # 计算损失
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        scheduler.step()  # 更新学习率
        cur_steps += 1
        running_loss += loss.item()
        writer.add_scalar('Loss/train/iter', loss.item(), cur_steps)
        if cur_steps % save_interval == 0 and cur_steps+init_steps>=3000:
            # 保存模型
            torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()}, f"checkpoint_steps_{init_steps + cur_steps}.pth")
    # 打印平均损失
    avg_loss = running_loss / len(train_dataloader)
    writer.add_scalar('Loss/train/epoch', avg_loss, epoch)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()}, f"checkpoint_steps_{init_steps + cur_steps}.pth")
    

训练总steps:1100, warm-up steps:110
Epoch 1/50, Loss: 4.0771
Epoch 2/50, Loss: 4.0690
Epoch 3/50, Loss: 4.0768
Epoch 4/50, Loss: 4.0779
Epoch 5/50, Loss: 4.0786
Epoch 6/50, Loss: 4.0694
Epoch 7/50, Loss: 4.0773
Epoch 8/50, Loss: 4.0787
Epoch 9/50, Loss: 4.0768
Epoch 10/50, Loss: 4.0828
Epoch 11/50, Loss: 4.0733
Epoch 12/50, Loss: 4.0754
Epoch 13/50, Loss: 4.0738
Epoch 14/50, Loss: 4.0773
Epoch 15/50, Loss: 4.0829
Epoch 16/50, Loss: 4.0730
Epoch 17/50, Loss: 4.0853
Epoch 18/50, Loss: 4.0743
Epoch 19/50, Loss: 4.0762
Epoch 20/50, Loss: 4.0721
Epoch 21/50, Loss: 4.0773
Epoch 22/50, Loss: 4.0843
Epoch 23/50, Loss: 4.0771
Epoch 24/50, Loss: 4.0728
Epoch 25/50, Loss: 4.0804
Epoch 26/50, Loss: 4.0786
Epoch 27/50, Loss: 4.0752
Epoch 28/50, Loss: 4.0789
Epoch 29/50, Loss: 4.0782
Epoch 30/50, Loss: 4.0739
Epoch 31/50, Loss: 4.0738
Epoch 32/50, Loss: 4.0746
Epoch 33/50, Loss: 4.0744
Epoch 34/50, Loss: 4.0768
Epoch 35/50, Loss: 4.0786
Epoch 36/50, Loss: 4.0728
Epoch 37/50, Loss: 4.0752
Epoch 38/50, L

5.2 使用schedule sampling训练

schedule sampling是结合了autogressive和teacher-forcing的方法，在刚开始的时候大概率选择teacher-foring，等到模型有一定能力后，autogressive的概率逐渐增加。

- 优点：提高模型自身的能力，不一味模仿数据集，同时容易收敛
- 缺点：无法并行，训练速度慢

In [19]:
def sig(k, x):
    return k/(k+ np.exp(x/k))

def sigmoid_decay(epoch_num, decay_ratio, is_continue=False):
    if epoch_num<15 and is_continue==False:
        return 1.0
    return sig(decay_ratio, epoch_num)

# 计算总训练步骤数和 warm-up 步骤数
num_epochs = 100

#设置checkpoint保存路径
checkpoint_dir = os.getcwd()
print(f"checkpoint 路径:{checkpoint_dir}")

#导入预训练模型 如果不想导入 设为None. 
init_checkpoint = 'sam_checkpoint_steps8800.pth'
init_steps = 8800
is_continue = False
prev_tf_ratio = 1.0
model_sam = DecoderOnlyTransformer(bos_id = start_id, eos_id = end_id, pad_id=pad_id, d_model=d_model, n_heads=n_heads, d_hidden=d_hidden, drop_prob=drop_prob1, voc_size=voc_size, seq_len=seq_len, n_layers=n_layers)
model_sam.to(device)
if init_checkpoint is not None:
    is_continue = True
    checkpoint = torch.load(init_checkpoint)
    model_sam.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    prev_tf_ratio = 1.0 if checkpoint.get('teacher_ratio')==None else checkpoint.get('teacher_ratio')
    print(prev_tf_ratio)
total_steps = num_epochs * len(train_dataloader)
warmup_steps = 0
if is_continue==False:
    warmup_steps = int(0.1 * total_steps)
    print(f"训练总steps:{total_steps}, warm-up steps:{warmup_steps}")

#设置优化器和迭代器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_sam.parameters(), lr=config["training_config"]["learning_rate"], weight_decay=config["training_config"]["weight_decay"])
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
writer_sam = SummaryWriter()
cur_steps = 0  
decay_ratio = 8

#开始训练
for epoch in range(num_epochs):
    model_sam.train()
    running_loss = 0.0
    batch_loss = 0.0
    for batch_data in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        
        input_ids = batch_data.to(device) # (batch_size, seq_len)
        loss = 0.0
        for t in range(len(input_ids)):
            #t时刻, 输入大小为 (batch_size, t)
            
            if t==0 or random.uniform(0, 1) < prev_tf_ratio:
                #teacher-forcing, t时刻输入是input[0:t+1], 输出和input[1:t+2]作比较
                if t == 0:
                    output_logits = model_sam(input_ids[:, 0].unsqueeze(1))
                    labels = input_ids[:, 1].contiguous()
                    loss += criterion(output_logits.view(-1, output_logits.size(-1)), labels)
                else:
                    output_logits = model_sam(input_ids[:, :t+1])    #(batch_size, seq_len, vocab_size)
                    labels = input_ids[:, 1:t+2].contiguous()
                    loss += criterion(output_logits.view(-1, output_logits.size(-1)), labels.view(-1)) #输入[0, t-1]对应的gt为[1,t]
            else:
                #schedule-sampling, t时刻输入是input[0:t] + output[t], 输出和input[1:t+2]作比较
                token_logits = model_sam(input_ids[:, :t])  #(bs, t, vocab_sz)
                cur_token = torch.argmax(token_logits, dim=-1) #t时刻的预测token (bs, t)
                cur_token = cur_token[:,-1].unsqueeze(1) #(bs,1)
                
                input_sample = torch.cat([input_ids[:, :t], cur_token], dim=1) #(bs, t+1)

                output_logits = model_sam(input_sample)
                labels = input_ids[:, 1:t+2].contiguous()
                loss += criterion(output_logits.view(-1, output_logits.size(-1)), labels.view(-1))
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        #batch_loss += loss.item()
        prev_tf_ratio = sigmoid_decay(epoch, decay_ratio, is_continue) #更新teacher_forcing
        scheduler.step()                      # 更新学习率

        cur_steps += 1
        writer_sam.add_scalar('Loss/train/iter', loss.item(), cur_steps)
        if cur_steps % 500 == 0:
            torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
                       'teacher_ratio':prev_tf_ratio}, f"sam_checkpoint_steps{cur_steps+init_steps}.pth")
            
    
    # 打印平均损失
    avg_loss = running_loss / len(train_dataloader)
    writer_sam.add_scalar('Loss/train/epoch', avg_loss, epoch)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
# 保存模型
torch.save({'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(), 'teacher_ratio':prev_tf_ratio}, f"sam_checkpoint_steps{cur_steps + init_steps}.pth")

checkpoint 路径:/home/liyihang/assignment3-transformer
3.378166897674838e-05


Epoch 1/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.46it/s]


Epoch 1/100, Loss: 136.9066


Epoch 2/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:14<00:00,  1.49it/s]


Epoch 2/100, Loss: 131.7287


Epoch 3/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.46it/s]


Epoch 3/100, Loss: 129.5599


Epoch 4/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.45it/s]


Epoch 4/100, Loss: 127.2383


Epoch 5/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.43it/s]


Epoch 5/100, Loss: 125.3461


Epoch 6/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.40it/s]


Epoch 6/100, Loss: 123.7448


Epoch 7/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.41it/s]


Epoch 7/100, Loss: 122.3723


Epoch 8/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.39it/s]


Epoch 8/100, Loss: 122.0406


Epoch 9/100: 100%|██████████████████████████████████████████████████████████████████████████████| 22/22 [00:15<00:00,  1.38it/s]


Epoch 9/100, Loss: 120.3001


Epoch 10/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:16<00:00,  1.35it/s]


Epoch 10/100, Loss: 119.8143


Epoch 11/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:16<00:00,  1.35it/s]


Epoch 11/100, Loss: 118.9688


Epoch 12/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:16<00:00,  1.30it/s]


Epoch 12/100, Loss: 119.7972


Epoch 13/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:16<00:00,  1.30it/s]


Epoch 13/100, Loss: 118.4105


Epoch 14/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:17<00:00,  1.27it/s]


Epoch 14/100, Loss: 118.3456


Epoch 15/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:17<00:00,  1.29it/s]


Epoch 15/100, Loss: 117.2182


Epoch 16/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:17<00:00,  1.26it/s]


Epoch 16/100, Loss: 117.7356


Epoch 17/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:17<00:00,  1.24it/s]


Epoch 17/100, Loss: 117.2710


Epoch 18/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:17<00:00,  1.23it/s]


Epoch 18/100, Loss: 116.3906


Epoch 19/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:18<00:00,  1.22it/s]


Epoch 19/100, Loss: 116.3373


Epoch 20/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:18<00:00,  1.20it/s]


Epoch 20/100, Loss: 116.2455


Epoch 21/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:18<00:00,  1.18it/s]


Epoch 21/100, Loss: 116.4420


Epoch 22/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:18<00:00,  1.17it/s]


Epoch 22/100, Loss: 116.4266


Epoch 23/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:19<00:00,  1.12it/s]


Epoch 23/100, Loss: 115.9074


Epoch 24/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:19<00:00,  1.13it/s]


Epoch 24/100, Loss: 116.4256


Epoch 25/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:19<00:00,  1.12it/s]


Epoch 25/100, Loss: 115.8560


Epoch 26/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:19<00:00,  1.12it/s]


Epoch 26/100, Loss: 115.6485


Epoch 27/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.10it/s]


Epoch 27/100, Loss: 115.3604


Epoch 28/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.09it/s]


Epoch 28/100, Loss: 115.9419


Epoch 29/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.09it/s]


Epoch 29/100, Loss: 114.8551


Epoch 30/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.08it/s]


Epoch 30/100, Loss: 114.7445


Epoch 31/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.07it/s]


Epoch 31/100, Loss: 114.7483


Epoch 32/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.07it/s]


Epoch 32/100, Loss: 114.7572


Epoch 33/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.06it/s]


Epoch 33/100, Loss: 114.6760


Epoch 34/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.07it/s]


Epoch 34/100, Loss: 114.4501


Epoch 35/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.06it/s]


Epoch 35/100, Loss: 113.9925


Epoch 36/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.06it/s]


Epoch 36/100, Loss: 113.8208


Epoch 37/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.06it/s]


Epoch 37/100, Loss: 113.9838


Epoch 38/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.05it/s]


Epoch 38/100, Loss: 113.6670


Epoch 39/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 39/100, Loss: 113.3134


Epoch 40/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:20<00:00,  1.05it/s]


Epoch 40/100, Loss: 113.3082


Epoch 41/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 41/100, Loss: 113.1152


Epoch 42/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 42/100, Loss: 112.8059


Epoch 43/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 43/100, Loss: 112.9387


Epoch 44/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 44/100, Loss: 112.5281


Epoch 45/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 45/100, Loss: 112.4901


Epoch 46/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:22<00:00,  1.00s/it]


Epoch 46/100, Loss: 112.2702


Epoch 47/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 47/100, Loss: 111.9103


Epoch 48/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 48/100, Loss: 111.8032


Epoch 49/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 49/100, Loss: 111.4751


Epoch 50/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 50/100, Loss: 111.5983


Epoch 51/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 51/100, Loss: 111.6070


Epoch 52/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 52/100, Loss: 111.2445


Epoch 53/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 53/100, Loss: 111.0337


Epoch 54/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 54/100, Loss: 110.8757


Epoch 55/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 55/100, Loss: 110.5671


Epoch 56/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 56/100, Loss: 110.7327


Epoch 57/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 57/100, Loss: 110.4169


Epoch 58/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 58/100, Loss: 110.4374


Epoch 59/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 59/100, Loss: 110.0537


Epoch 60/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 60/100, Loss: 110.0495


Epoch 61/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 61/100, Loss: 110.0046


Epoch 62/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 62/100, Loss: 109.8302


Epoch 63/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 63/100, Loss: 109.5994


Epoch 64/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 64/100, Loss: 109.5268


Epoch 65/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 65/100, Loss: 109.3336


Epoch 66/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 66/100, Loss: 109.3174


Epoch 67/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 67/100, Loss: 109.1846


Epoch 68/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 68/100, Loss: 109.0942


Epoch 69/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:22<00:00,  1.01s/it]


Epoch 69/100, Loss: 108.9846


Epoch 70/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 70/100, Loss: 109.0313


Epoch 71/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 71/100, Loss: 108.7109


Epoch 72/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 72/100, Loss: 108.5919


Epoch 73/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 73/100, Loss: 108.7141


Epoch 74/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 74/100, Loss: 108.5627


Epoch 75/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 75/100, Loss: 108.5540


Epoch 76/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 76/100, Loss: 108.4532


Epoch 77/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 77/100, Loss: 108.2365


Epoch 78/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 78/100, Loss: 108.2825


Epoch 79/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 79/100, Loss: 108.2203


Epoch 80/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.02it/s]


Epoch 80/100, Loss: 108.0625


Epoch 81/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 81/100, Loss: 108.0536


Epoch 82/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 82/100, Loss: 107.8493


Epoch 83/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 83/100, Loss: 107.9263


Epoch 84/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 84/100, Loss: 107.7930


Epoch 85/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 85/100, Loss: 107.7153


Epoch 86/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.04it/s]


Epoch 86/100, Loss: 107.7961


Epoch 87/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 87/100, Loss: 107.6667


Epoch 88/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 88/100, Loss: 107.5762


Epoch 89/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 89/100, Loss: 107.6913


Epoch 90/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 90/100, Loss: 107.4574


Epoch 91/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.00it/s]


Epoch 91/100, Loss: 107.4857


Epoch 92/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 92/100, Loss: 107.4202


Epoch 93/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 93/100, Loss: 107.5108


Epoch 94/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.02it/s]


Epoch 94/100, Loss: 107.3265


Epoch 95/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 95/100, Loss: 107.2343


Epoch 96/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 96/100, Loss: 107.2855


Epoch 97/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 97/100, Loss: 107.2108


Epoch 98/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 98/100, Loss: 107.3578


Epoch 99/100: 100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 99/100, Loss: 107.2372


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████| 22/22 [00:21<00:00,  1.03it/s]


Epoch 100/100, Loss: 107.2720


6. 可视化
使用tensorboard记录每个iteration的loss和每个epoch的平均loss，这里每个epoch差不多有22个iteration

对应的图片可以在teacher_forcing_Loss_train_iter.svg, teacher_forcing_Loss_train_epoch.svg中查看。

其次，在训练时可以直接在命令行输入tensorboard --logdir=runs 进行查看; 推荐在vscode中运行该笔记，可以直接查看。

7. 推理

transformer的推理方法：
- greedy
- beam-search
- top-k
- top-p

这里选择top-k方案。具体如代码注释所示

In [16]:
def choose_token(logits, top_k = 50):
    """
    选择logits中top_k个最大的token, 计算softmax后以此为权重随机采样
    logits: (1, 1, vocab_sz)
    """
    values, indices = torch.topk(logits, top_k, dim = -1) #(1, top_k)
    values = F.softmax(values, dim = -1)  #(1, top_k)
    token_idx = random.choices(indices[0].tolist(), weights=values[0].tolist()) 
    return token_idx[0] #(bs, 1)


#如果多次使用，先清除kv cache (model.clear_cache)
def generate_poem(model, model_checkpoint_path, instruction, max_seq_len = 512):
    #推理环节
    #如果cuda out of memory 尝试torch.cuda.empty_cache()
    #load model:
    checkpoint = torch.load(model_checkpoint_path)
    model.eval()
    model.load_state_dict(checkpoint['model_state_dict'])
    #前面的cell已经把model转移到device上了
    #进行文本生成
    max_seq_len = 512
    input_ids = torch.tensor(tokenizer.encode(instruction, bos=True, eos=False)).unsqueeze(0) #(1, len)
    model.clear_cache()
    
    logits = model.generate(input_ids.to(device), use_rope=True)
    #print(logits.shape) #[1, 5, 32000]
    
    next_token_idx = choose_token(logits[:, -1, :],top_k = 50) #int
    
    input_ids = torch.cat([input_ids, torch.tensor([next_token_idx]).unsqueeze(0)], dim=1) #t+1步的输入
    
    initial_len = input_ids.shape[1]
    
    for t in range(0, max_seq_len-initial_len):
        #由于启用了kv-cache, 每次只需要输入一个token
        logits = model.generate(input_ids[:,-1].unsqueeze(0).to(device),use_rope=True)
        
        next_token_idx = choose_token(logits[:, -1, :],top_k = 50)                    #每次新生成一个单词，从概率最高的前50个单词里选
        input_ids = torch.cat([input_ids, torch.tensor([next_token_idx]).unsqueeze(0)], dim=1)   #t+1步的输入
        if next_token_idx == tokenizer.eos_id:                       #遇到终止符就停止输出
            break
    
    output = tokenizer.decode(input_ids.cpu().tolist())
    print(output[0])
    return output

In [24]:
#对teacher-forcing模型进行评测
model = DecoderOnlyTransformer(bos_id = start_id, eos_id = end_id, pad_id=pad_id, d_model=d_model, n_heads=n_heads, d_hidden=d_hidden, drop_prob=drop_prob1, voc_size=voc_size, seq_len=seq_len, n_layers=n_layers)
model.to(device)
model_checkpoint_path = 'checkpoint_steps_9800.pth'
instruction = "Sir,"
output_poem = generate_poem(model, model_checkpoint_path, instruction)

Sir,
And yet; come at King Aumerlemen, let me,
Your voices;

PETRUCHELLO,

My soul
Away on my lord,
Whose the cause.
I may
HORTENSIO:
CAPULIETRUCHIO:
And take thee from Edward's face, and thou art too?


If they be so much.
With two of the other! What was not well for her son will she is:
If I's:
Theirrah, he was a kindred of God'd with the warrine


The Duke of their good news?
In those that do my brother'll bears that I
PETRUCHEven with a poor Clarence, here.
Thou hast thou art thou darest
Theirator:
HARD IV:
A:
As true man,
That here was been so long-morrow, if thou hast
The king.

A:
Ay me the lamentation!
The gracious sister:
I know it for your brother,
And when she in that I hear the world did make my true lordly for all to keep them to thy frown me so? I had, by the people,
And I may be so;
This day,

The very well:
Godeign and
As it will be like a man;
I am

CENTIO:
Of his eyes
No, I will not.

Whose, go with his life should be found my tongue,
DUCHIO:
Tis hinder the sea,
But t

In [21]:
modelsam_checkpoint_path = 'sam_checkpoint_steps11000.pth'
instruction = "Sir,"
model_sam = DecoderOnlyTransformer(bos_id = start_id, eos_id = end_id, pad_id=pad_id, d_model=d_model, n_heads=n_heads, d_hidden=d_hidden, drop_prob=drop_prob1, voc_size=voc_size, seq_len=seq_len, n_layers=n_layers)
model_sam.to(device)
output_p = generate_poem(model_sam, modelsam_checkpoint_path, instruction)

Sir, then no longer of the sun.
If we may
As he should stand hatches my Lord:
A:

With wicked's
MENES:
'dainst the ground. I would call you are he did not,
And bring
To take you know, here.
T:

O, but how the king, not.
Lie to the heavensio but thou think it out of love.

A:
And with
If I will be gone,
No, you have your

BROMEO:


To use I swear'T:



Why must be, which if thou in them a thing. Be satisfied.
FRIVERSICHisure with thy country, as thou'er:

ILIET:

He have you be married.
IUS:

God save mine.



NORKATHAR:
But, it?


PETRUCHESS OF YORK:

My father?
I should meet at least of my lord,

In mine and I thank you do repass and not the night:
Bark.
H:
MERRY VI
SICINIOLY:

Nurse!

Whom? what means on
And then?
And yet I do not so,
Of mine?
Than death,

No,
I know it! or that is a word, a thousand men's all to be, what, this, and more than that thou wilt speak.
K:
And that all this,
And this, a man; my lord,

O:
Farewell:
Of all the first:

Or that did this I will be the poor brot

7. FLOPS的计算

我们计算batch_size=1的情形，然后推广即可。

一些前置知识：(H, I) * (I, W) --> (H, W), FLOPs 为 H\*W\*2I (如果没有bias, 就是H\*W\*(2I-1))

前向传播：
- Embedding 层： (seq_len, 1) * (seq_len, d_model) --> (seq_len, d_model) 这一步只需要查表，所以没有FLOPs
- dropout层：  推理时会被禁用，忽略
- Attention层：
    - (seq_len, d_model) --> 首先经过三个线性层，nn.Linear(d_model, d_model) , 得到$3*(2d_m-1)*d_ms$
    -  RoPE， 实际上是(n_head, seq_len, d_head) 和 (1, seq_len, d_head)的逐元素乘法， 先广播后计算，得到$2*s*d_m$ --->利用了[$d_h*n = d_m$]
    -  $a = qk^T$: (n_head, seq_len, d_head) @ (n_head, d_head, seq_len) --> (n, seq_len, seq_len)，得到$2s^2d_m$
    -  $o = a/\sqrt(d_m)·v$: (n, seq_len, seq_len) @ (n, seq_len, d_head) --> (n, seq_len, d_head)，得到$2s^2d_m + n_hs^2$ (softmax只计算s次，忽略不计）
    -  最后一个线性层：$2sd_m^2$, 得到输出大小为(seq_len, d_m)
    -  综上所述，attention层的FLOPs约为$10d_m^2s + 2sd_m+n_hs^2$
- FFN层：
    - (seq_len, d_m)-->(seq_len, d_ffn)-->(seq_len,d_m) 主要是两个线性层的FLOPs, 约为$4d_md_fs$
- LayerNorm层：
    - 只在d_model上做计算(求平均值和方差)，但会对每一个元素应用此操作，复杂度为$O(sd_m)$
- Decoder-layer：
    - layernorm + attention + FFN层即为Decoder-layer层的FLOPs，以注意力头n=8, 隐藏层维度$d_f=4d_m$计算，得到$26sd^2 + 2ds+8s^2$
- 输出层:
    - 一个FFN， 负责把模型维度d映射到词表维度V上， FLOPs=$2shV$

- 总和：
    - 6个Decoder-layer + Output layer = $(26sd^2 + 2ds+8s^2)*6+2shV$
    - 带入s=d=512, V=35000, 得到39303774208，约为$3.93*10^{10}$

反向传播：
- 一般认为，反向传播的计算量是前向传播的2倍

In [20]:
#实际计算---FLOPs为18065915904.0, 约为1.8e10, 计算量更少的原因可能是实际前向传播时有一些计算优化。
from thop import profile
model = DecoderOnlyTransformer(bos_id = start_id, eos_id = end_id, pad_id=pad_id, d_model=d_model, n_heads=n_heads, d_hidden=d_hidden, drop_prob=drop_prob1, voc_size=voc_size, seq_len=seq_len, n_layers=n_layers)
model_checkpoint_path = 'checkpoint_steps_9800.pth'
checkpoint = torch.load(model_checkpoint_path)
model.to('cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

input_tensor = torch.randint(20, 3485, (1, 512)).long() #(batch_size, seq_len)
flops, params = profile(model, inputs=(input_tensor,))
print(f"FLOPs: {flops}")
print(f"Parameters: {params}")

[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
FLOPs: 18065915904.0
Parameters: 35322112.0


8. 总结：

- 效果不是特别好，可能因为训练数据过少，导致模型输出的单词没有很强的语义连贯性
- 可能beam_search的推理方案更适用于本实验
- 总体来讲模型学到了古英语风格（使用thy, thou等词而不是your, you）
- 关于超参：尝试学习率1e-5, 1e-4, 1e-3, 发现1e-3似乎更好一些