# 文本摘要实战---transformers

In [1]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments



In [2]:
ds = Dataset.load_from_disk("./nlpcc_2017/")
ds

Dataset({
    features: ['title', 'content'],
    num_rows: 5000
})

In [3]:
ds[0]

{'title': '澳大利亚央行将利率降至纪录低点,以应对疲软的经济前景,并遏制澳元进一步走强。',
 'content': '澳大利亚央行将利率降至纪录低点,以应对疲软的经济前景,并遏制澳元进一步走强。05/0513:37|评论(0)A+澳大利亚央行周二发布声明称,将关键利率由2.25%调降至2%,符合此前交易员及接受彭博调查的29位经济学家中25位的预期。据彭博社报道,上月澳央行官员曾警告,矿业之外的行业投资可能下滑。澳大利亚政府不太可能推出新的刺激措施,来扶助受本币升值和铁矿石价格下跌打击而低于潜在水平的经济增长。“鉴于大宗商品价格下跌,矿业投资还可能有低于当前预期的风险,”预计到降息的澳新银行高级经济学家FelicityEmmett在决议公布前编写的研究报告中称。他表示此次决议可能反映出“央行经济增长预估轨迹有所下调”。'}

In [4]:

ds = ds.train_test_split(test_size=100, seed=42)
ds



DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 4900
    })
    test: Dataset({
        features: ['title', 'content'],
        num_rows: 100
    })
})

In [5]:


ds["train"][0]


{'title': '组图:黑河边防军人零下30℃户外训练,冰霜沾满眉毛和睫毛,防寒服上满是冰霜。',
 'content': '中国军网2014-12-1709:08:0412月16日,黑龙江省军区驻黑河某边防团机动步兵连官兵,冒着-30℃严寒气温进行体能训练,挑战极寒,锻造钢筋铁骨。该连素有“世界冠军的摇篮”之称,曾有5人24人次登上世界军事五项冠军的领奖台。(魏建顺摄)黑龙江省军区驻黑河某边防团机动步兵连官兵冒着-30℃严寒气温进行体能训练驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜官兵睫毛上都被冻上了冰霜官兵们睫毛上都被冻上了冰霜驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练'}

In [6]:
tokenizer = AutoTokenizer.from_pretrained("Langboat/mengzi-t5-base", cache_dir="mengzi-t5-base")
tokenizer

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


T5TokenizerFast(name_or_path='Langboat/mengzi-t5-base', vocab_size=32128, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', 

In [16]:

def process_func(examples):

    contents = ["摘要生成：\n" + e for e in examples["content"]]  # 每个元素：摘要生成:\n + 一个content

    inputs =  tokenizer(contents, max_length=384, truncation=True)
    labels = tokenizer(text_target=examples["title"], max_length=64, truncation=True)
    inputs["labels"] = labels["input_ids"]

    return inputs



In [17]:
tokenized_ds = ds.map(process_func, batched=True)
tokenized_ds

Map:   0%|          | 0/4900 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 4900
    })
    test: Dataset({
        features: ['title', 'content', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 100
    })
})

In [18]:

tokenized_ds["train"][0]


{'title': '组图:黑河边防军人零下30℃户外训练,冰霜沾满眉毛和睫毛,防寒服上满是冰霜。',
 'content': '中国军网2014-12-1709:08:0412月16日,黑龙江省军区驻黑河某边防团机动步兵连官兵,冒着-30℃严寒气温进行体能训练,挑战极寒,锻造钢筋铁骨。该连素有“世界冠军的摇篮”之称,曾有5人24人次登上世界军事五项冠军的领奖台。(魏建顺摄)黑龙江省军区驻黑河某边防团机动步兵连官兵冒着-30℃严寒气温进行体能训练驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜官兵睫毛上都被冻上了冰霜官兵们睫毛上都被冻上了冰霜驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练',
 'input_ids': [7,
  17965,
  5353,
  13,
  5893,
  349,
  295,
  2927,
  13049,
  114,
  1238,
  6473,
  13,
  5844,
  13,
  6849,
  364,
  50,
  1095,
  70,
  3,
  15059,
  12850,
  3168,
  409,
  542,
  717,
  505,
  1139,
  877,
  20399,
  17187,
  375,
  14061,
  3,
  18694,
  11081,
  10788,
  722,
  24698,
  6472,
  103,
  16773,
  1332,
  3,
  2197,
  1164,
  1342,
  3,
  26664,
  12966,
  789,
  1195,
  4,
  206,
  375,
  18665,
  31,
  304,
  2808,
  5,
  24456,
  41,
  12562,
  3,
  17222,
  108,
  21,
  1367,
  14434,
  8768,
  304,
  2898,
  270,
  1221,
  2808,
  5,
  1324,
  1065,
  498,
  4

In [20]:
tokenizer.decode(tokenized_ds["train"][0]["labels"])

'组图:黑河边防军人零下30°C户外训练,冰霜沾满眉毛和睫毛,防寒服上满是冰霜。</s>'

In [22]:
tokenizer.eos_token, tokenizer.eos_token_id, tokenizer.bos_token, tokenizer.bos_token_id

('</s>', 1, None, None)

In [23]:


model = AutoModelForSeq2SeqLM.from_pretrained("Langboat/mengzi-t5-base", cache_dir="mengzi-t5-base")


model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

In [30]:

import numpy as np
from rouge_chinese import Rouge

rouge = Rouge()

# 创建评估函数


def compute_metric(evalPred):

    predictions, labels = evalPred
    decode_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels!=-100, labels, tokenizer.pad_token_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decode_preds = [" ".join(p) for p in decode_preds]
    deocde_labels = [" ".join(p) for p in decode_labels]
    scores = rouge.get_scores(decode_preds, decode_labels, avg=True)
    return {"rouge-1": scores["rouge-1"]["f"], "rouge-2": scores["rouge-2"]["f"], "rouge-l": scores["rouge-l"]["f"]}




In [38]:

# 配置训练参数

args = Seq2SeqTrainingArguments(output_dir="./text_summarition", per_device_train_batch_size=16,
                               per_device_eval_batch_size=16,
                               gradient_accumulation_steps=4,
                               logging_steps=8,
                                num_train_epochs=5,
                               evaluation_strategy="epoch",
                               save_strategy="epoch",
                               metric_for_best_model="rouge-l",
                                predict_with_generate=True, learning_rate=1e-4
                               )





In [39]:

# 创建训练器
def compute_metric(evalPred):
    predictions, labels = evalPred
    decode_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decode_preds = [" ".join(p) for p in decode_preds]
    decode_labels = [" ".join(l) for l in decode_labels]
    scores = rouge.get_scores(decode_preds, decode_labels, avg=True)
    return {
        "rouge-1": scores["rouge-1"]["f"],
        "rouge-2": scores["rouge-2"]["f"],
        "rouge-l": scores["rouge-l"]["f"],
    }

from transformers import DataCollatorForSeq2Seq

trainer = Seq2SeqTrainer(model=model,args=args,train_dataset=tokenized_ds["train"],eval_dataset=tokenized_ds["test"],
                        compute_metrics=compute_metric,tokenizer=tokenizer,
                        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer))





In [40]:
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
0,0.8665,1.819458,0.504531,0.348619,0.4308
1,0.6636,1.854909,0.509536,0.357789,0.436331
2,0.6438,1.815585,0.531996,0.378386,0.458767
4,0.7242,1.762014,0.537444,0.392941,0.462732




TrainOutput(global_step=380, training_loss=0.6630201590688605, metrics={'train_runtime': 359.8006, 'train_samples_per_second': 68.093, 'train_steps_per_second': 1.056, 'total_flos': 1.0731426283229184e+16, 'train_loss': 0.6630201590688605, 'epoch': 4.95114006514658})