In [1]:
from datasets import load_dataset, load_metric,Dataset,DatasetDict
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer
import nltk
import numpy as np

#datasets = ["samsum","cnn","xsum"]
model_name = "facebook/bart-large"

In [2]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [3]:
raw_datasets = load_dataset("samsum")
metric = load_metric("rouge")
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)



Using the latest cached version of the module from /home/weichen/.cache/huggingface/modules/datasets_modules/datasets/samsum/3f7dba43be72ab10ca66a2e0f8547b3590e96c2bd9f2cbb1f6bb1ec1f1488ba6 (last modified on Wed Apr 28 13:59:29 2021) since it couldn't be found locally at samsum/samsum.py or remotely (ConnectionError).
Reusing dataset samsum (/home/weichen/.cache/huggingface/datasets/samsum/samsum/0.0.0/3f7dba43be72ab10ca66a2e0f8547b3590e96c2bd9f2cbb1f6bb1ec1f1488ba6)
Using the latest cached version of the module from /home/weichen/.cache/huggingface/modules/datasets_modules/metrics/rouge/2b73d5eb463209373e9d21a95decb226d4164bdca4c361b8dfad295ec82bc62e (last modified on Wed Apr 28 14:00:09 2021) since it couldn't be found locally at rouge/rouge.py or remotely (ConnectionError).


In [4]:
max_input_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = [doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [5]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [6]:
batch_size = 8
args = Seq2SeqTrainingArguments(
    "BART-LARGE-samsum",
    evaluation_strategy = "steps",
    eval_steps = 500,
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=True,
    warmup_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="eval_rouge1",
    greater_is_better=True
)

In [7]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [8]:
trainer.train()



Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len,Runtime,Samples Per Second
500,1.8287,1.504997,46.0253,23.5235,38.5538,42.3048,18.0269,55.4087,14.763
1000,1.6137,1.47132,47.1868,24.5705,39.9758,43.8132,18.3851,55.5973,14.713
1500,1.5754,1.420214,47.2808,25.5742,40.3989,43.6609,17.3166,55.4554,14.751
2000,1.4504,1.412556,48.1426,25.7698,41.3451,44.6887,17.7738,55.3234,14.786
2500,1.307,1.411188,47.6959,25.5439,40.6456,44.0342,17.6027,55.4501,14.752
3000,1.3123,1.40058,48.8291,26.6242,41.8216,44.9105,17.7139,55.2841,14.796
3500,1.2848,1.405458,49.3057,26.9458,41.9101,46.0089,17.7347,55.332,14.783
4000,1.156,1.408802,49.2716,26.8302,41.9453,45.5732,18.0391,55.352,14.778
4500,1.0638,1.421972,48.3158,26.3334,41.3479,44.7599,18.5403,55.3319,14.784
5000,1.0824,1.384098,49.7031,27.0819,42.2124,45.675,18.5122,55.8001,14.659


TrainOutput(global_step=9210, training_loss=1.144713898202107, metrics={'train_runtime': 3843.2667, 'train_samples_per_second': 2.396, 'total_flos': 7.023151183117517e+16, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': 1585229824, 'init_mem_gpu_alloc_delta': 1625367040, 'init_mem_cpu_peaked_delta': 1621676032, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 3025321984, 'train_mem_gpu_alloc_delta': 6509939712, 'train_mem_cpu_peaked_delta': 411684864, 'train_mem_gpu_peaked_delta': 8868905984})

In [9]:
out = trainer.predict(tokenized_datasets["test"],num_beams=5)

In [10]:
print(out)

PredictionOutput(predictions=array([[    2,     0,     0, ...,  1394,  6045,     2],
       [    2, 24375,     8, ...,     1,     1,     1],
       [    2,     0,     0, ...,  3045,     4,     2],
       ...,
       [    2,     0, 13012, ...,    30,  3213,     2],
       [    2,     0,     0, ...,    11,  8353,     2],
       [    2, 41415, 11210, ...,   416,  3996,     2]]), label_ids=array([[    0,   725, 25984, ...,  -100,  -100,  -100],
       [    0, 24375,     8, ...,  -100,  -100,  -100],
       [    0,   574, 11867, ...,  -100,  -100,  -100],
       ...,
       [    0, 13012,   102, ...,  -100,  -100,  -100],
       [    0,   970,    21, ...,  -100,  -100,  -100],
       [    0, 41415, 11210, ...,  -100,  -100,  -100]]), metrics={'eval_loss': 1.4359217882156372, 'eval_rouge1': 48.3409, 'eval_rouge2': 25.6455, 'eval_rougeL': 41.0306, 'eval_rougeLsum': 44.3759, 'eval_gen_len': 17.8181, 'eval_runtime': 59.4663, 'eval_samples_per_second': 13.773, 'test_mem_cpu_alloc_delta': -117116