In [None]:
#@title 链接Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [47]:
#@title 自定义GPT2模型
import sys

sys.path.append("..")
from transformers import BertTokenizer, GPT2LMHeadModel
from torch import nn

# from utils.utils import get_project_rootpath
import os


class GPT2(nn.Module):
    def __init__(self):
        super(GPT2, self).__init__()

        # self.gpt = GPT2LMHeadModel.from_pretrained(os.path.join(get_project_rootpath(), "gpt2-chinese-cluecorpussmall"))

        self.gpt = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall")


    def forward(self, batch_inputs):
        outputs = self.gpt(input_ids=batch_inputs)
        return outputs

    @property
    def config(self):
        # 返回模型的配置
        return self.gpt.config


    @property
    def device(self):
        # Provide the device attribute for the model
        return next(self.parameters()).device

    def to(self, device):
        # Move the model and its parameters to the specified device
        self.gpt.to(device)
        return self



## 数据

In [4]:
#@title 数据加载
import json
import torch
import torch.utils.data as Data
from torch import nn, optim
import numpy as np


def make_data(file_path, tokenizer):
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    train_datas = []
    for line in lines:
        line = line.strip()
        train_data = [i if i != '\t' else "[SEP]" for i in line] + ['[SEP]']
        train_num_data = tokenizer.encode(train_data)
        train_num_data = train_num_data[:-1]
        train_datas.append(train_num_data)

    return train_datas


class MyDataSet(Data.Dataset):
    def __init__(self, datas, vocab2id):
        self.datas = datas
        self.vocab2id = vocab2id

    def __getitem__(self, item):
        data = self.datas[item]
        decoder_input = data[:-1]
        decoder_output = data[1:]

        decoder_input_len = len(decoder_input)
        decoder_output_len = len(decoder_output)

        return {"decoder_input": decoder_input, "decoder_input_len": decoder_input_len,
                "decoder_output": decoder_output, "decoder_output_len": decoder_output_len}

    def __len__(self):
        return len(self.datas)

    def padding_batch(self, batch):
        decoder_input_lens = [d["decoder_input_len"] for d in batch]
        decoder_output_lens = [d["decoder_output_len"] for d in batch]

        decoder_input_maxlen = max(decoder_input_lens)
        decoder_output_maxlen = max(decoder_output_lens)

        for d in batch:
            d["decoder_input"].extend([self.vocab2id["[PAD]"]] * (decoder_input_maxlen - d["decoder_input_len"]))
            d["decoder_output"].extend([self.vocab2id["[PAD]"]] * (decoder_output_maxlen - d["decoder_output_len"]))
        decoder_inputs = torch.tensor([d["decoder_input"] for d in batch], dtype=torch.long)
        decoder_outputs = torch.tensor([d["decoder_output"] for d in batch], dtype=torch.long)

        return decoder_inputs, decoder_outputs



In [19]:
#@title 自定义数据

%%writefile selfTxt.txt
谢谢你所做的一切
你开心就好
开心
嗯因为你的心里只有学习
某某某，还有你
这个某某某用的好

你们宿舍都是这么厉害的人吗
眼睛特别搞笑这土也不好捏但就是觉得挺可爱
特别可爱啊

今天好点了吗？
一天比一天严重
吃药不管用，去打一针。别拖着

是的。下辈子想做只萤火虫
可是萤火虫太容易被抓了还是改一个吧
不，我只想奋不顾身扑火

加油，三月动起来，五月笑起来
正解你为什么就那么厉害呢
哈哈，没办法，智商就是这么高
你这是要开始得瑟了吗！好啦！你最厉害！
哈哈哈哈

好身材，秀出来
哈哈哈其实我是胖的
不会的


Writing selfTxt.txt


## 训练

In [12]:
#@title AverageMeter
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.count = None
        self.sum = None
        self.avg = None
        self.val = None
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
#@title 训练过程
import json
import os

import torch
import sys

sys.path.append("..")
import torch.utils.data as Data
from torch import nn, optim
import numpy as np
import time
from tqdm import tqdm

from transformers import BertTokenizer

class TrainArgs:
    def __init__(self):
        self.device = "cpu"
        self.batch_size = 4
        self.epochs = 1
        self.print_every = 10
        self.clip = 1
        # self.train_file_path = "/content/drive/MyDrive/train.txt"
        self.train_file_path = "/content/selfTxt.txt"
        self.save_path = "GPT2.pt"
        self.lr = 1e-4

# 实例化 TrainArgs
train_args = TrainArgs()

# 设置属性
train_args.device = "cpu"
train_args.batch_size = 4
train_args.epochs = 1
train_args.print_every = 10
train_args.clip = 1
# train_args.train_file_path = "/content/drive/MyDrive/train.txt"
train_args.train_file_path = "/content/selfTxt.txt"
train_args.save_path = "GPT2.pt"
train_args.lr = 1e-4


def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


def train_step(model, data_loader, epoch, optimizer, criterion, clip=1, print_every=None):
    model.train()

    if print_every == 0:
        print_every = 1

    epoch_loss = 0
    losses = AverageMeter()
    temp_time = time.time()
    for step, (dec_inputs, dec_outputs) in enumerate(data_loader):
        '''
        dec_inputs: [batch_size, tgt_len]
        dec_outputs: [batch_size, tgt_len]
        '''
        optimizer.zero_grad()
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        outputs = model(dec_inputs)
        outputs = outputs.logits
        outputs = outputs.view(-1,outputs.size(-1))
        loss = criterion(outputs, dec_outputs.view(-1))
        epoch_loss += loss.item()
        losses.update(loss.item(), batch_size)

        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        if print_every and (step + 1) % print_every == 0:
            minutes, seconds = epoch_time(temp_time, time.time())
            print('Epoch: [{0}][{1}/{2}] '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Elapsed {minutes:s}min {seconds:s}s '
                  .format(epoch, step + 1, len(data_loader),
                          minutes=minutes.__str__(),
                          seconds=seconds.__str__(),
                          loss=losses))
            temp_time = time.time()

    return epoch_loss / len(data_loader)


def train(model, dataloader, train_args):
    criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)
    lr = train_args.lr
    CLIP = train_args.clip
    print_every = train_args.print_every
    save_path = train_args.save_path
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        start_time = time.time()
        train_loss = train_step(model, dataloader, epoch, optimizer, criterion, CLIP, print_every=print_every)
        end_time = time.time()

        torch.save(model.state_dict(), save_path)

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)
        print(f'Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f}')


def print_num_parameters(model):
    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(

        p.numel() for p in model.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')


if __name__ == '__main__':
    device = train_args.device
    # tokenizer = BertTokenizer.from_pretrained(os.path.join(get_project_rootpath(), "gpt2-chinese-cluecorpussmall"))
    tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")

    epochs = train_args.epochs
    batch_size = train_args.batch_size

    train_file_path = train_args.train_file_path
    datas = make_data(train_file_path, tokenizer)
    dataset = MyDataSet(datas, tokenizer.vocab)
    dataloader = Data.DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.padding_batch)

    model = GPT2().to(device)
    train(model, dataloader, train_args)



In [15]:
#@title 训练参数参考
import argparse

import sys

sys.path.append("..")
# from utils.utils import get_project_rootpath
import os

# checkpoints_dir = os.path.join(get_project_rootpath(), "model_checkpoints")


def train_parse_args():
    parser = argparse.ArgumentParser(description="训练参数配置")
    parser.add_argument("--device", type=str, default="cuda", help="batch size")
    parser.add_argument("--batch_size", type=int, default=4, help="batch size")
    parser.add_argument("--epochs", type=int, default=1, help="epochs")
    parser.add_argument("--print_every", type=int, default=10, help="print every")
    parser.add_argument("--clip", type=int, default=1, help="clip")


    parser.add_argument("--train_file_path", type=str, default=os.path.join("","/content/drive/MyDrive/train.txt"),
                        help="train_file_path")

    parser.add_argument('--save_path', type=str, default=os.path.join("", "GPT2.pt"),
                        help='decay step')
    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')


    return parser.parse_args()


## 验证



In [30]:
def get_project_rootpath():
    """
    获取项目根目录。此函数的能力体现在，不论当前module被import到任何位置，都可以正确获取项目根目录
    :return:
    """
    path = os.path.realpath(os.curdir)
    while True:
        # PyCharm项目中，'.idea'是必然存在的，且名称唯一
        if '.idea' in os.listdir(path):
            return path
        path = os.path.dirname(path)


In [None]:
#@title validate result

from transformers import BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline
tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
model = GPT2().to(device)
# 加载模型权重
model.load_state_dict(torch.load('GPT2.pt'))

# 设置模型为评估模式
model.eval()

# 创建文本生成管道
text_generator = TextGenerationPipeline(model.gpt, tokenizer)

# 使用模型进行文本生成
result = text_generator("谢谢你所做的一切", max_length=100, do_sample=True)
print(result)
