In [None]:
#!pip install datasets
#!pip install py7zr
#! pip install -U accelerate
#! pip install -U transformers

#train

In [None]:
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer ,DataCollatorForSeq2Seq
from datasets import load_dataset, load_from_disk, load_metric
import numpy as np
import nltk
nltk.download('punkt')


In [None]:
def preprocess_data_cnn(data_to_process):
    #get all the dialogues
    inputs = [dialogue for dialogue in data_to_process['article']]
    #tokenize the dialogues
    model_inputs = tokenizer(inputs,  max_length=max_input, padding='max_length', truncation=True)
    #tokenize the summaries
    with tokenizer.as_target_tokenizer():
        targets = tokenizer(data_to_process['highlights'], max_length=max_target, padding='max_length', truncation=True)

    #set labels
    model_inputs['labels'] = targets['input_ids']
    #return the tokenized data
    #input_ids, attention_mask and labels
    return model_inputs

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [None]:
metric = load_metric("rouge")
max_input = 512
max_target = 128
batch_size = 8
model_checkpoints = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoints)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoints)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
train_dataset_cnn = load_from_disk('/content/drive/MyDrive/dataset/cnn/train')
test_dataset_cnn = load_from_disk('/content/drive/MyDrive/dataset/cnn/test')
validation_dataset_cnn = load_from_disk('/content/drive/MyDrive/dataset/cnn/validation')
tokenize_data_cnn_train = train_dataset_cnn.map(preprocess_data_cnn, batched = True)
tokenize_data_cnn_test = test_dataset_cnn.map(preprocess_data_cnn, batched = True)
tokenize_data_cnn_validation = validation_dataset_cnn.map(preprocess_data_cnn, batched = True)

In [None]:
print(test_dataset_cnn[8])

{'article': '(CNN)Filipinos are being warned to be on guard for flash floods and landslides as tropical storm Maysak approached the Asian island nation Saturday. Just a few days ago, Maysak gained super typhoon status thanks to its sustained 150 mph winds. It has since lost a lot of steam as it has spun west in the Pacific Ocean. It\'s now classified as a tropical storm, according to the Philippine national weather service, which calls it a different name, Chedeng. It boasts steady winds of more than 70 mph (115 kph) and gusts up to 90 mph as of 5 p.m. (5 a.m. ET) Saturday. Still, that doesn\'t mean Maysak won\'t pack a wallop. Authorities took preemptive steps to keep people safe such as barring outdoor activities like swimming, surfing, diving and boating in some locales, as well as a number of precautionary evacuations. Gabriel Llave, a disaster official, told PNA that tourists who arrive Saturday in and around the coastal town of Aurora "will not be accepted by the owners of hotels

In [None]:
args = Seq2SeqTrainingArguments(
    '/storage/changyu/results/bart', #save directory
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=3,
    weight_decay=0.01,
    num_train_epochs=10,
    predict_with_generate=True,
    eval_accumulation_steps=3,
    fp16=True, #available only with CUDA
    save_steps=500,
    save_total_limit=10,
    logging_first_step=True,
    logging_steps=500,
    #generation_max_length=128,
    )


trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenize_data_cnn_train, #tokenize_data_cnn_train,
    eval_dataset=tokenize_data_cnn_validation, #tokenize_data_cnn_validation,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


In [None]:
trainer.train()

evaluate model

In [None]:
#evaluate model on test dataset
trainer = Seq2SeqTrainer(
    model,
    args,
    #train_dataset=tokenize_data_cnn_train, #tokenize_data_cnn_train,
    eval_dataset=tokenize_data_cnn_test,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.evaluate()

show some examples

In [None]:
test_case = test_dataset_cnn[0]["article"]
model_inputs = tokenizer(test_case,  max_length=max_input, padding='max_length', truncation=True)
raw_pred, _, _ = trainer.predict([model_inputs])
print(tokenizer.decode(raw_pred[0]))

The first example result:
The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January.

In [None]:
test_case1 = test_dataset_cnn[1]["article"]
model_inputs1 = tokenizer(test_case1,  max_length=max_input, padding='max_length', truncation=True)
raw_pred1, _, _ = trainer.predict([model_inputs1])
print(tokenizer.decode(raw_pred1[0]))

The second result: Mohammad Javad Zarif is the Iranian foreign minister. He has been U.S. Secretary of State John Kerry's opposite number in nuclear talks. Zarif has gone a long way to bring Iran in from the cold. Mohammad Javad Zarif is the Iranian foreign minister. He was nominated to be foreign minister by Ahmadinejad's successor, Hassan Rouhami. Zarif was outside the country during the demonstrations against the Shah of Iran.