# GPT训练
    - 无监督训练：以自回归的方式，基于transformer解码器原理，采用掩码自注意力机制,基于字节对编码算法（byte-pair Encoding）,将文本分割为小单元    
    - 有监督微调：根据下游任务(如文本分类、问答、摘要)调整模型参数，使其具有特定领域能力,联合优化任务损失和训练损失
# 解码策略
    - 贪心算法：每次预测都取概率最大的.方法比较简单，但结果全是相同的。只考虑当前最优解，忽视全局最优解.  
        --top K:从概率超过K的中，随机选择一个.,top P方法优化  
        --top P:采用累计分布概率，设定P值,选择累积概率不超过P的序列
        --温度调节:每一个词生产一个logits,使用logits/tem,用来控制logits的分布，tem越高，则越有创意，tem越低则越保守
    - 束搜索：生成过程，
        --有一个beam size=K，每次都在选择K长度的序列，计算联合概率

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as fc

class MiniGPT(nn.Module):
    def __init__(self,vocab_size = 10000,embed_dim = 128,num_layer = 3,max_len = 512):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size,embed_dim)
        self.pos_embed = nn.Embedding(max_len,embed_dim)
        # 使用编码器层代替GPT解码器
        # pass

        # 最小改动解码器，实现GPT
        self.layers = nn.ModuleList([
                                    nn.TransformerDecoderLayer(
                                        d_model=embed_dim,
                                        nhead=4,
                                        dim_feedforward=embed_dim*4,
                                        dropout = 0.1,
                                        batch_first = True,                        
                                    ) for _ in range(num_layer)
                                    ])
        # 为什么使用ModuleList？使用普通列表，无法将所有参数绑定到类中，使用ModuleList可以将所有可训练parameter()收集
        # 语言模型的输出层
        self.lm_out  = nn.Linear(embed_dim,vocab_size)
        # 分类层：用于模型微调时添加
        self.classifier = None

    def forward(self,input_ids):
        # input_ids.shape = (batch_size,seq_len)
        seq_len = input_ids.size(1)
        pos = torch.arange(seq_len,device=input_ids.device)
        # self.token_embed(input_ids).shape = (seq_len,embed_dim)
        x  = self.token_embed(input_ids)+self.pos_embed(pos)
        mask = nn.Transformer.generate_square_subsequent_mask(seq_len,device=input_ids.device)
        for layer in self.layers:
            x = layer(tgt=x,memory = x,tgt_mask = mask)
        # 语言模型输出
        logits = self.lm_out(x)
        # 分类任务输出
        if self.classifier is not None:
            # 取序列中的首token,作为分类
            cls_logits = self.classifier(x[:,0,:])
            return logits,cls_logits
        return logits

In [3]:
def test():
    vocab_size = 100
    embed_dim = 64
    n_layer = 6
    max_len = 16
    batch_size = 2
    seq_len = 4

    model = MiniGPT(vocab_size,embed_dim,n_layer,max_len)
    input_ids = torch.randint(0,vocab_size,(batch_size,seq_len))
    logits = model(input_ids)
    print('语言模型的形状',logits.shape)
    model.classifier = nn.Linear(embed_dim,3)
    logits,cls_logits = model(input_ids)
    print('分类logits',cls_logits.shape)
test()

语言模型的形状 torch.Size([2, 4, 100])
分类logits torch.Size([2, 3])


# 数据集类

In [21]:
import re
import numpy as np
from torch.utils.data import Dataset,DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm

In [31]:
class TextDataset(Dataset):
    '''文本数据集类，用于预训练模型'''
    def __init__(self,file_path,tokenizer,block_size = 512,stride = 384):
        '''
        parameter
        ------------
        tokenizer:AotoTonkenizer
            分词器
        block_size:int
            分块大小
        stride:int
            滑动的步长
        '''
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.stride = stride

        with open(file_path,mode = 'r',encoding='utf-8') as f:
            text = self.clear_text(f.read())
        self.tokens = self.tokenizer.encode(text,add_special_tokens = True)
        # 若数据量很大，使用动态截取的方式，而不是事先全部计算

    def clear_text(self,text):
        # 段落间的分隔符
        token = self.tokenizer.sep_token + self.tokenizer.cls_token
        text = text.strip()
        text = re.sub('\n',token,text)
        text = re.sub(r'\s+','',text)
        return text
    
    def __len__(self):
        # 计算总样本数:（总长度-块长度）/步长+1
        return (len(self.tokens)-self.block_size)//self.stride+1
    
    def __getitem__(self,idx):
        start = idx * self.stride
        end = start + self.block_size
        # 动态截取token块。
        chunk = self.tokens[start:end]
        x = torch.tensor(chunk[:-1],dtype = torch.long)
        y = torch.tensor(chunk[1:],dtype = torch.long)
        return x,y
    


# 预训练

In [20]:
tokenizer = AutoTokenizer.from_pretrained('uer/gpt2-chinese-cluecorpussmall')

In [35]:
block_size = 64
stride = 64
batch_size = 64
embed_dim = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 50
def pretrain():
    dataset = TextDataset(
        file_path=r'E:\LMAI_study\05 深度学习\data\天龙八部分类.txt',
        tokenizer=tokenizer,
        block_size=block_size,
        stride=stride
    )

    train_loader = DataLoader(dataset,batch_size=batch_size,shuffle = True)

    model = MiniGPT(vocab_size=tokenizer.vocab_size,embed_dim=embed_dim).to(device)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        train_loss = 0
        bar = tqdm(train_loader)
        for input,target in bar:
            input ,target = input.to(device) ,target.to(device)
            logits =  model(input)
            # 对不连续的变量改变形状，不能直接用view,可以用reshape,或调用contiguous方法
            loss = criterion(logits.view(-1,tokenizer.vocab_size),target.contiguous().view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
            # 训练一个批次后，显示当前批次的损失
            bar.set_description(f'epoch:{epoch+1} train loss:{loss.item():.4f}')
        train_loss /= len(train_loader)
        print(train_loss)
    torch.save(model.state_dict(),'model.pth')

pretrain()

epoch:1 train loss:9.6956: 100%|██████████| 2/2 [00:00<00:00,  4.92it/s] 


9.903632164001465


epoch:2 train loss:9.0865: 100%|██████████| 2/2 [00:00<00:00,  9.45it/s]


9.212864398956299


epoch:3 train loss:8.6985: 100%|██████████| 2/2 [00:00<00:00,  9.98it/s]


8.788248062133789


epoch:4 train loss:8.2966: 100%|██████████| 2/2 [00:00<00:00, 11.17it/s]


8.402206420898438


epoch:5 train loss:7.8265: 100%|██████████| 2/2 [00:00<00:00, 10.84it/s]


7.955410957336426


epoch:6 train loss:7.3768: 100%|██████████| 2/2 [00:00<00:00,  9.99it/s]


7.4914422035217285


epoch:7 train loss:6.9836: 100%|██████████| 2/2 [00:00<00:00, 10.30it/s]


7.076613426208496


epoch:8 train loss:6.5668: 100%|██████████| 2/2 [00:00<00:00, 11.04it/s]


6.6882734298706055


epoch:9 train loss:6.2824: 100%|██████████| 2/2 [00:00<00:00, 10.40it/s]


6.343282699584961


epoch:10 train loss:5.9418: 100%|██████████| 2/2 [00:00<00:00,  9.61it/s]


6.015002250671387


epoch:11 train loss:5.6256: 100%|██████████| 2/2 [00:00<00:00,  9.67it/s]


5.722519874572754


epoch:12 train loss:5.4582: 100%|██████████| 2/2 [00:00<00:00,  9.83it/s]


5.490045070648193


epoch:13 train loss:5.2236: 100%|██████████| 2/2 [00:00<00:00,  9.47it/s]


5.265669345855713


epoch:14 train loss:4.9554: 100%|██████████| 2/2 [00:00<00:00,  9.48it/s]


5.051986217498779


epoch:15 train loss:4.8225: 100%|██████████| 2/2 [00:00<00:00,  9.66it/s]


4.882689714431763


epoch:16 train loss:4.7110: 100%|██████████| 2/2 [00:00<00:00,  9.51it/s]


4.727940797805786


epoch:17 train loss:4.5061: 100%|██████████| 2/2 [00:00<00:00, 10.01it/s]


4.564103126525879


epoch:18 train loss:4.4196: 100%|██████████| 2/2 [00:00<00:00,  9.31it/s]


4.4329445362091064


epoch:19 train loss:4.2171: 100%|██████████| 2/2 [00:00<00:00,  9.75it/s]


4.277642250061035


epoch:20 train loss:4.0521: 100%|██████████| 2/2 [00:00<00:00,  9.30it/s]


4.134768724441528


epoch:21 train loss:4.0568: 100%|██████████| 2/2 [00:00<00:00, 10.26it/s]


4.026669859886169


epoch:22 train loss:3.8333: 100%|██████████| 2/2 [00:00<00:00,  9.53it/s]


3.8760769367218018


epoch:23 train loss:3.7208: 100%|██████████| 2/2 [00:00<00:00,  9.31it/s]


3.7517181634902954


epoch:24 train loss:3.6317: 100%|██████████| 2/2 [00:00<00:00,  9.75it/s]


3.6328240633010864


epoch:25 train loss:3.4529: 100%|██████████| 2/2 [00:00<00:00,  9.52it/s]


3.4954456090927124


epoch:26 train loss:3.3564: 100%|██████████| 2/2 [00:00<00:00, 10.26it/s]


3.376646876335144


epoch:27 train loss:3.2548: 100%|██████████| 2/2 [00:00<00:00,  9.52it/s]


3.2577600479125977


epoch:28 train loss:3.1069: 100%|██████████| 2/2 [00:00<00:00,  9.30it/s]


3.1291568279266357


epoch:29 train loss:2.9360: 100%|██████████| 2/2 [00:00<00:00,  9.45it/s]


2.9945675134658813


epoch:30 train loss:2.9221: 100%|██████████| 2/2 [00:00<00:00,  9.40it/s]


2.89854896068573


epoch:31 train loss:2.7803: 100%|██████████| 2/2 [00:00<00:00, 10.12it/s]


2.776131749153137


epoch:32 train loss:2.6845: 100%|██████████| 2/2 [00:00<00:00,  9.69it/s]


2.6639904975891113


epoch:33 train loss:2.5179: 100%|██████████| 2/2 [00:00<00:00,  9.49it/s]


2.53777813911438


epoch:34 train loss:2.4517: 100%|██████████| 2/2 [00:00<00:00,  9.29it/s]


2.4370702505111694


epoch:35 train loss:2.3002: 100%|██████████| 2/2 [00:00<00:00,  9.52it/s]


2.318273663520813


epoch:36 train loss:2.1749: 100%|██████████| 2/2 [00:00<00:00, 10.14it/s]


2.2085143327713013


epoch:37 train loss:2.1177: 100%|██████████| 2/2 [00:00<00:00,  9.48it/s]


2.112419843673706


epoch:38 train loss:2.0376: 100%|██████████| 2/2 [00:00<00:00,  9.58it/s]


2.0101239681243896


epoch:39 train loss:1.9322: 100%|██████████| 2/2 [00:00<00:00,  9.23it/s]


1.9117707014083862


epoch:40 train loss:1.7854: 100%|██████████| 2/2 [00:00<00:00, 10.24it/s]


1.805662751197815


epoch:41 train loss:1.6884: 100%|██████████| 2/2 [00:00<00:00,  9.23it/s]


1.708232343196869


epoch:42 train loss:1.6106: 100%|██████████| 2/2 [00:00<00:00,  9.31it/s]


1.6194342374801636


epoch:43 train loss:1.4606: 100%|██████████| 2/2 [00:00<00:00,  9.70it/s]


1.5109540820121765


epoch:44 train loss:1.4213: 100%|██████████| 2/2 [00:00<00:00,  9.72it/s]


1.4376680850982666


epoch:45 train loss:1.2803: 100%|██████████| 2/2 [00:00<00:00,  9.87it/s]


1.3484402298927307


epoch:46 train loss:1.2223: 100%|██████████| 2/2 [00:00<00:00,  9.23it/s]


1.2730331420898438


epoch:47 train loss:1.1346: 100%|██████████| 2/2 [00:00<00:00,  9.28it/s]


1.193272054195404


epoch:48 train loss:1.0892: 100%|██████████| 2/2 [00:00<00:00,  9.42it/s]


1.118008553981781


epoch:49 train loss:0.9934: 100%|██████████| 2/2 [00:00<00:00,  9.31it/s]


1.0414519309997559


epoch:50 train loss:0.9763: 100%|██████████| 2/2 [00:00<00:00,  9.88it/s]

0.9846748113632202



