In [6]:
from transformers import AutoTokenizer
from datasets import load_from_disk
from transformers import DataCollatorForSeq2Seq
import evaluate
import numpy as np
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [7]:
billsum = load_from_disk("billsum")['train']
billsum = billsum.remove_columns(['Unnamed: 0', 'title'])

In [8]:
billsum = billsum.train_test_split(test_size=0.2)
billsum

DatasetDict({
    train: Dataset({
        features: ['text', 'summary'],
        num_rows: 15159
    })
    test: Dataset({
        features: ['text', 'summary'],
        num_rows: 3790
    })
})

In [9]:
# transformer bart格式为:
# - single sequence: `<s> X </s>`
# - pair of sequences: `<s> A </s></s> B </s>`
checkpoint = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [10]:
def preprocess_function(examples):
    model_inputs = tokenizer(examples["text"], max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [11]:
# BartForConditionalGeneration forward函数函数签名:
'''
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    decoder_input_ids: Optional[torch.LongTensor] = None,
    decoder_attention_mask: Optional[torch.LongTensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    decoder_head_mask: Optional[torch.Tensor] = None,
    cross_attn_head_mask: Optional[torch.Tensor] = None,
    encoder_outputs: Optional[List[torch.FloatTensor]] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, Seq2SeqLMOutput]:
'''
tokenized_billsum = billsum.map(preprocess_function, batched=True)
tokenized_billsum

  0%|          | 0/16 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 15159
    })
    test: Dataset({
        features: ['text', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 3790
    })
})

In [12]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

Downloading:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [13]:
# 相当于torch.utils.data.DataLoader中collate_fn的作用(可以重写,参考K_demo/way_of_training/pytorch_transformer.ipynb)
# Data collator that will dynamically pad the inputs received, as well as the labels.
'''
model ([`PreTrainedModel`]):
    The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
    prepare the *decoder_input_ids*

    This is useful when using *label_smoothing* to avoid calculating loss twice.
'''
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [15]:
rouge = evaluate.load("rouge")
rouge.description

'ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for\nevaluating automatic summarization and machine translation software in natural language processing.\nThe metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.\n\nNote that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters.\n\nThis metrics is a wrapper around Google Research reimplementation of ROUGE:\nhttps://github.com/google-research/google-research/tree/master/rouge\n'

In [16]:
def compute_metrics(eval_pred):
    # predictions.shape=[batch_size, max(该批次生成句子长度)]
    # labels.shape=[batch_size, max(该批次句子长度)]
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

In [17]:
training_args = Seq2SeqTrainingArguments(
    output_dir="my_awesome_billsum_model_bart",
    save_total_limit=1,
    evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    num_train_epochs=4,

    # 区别于TrainingArguments特有参数:
    # predict_with_generate (bool, optional, defaults to False) — Whether to use generate to calculate generative metrics (ROUGE, BLEU).
    predict_with_generate=True
)

# 继承自Trainer
# predict方法:trainer.predict(tokenized_billsum["test"], **gen_kwargs)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum["train"],
    eval_dataset=tokenized_billsum["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# generate参数含义参考:huggingface GenerationConfig类
# 参考generate_help.md
trainer._gen_kwargs = {
    "repetition_penalty": 1.0,
    "max_length": 50,
    "min_length": 0}  # 用于评估

trainer.train()

The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, summary. If text, summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 15159
  Num Epochs = 4
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3792
You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
500,2.2023,1.756192,0.235,0.1873,0.2274,0.2283,20.0
1000,1.9186,1.669167,0.2378,0.1916,0.2307,0.2315,20.0
1500,1.7907,1.61571,0.2394,0.1944,0.2323,0.2332,20.0
2000,1.7505,1.590419,0.2402,0.1959,0.2335,0.2343,20.0
2500,1.6824,1.575646,0.2411,0.1972,0.2345,0.2352,20.0
3000,1.6713,1.560662,0.2409,0.1973,0.2343,0.235,20.0
3500,1.6422,1.553455,0.2403,0.197,0.234,0.2347,20.0


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, summary. If text, summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3790
  Batch size = 16
Saving model checkpoint to my_awesome_billsum_model_bart/checkpoint-500
Configuration saved in my_awesome_billsum_model_bart/checkpoint-500/config.json
Model weights saved in my_awesome_billsum_model_bart/checkpoint-500/pytorch_model.bin
tokenizer config file saved in my_awesome_billsum_model_bart/checkpoint-500/tokenizer_config.json
Special tokens file saved in my_awesome_billsum_model_bart/checkpoint-500/special_tokens_map.json
The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, summary. If text, summary are not expected by `BartForConditionalGener

TrainOutput(global_step=3792, training_loss=1.7948016637488255, metrics={'train_runtime': 3554.2228, 'train_samples_per_second': 17.06, 'train_steps_per_second': 1.067, 'total_flos': 3.697197988184064e+16, 'train_loss': 1.7948016637488255, 'epoch': 4.0})

In [18]:
# BartForConditionalGeneration
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05,

In [19]:
# inference
inputs = tokenizer(
    "Shields a business entity from civil liability relating to any injury or death occurring at a facility of that entity in connection with a use of such facility by a nonprofit organization if: (1) the use occurs outside the scope of business of the business entity; (2) such injury or death occurs during a period that such facility is used by such organization; and (3) the business entity authorized the use of such facility by the organization. Makes this Act inapplicable to an injury or death that results from an act or omission of a business entity that constitutes gross negligence or intentional misconduct, including misconduct that: (1) constitutes a hate crime or a crime of violence or act of international terrorism for which the defendant has been convicted in any court; or (2) involves a sexual offense for which the defendant has been convicted in any court or misconduct for which the defendant has been found to have violated a Federal or State civil rights law. Preempts State laws to the extent that such laws are inconsistent with this Act, except State law that provides additional protection from liability. Specifies that this Act shall not be construed to supersede any Federal or State health or safety law. Makes this Act inapplicable to any civil action in a State court against a business entity in which all parties are citizens of the State if such State, citing this Act's authority and containing no other provision, enacts a statute declaring the State's election that this Act shall not apply to such action in the State.",
    return_tensors="pt").input_ids
inputs = inputs.to(model.device)
outputs = model.generate(inputs, repetition_penalty=1.0, min_length=0, max_length=50)
outputs

tensor([[    2,     0, 39278,    29,    10,   265, 10014,    31,  2366,  9416,
          8941,     7,   143,  1356,    50,   744, 14196,    23,    10,  2122,
             9,    14, 10014,    11,  2748,    19,    10,   304,     9,    10,
          2122,    30,    10,  6651,  1651,   114,    35,    36,   134,    43,
             5,   304, 11493,   751,     5,  7401,     9,   265,     9,     2]],
       device='cuda:0')

In [20]:
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Shields a business entity from civil liability relating to any injury or death occurring at a facility of that entity in connection with a use of a facility by a nonprofit organization if: (1) the use occurs outside the scope of business of
