In [1]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
import numpy as np
import torch.utils.data as Data
import random
import torch
import time
import torch.optim as optim
import json
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [2]:
class CFG:
    seed = 42
    model_name = '../all_models/t5-base/'
    text_max_len = 768
    summary_max_len = 256
    batch_size = 8
    epochs = 5
    verbose = 100
    lr = 2e-5  # 学习率

In [3]:
def set_seed(seed):
    """PyTorch随机数种子设置大全"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # CPU上设置随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 当前GPU上设置随机种子
        # A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.
        torch.backends.cudnn.deterministic = True
        # torch.cuda.manual_seed_all(seed) # 所有GPU上设置随机种子


set_seed(CFG.seed)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
class Dataset(Data.Dataset):
    """定义数据集"""

    def __init__(self, dataset_path):
        with open(dataset_path) as f:
            self.dataset = json.loads(f.read())

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        """定义索引方式"""
        text = self.dataset[i]['dialogue']
        summary = self.dataset[i]['summary']
        return text, summary


dataset_train = Dataset("../huggingface_dataset_samsum/train.json")
dataset_valid = Dataset("../huggingface_dataset_samsum/val.json")

print(len(dataset_train), len(dataset_valid))

for text_, summary_ in dataset_train:
    # 调用__getitem__方法
    print(len(str(text_)), len(str(summary_)))
    print(text_, end='\n\n')
    print(summary_)
    break

14732 818
94 56
Amanda: I baked  cookies. Do you want some?
Jerry: Sure!
Amanda: I'll bring you tomorrow :-)

Amanda baked cookies and will bring Jerry some tomorrow.


In [6]:
tokenizer = AutoTokenizer.from_pretrained(CFG.model_name)
print(tokenizer.model_input_names)
print(tokenizer)

model_t5 = T5ForConditionalGeneration.from_pretrained(CFG.model_name)
model_t5 = model_t5.to(device)
print(model_t5.num_parameters())

# 优化器
optimizer_adamw = optim.AdamW(model_t5.parameters(), lr=CFG.lr)

['input_ids', 'attention_mask']
T5TokenizerFast(name_or_path='../all_models/t5-base/', vocab_size=32100, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<e

In [7]:
def get_collate_fn(tokenizer, text_max_len=512, summary_max_len=256):
    """返回collate_fun函数(通过闭包函数引入形参)"""

    def collate_fn(data):
        prefix = "summarize: "

        texts = [prefix + i[0] for i in data]
        summarys = [i[1] for i in data]

        texts_encode = tokenizer(texts,
                                 max_length=text_max_len,
                                 padding=True,
                                 truncation=True,
                                 return_tensors='pt')

        summarys_encode = tokenizer(text_target=summarys,
                                    max_length=summary_max_len,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt')

        return {"texts_input_ids": texts_encode['input_ids'],
                "texts_attention_mask": texts_encode['attention_mask'],
                "summarys_input_ids": summarys_encode['input_ids']}

    return collate_fn


dataloader_valid = torch.utils.data.DataLoader(dataset=dataset_valid,
                                               batch_size=CFG.batch_size,
                                               collate_fn=get_collate_fn(tokenizer, CFG.text_max_len,
                                                                         CFG.summary_max_len),
                                               shuffle=False,
                                               drop_last=False)

dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train,
                                               batch_size=CFG.batch_size,
                                               collate_fn=get_collate_fn(tokenizer, CFG.text_max_len,
                                                                         CFG.summary_max_len),
                                               shuffle=True,
                                               drop_last=False)

print(len(dataloader_train))

for i in dataloader_train:
    print(i['texts_input_ids'].shape)  # [batch_size, text_max_len]
    print(i['texts_attention_mask'].shape)
    print(i['summarys_input_ids'].shape)  # [batch_size, summary_max_len]
    break

1842
torch.Size([8, 374])
torch.Size([8, 374])
torch.Size([8, 74])


In [8]:
# 模型训练
def train(model, dataloader, optimizer, device):
    model.train()

    for idx, encode_data in enumerate(dataloader):
        # 数据设备切换
        input_ids = encode_data['texts_input_ids'].to(device)
        attention_mask = encode_data['texts_attention_mask'].to(device)
        labels = encode_data['summarys_input_ids'].to(device)
        # replace padding token id's of the labels by -100 so it's ignored by the loss
        labels[labels == tokenizer.pad_token_id] = - 100

        '''
        huggingfae源码:
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)  ### Teacher Forcing 
        '''
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]  # T5ForConditionalGeneration源码内封装损失函数:CrossEntropyLoss(ignore_index=-100)

        loss.backward()
        optimizer.step()

        if idx % CFG.verbose == 0 and idx > 0:
            print('| step {:5d} | loss {:8.5f} |'.format(idx, loss.item()))

In [9]:
# 模型评估
def evaluate(model, dataloader, device):
    model.eval()

    all_blue_score = []
    for idx, encode_data in enumerate(dataloader):
        # 数据设备切换
        input_ids = encode_data['texts_input_ids'].to(device)
        attention_mask = encode_data['texts_attention_mask'].to(device)
        labels = encode_data['summarys_input_ids'].to(device)

        predictions = model.generate(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     max_length=256,
                                     num_beams=2,
                                     repetition_penalty=2.5,
                                     length_penalty=1,
                                     early_stopping=True)

        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        all_blue_score.extend([sentence_bleu(references=[decoded_labels[i].split(' ')],
                                             hypothesis=decoded_preds[i].split(' '),
                                             weights=[0.25, 0.25, 0.25, 0.25],
                                             smoothing_function=SmoothingFunction().method1)
                               for i in range(predictions.shape[0])])

    return np.array(all_blue_score).mean()

In [10]:
for epoch in range(1, CFG.epochs + 1):
    epoch_start_time = time.time()

    train(model_t5, dataloader_train, optimizer_adamw, device)
    blue_val = evaluate(model_t5, dataloader_valid, device)

    print('-' * 60)
    print('| end of epoch {:5d} | time: {:5.2f}s | valid blue {:8.5f} |'.format(epoch,
                                                                                time.time() - epoch_start_time,
                                                                                blue_val))
    print('-' * 60)

| step   100 | loss  2.06461 |
| step   200 | loss  2.08177 |
| step   300 | loss  2.08045 |
| step   400 | loss  2.00902 |
| step   500 | loss  1.66362 |
| step   600 | loss  1.75473 |
| step   700 | loss  1.74266 |
| step   800 | loss  1.44380 |
| step   900 | loss  2.16509 |
| step  1000 | loss  1.96669 |
| step  1100 | loss  2.21257 |
| step  1200 | loss  2.31056 |
| step  1300 | loss  2.12070 |
| step  1400 | loss  2.15469 |
| step  1500 | loss  1.67766 |
| step  1600 | loss  1.69820 |
| step  1700 | loss  1.79850 |
| step  1800 | loss  1.75700 |
------------------------------------------------------------
| end of epoch     1 | time: 204.84s | valid blue  0.09223 |
------------------------------------------------------------
| step   100 | loss  1.55292 |
| step   200 | loss  1.67972 |
| step   300 | loss  1.46360 |
| step   400 | loss  1.78757 |
| step   500 | loss  1.69956 |
| step   600 | loss  1.87487 |
| step   700 | loss  1.39110 |
| step   800 | loss  1.19038 |
| step   90