In [1]:
import numpy as np
import torch.utils.data as Data
import random
import torch
import time
import torch.optim as optim
import json
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.optim.lr_scheduler import LambdaLR
import math

In [2]:
class CFG:
    seed = 42
    model_name = '../all_models/bart-base/'
    text_max_len = 768
    summary_max_len = 256
    batch_size = 8
    epochs = 5
    verbose = 100
    lr = 1e-6  # 学习率
    num_warmup_steps = 0
    num_training_steps = math.ceil(14732 / batch_size) * epochs  # 向上取整

In [3]:
def set_seed(seed):
    """PyTorch随机数种子设置大全"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # CPU上设置随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 当前GPU上设置随机种子
        # A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.
        torch.backends.cudnn.deterministic = True
        # torch.cuda.manual_seed_all(seed) # 所有GPU上设置随机种子


set_seed(CFG.seed)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
class Dataset(Data.Dataset):
    """定义数据集"""

    def __init__(self, dataset_path):
        with open(dataset_path) as f:
            self.dataset = json.loads(f.read())

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        """定义索引方式"""
        text = self.dataset[i]['dialogue']
        summary = self.dataset[i]['summary']
        return text, summary


dataset_train = Dataset("../huggingface_dataset_samsum/train.json")
dataset_valid = Dataset("../huggingface_dataset_samsum/val.json")

print(len(dataset_train), len(dataset_valid))

for text_, summary_ in dataset_train:
    # 调用__getitem__方法
    print(len(str(text_)), len(str(summary_)))
    print(text_, end='\n\n')
    print(summary_)
    break

14732 818
94 56
Amanda: I baked  cookies. Do you want some?
Jerry: Sure!
Amanda: I'll bring you tomorrow :-)

Amanda baked cookies and will bring Jerry some tomorrow.


In [6]:
tokenizer = BartTokenizer.from_pretrained(CFG.model_name)
print(tokenizer.model_input_names)
print(tokenizer)

model_bart = BartForConditionalGeneration.from_pretrained(CFG.model_name)
model_bart = model_bart.to(device)
print(model_bart.num_parameters())

# 优化器
optimizer_adamw = optim.AdamW(model_bart.parameters(), lr=CFG.lr)


def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            # 学习率预热(线性增加)
            return float(current_step) / float(max(1, num_warmup_steps))
        # 学习率线性衰减(最小为0)
        # num_training_steps后学习率恒为0
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return LambdaLR(optimizer, lr_lambda)


scheduler_lr = get_linear_schedule_with_warmup(optimizer_adamw, CFG.num_warmup_steps, CFG.num_training_steps)

['input_ids', 'attention_mask']
BartTokenizer(name_or_path='../all_models/bart-base/', vocab_size=50265, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)}, clean_up_tokenization_spaces=True)
139420416


In [7]:
def get_collate_fn(tokenizer, text_max_len=512, summary_max_len=256):
    """返回collate_fun函数(通过闭包函数引入形参)"""

    def collate_fn(data):
        texts = [i[0] for i in data]
        summarys = [i[1] for i in data]

        texts_encode = tokenizer(texts,
                                 max_length=text_max_len,
                                 padding=True,
                                 truncation=True,
                                 return_tensors='pt')

        summarys_encode = tokenizer(text_target=summarys,
                                    max_length=summary_max_len,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt')

        return {"texts_input_ids": texts_encode['input_ids'],
                "texts_attention_mask": texts_encode['attention_mask'],
                "summarys_input_ids": summarys_encode['input_ids']}

    return collate_fn


dataloader_valid = torch.utils.data.DataLoader(dataset=dataset_valid,
                                               batch_size=CFG.batch_size,
                                               collate_fn=get_collate_fn(tokenizer, CFG.text_max_len,
                                                                         CFG.summary_max_len),
                                               shuffle=False,
                                               drop_last=False)

dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train,
                                               batch_size=CFG.batch_size,
                                               collate_fn=get_collate_fn(tokenizer, CFG.text_max_len,
                                                                         CFG.summary_max_len),
                                               shuffle=True,
                                               drop_last=False)

print(len(dataloader_train))

for i in dataloader_train:
    print(i['texts_input_ids'].shape)  # [batch_size, text_max_len]
    print(i['texts_attention_mask'].shape)
    print(i['summarys_input_ids'].shape)  # [batch_size, summary_max_len]
    break

1842
torch.Size([8, 233])
torch.Size([8, 233])
torch.Size([8, 49])


In [8]:
# 模型训练
def train(model, dataloader, optimizer, device):
    model.train()

    for idx, encode_data in enumerate(dataloader):
        # 数据设备切换
        input_ids = encode_data['texts_input_ids'].to(device)
        attention_mask = encode_data['texts_attention_mask'].to(device)
        labels = encode_data['summarys_input_ids'].to(device)
        # replace padding token id's of the labels by -100 so it's ignored by the loss
        labels[labels == tokenizer.pad_token_id] = - 100

        '''
        huggingfae源码:
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            decoder_input_ids = shift_tokens_right(
                labels, self.config.pad_token_id, self.config.decoder_start_token_id
            )   ### Teacher Forcing 
        '''
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]  # BartForConditionalGeneration源码内封装损失函数:CrossEntropyLoss(ignore_index=-100)

        loss.backward()
        optimizer.step()
        scheduler_lr.step()

        if idx % CFG.verbose == 0 and idx > 0:
            print('| step {:5d} | loss {:8.5f} |'.format(idx, loss.item()))

In [9]:
# 模型评估
def evaluate(model, dataloader, device):
    model.eval()

    all_blue_score = []
    for idx, encode_data in enumerate(dataloader):
        # 数据设备切换
        input_ids = encode_data['texts_input_ids'].to(device)
        attention_mask = encode_data['texts_attention_mask'].to(device)
        labels = encode_data['summarys_input_ids'].to(device)

        predictions = model.generate(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     max_length=256,
                                     num_beams=2,
                                     repetition_penalty=2.5,
                                     length_penalty=1,
                                     early_stopping=True)

        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        all_blue_score.extend([sentence_bleu(references=[decoded_labels[i].split(' ')],
                                             hypothesis=decoded_preds[i].split(' '),
                                             weights=[0.25, 0.25, 0.25, 0.25],
                                             smoothing_function=SmoothingFunction().method1)
                               for i in range(predictions.shape[0])])

    return np.array(all_blue_score).mean()

In [10]:
for epoch in range(1, CFG.epochs + 1):
    epoch_start_time = time.time()

    train(model_bart, dataloader_train, optimizer_adamw, device)
    blue_val = evaluate(model_bart, dataloader_valid, device)

    print('-' * 60)
    print('| end of epoch {:5d} | time: {:5.2f}s | valid blue {:8.5f} |'.format(epoch,
                                                                                time.time() - epoch_start_time,
                                                                                blue_val))
    print('-' * 60)

| step   100 | loss  2.50329 |
| step   200 | loss  2.20362 |
| step   300 | loss  2.45003 |
| step   400 | loss  2.15884 |
| step   500 | loss  2.31759 |
| step   600 | loss  2.31060 |
| step   700 | loss  2.31725 |
| step   800 | loss  2.13361 |
| step   900 | loss  2.23650 |
| step  1000 | loss  2.02029 |
| step  1100 | loss  1.68227 |
| step  1200 | loss  2.31555 |
| step  1300 | loss  2.15683 |
| step  1400 | loss  1.78398 |
| step  1500 | loss  2.19451 |
| step  1600 | loss  1.88260 |
| step  1700 | loss  1.83507 |
| step  1800 | loss  1.82666 |
------------------------------------------------------------
| end of epoch     1 | time: 119.62s | valid blue  0.10503 |
------------------------------------------------------------
| step   100 | loss  2.10892 |
| step   200 | loss  1.42785 |
| step   300 | loss  1.94260 |
| step   400 | loss  1.66786 |
| step   500 | loss  1.98646 |
| step   600 | loss  1.55752 |
| step   700 | loss  1.92907 |
| step   800 | loss  1.45522 |
| step   90