In [1]:
from datasets import load_dataset, load_metric,Dataset,DatasetDict
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer
import nltk
import numpy as np

#datasets = ["samsum","cnn","xsum"]
model_name = "google/pegasus-xsum"

In [2]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [3]:
raw_datasets = load_dataset("samsum")
metric = load_metric("rouge")

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)



Reusing dataset samsum (/home/weichen/.cache/huggingface/datasets/samsum/samsum/0.0.0/3f7dba43be72ab10ca66a2e0f8547b3590e96c2bd9f2cbb1f6bb1ec1f1488ba6)
Using the latest cached version of the module from /home/weichen/.cache/huggingface/modules/datasets_modules/metrics/rouge/2b73d5eb463209373e9d21a95decb226d4164bdca4c361b8dfad295ec82bc62e (last modified on Wed Apr 28 14:00:09 2021) since it couldn't be found locally at rouge/rouge.py or remotely (ConnectionError).


In [4]:
max_input_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = [doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
batch_size = 8
args = Seq2SeqTrainingArguments(
    "pegasus-xum-samsum",
    evaluation_strategy = "steps",
    eval_steps = 500,
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
    warmup_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="eval_rouge1",
    greater_is_better=True
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [5]:
trainer.train()

Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len,Runtime,Samples Per Second
500,,1.602473,46.7618,22.8783,39.1641,42.7114,16.4768,141.2847,5.79
1000,,1.539185,49.7257,24.4956,41.074,45.3839,20.5073,167.0176,4.898
1500,,1.504752,50.4609,25.462,41.9416,46.0296,19.6516,162.3277,5.039
2000,,1.494941,50.1915,25.6302,41.767,45.8472,18.9462,161.6146,5.061
2500,,1.493194,49.9111,25.9992,42.0325,45.7736,17.1467,160.9986,5.081
3000,,1.495335,50.1104,25.9026,42.1324,46.0669,17.9352,151.779,5.389
3500,,1.508917,50.7467,26.3707,42.2735,46.6873,20.8203,169.3891,4.829
4000,,1.522761,47.1751,24.1726,39.8049,43.0879,15.2494,145.8609,5.608
4500,,1.518134,47.2448,24.3726,39.8453,43.1856,15.0611,143.449,5.702
5000,,1.516037,47.2404,24.2761,39.8543,43.1117,15.0868,144.6217,5.656


TrainOutput(global_step=7368, training_loss=nan, metrics={'train_runtime': 4718.61, 'train_samples_per_second': 1.561, 'total_flos': 6.981073885384704e+16, 'epoch': 4.0, 'init_mem_cpu_alloc_delta': 1220816896, 'init_mem_gpu_alloc_delta': 2280005120, 'init_mem_cpu_peaked_delta': 1463779328, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 2706505728, 'train_mem_gpu_alloc_delta': 9117943808, 'train_mem_cpu_peaked_delta': 1948725248, 'train_mem_gpu_peaked_delta': 11142844416})

In [6]:
out = trainer.predict(tokenized_datasets["test"],num_beams=5)

In [7]:
print(out)

PredictionOutput(predictions=array([[    0, 12195,   591, ...,     0,     0,     0],
       [    0,  6303,   111, ...,     0,     0,     0],
       [    0, 43880,   138, ...,     0,     0,     0],
       ...,
       [    0, 42828, 11586, ...,     0,     0,     0],
       [    0,   353,   140, ...,     0,     0,     0],
       [    0, 38772,   148, ...,     0,     0,     0]]), label_ids=array([[12636,   397, 17379, ...,  -100,  -100,  -100],
       [ 6303,   111,  7374, ...,  -100,  -100,  -100],
       [43880,   137,   131, ...,  -100,  -100,  -100],
       ...,
       [42828,  1406,   114, ...,  -100,  -100,  -100],
       [  353,   140,   114, ...,  -100,  -100,  -100],
       [ 9199,  9274,   114, ...,  -100,  -100,  -100]]), metrics={'eval_loss': 1.5203441381454468, 'eval_rouge1': 50.6243, 'eval_rouge2': 26.3668, 'eval_rougeL': 42.2809, 'eval_rougeLsum': 46.4494, 'eval_gen_len': 21.8059, 'eval_runtime': 150.3924, 'eval_samples_per_second': 5.446, 'test_mem_cpu_alloc_delta': 782336,