In [None]:
import os
import sys
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
import evaluate

In [None]:
train_data_path = './../../data/gen/google-train.csv'
test_data_path = './../../data/gen/google-test.csv'
model_checkpoint = 'YOUR_MODEL_CHECKPOINT'
model_name = model_checkpoint.split("/")[-1]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_batch_size = 32
eval_batch_size = 32
num_train_epochs = 10
lr = 2e-5
lr_schedule='linear'
max_gen_length = 64
np.random.seed(114514)
torch.manual_seed(114514)

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,model_max_length=128)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
bertscore = evaluate.load('bertscore')
gleu_score = evaluate.load("google_bleu")

In [None]:
def tokenize(examples):
    model_inputs = tokenizer(examples["text"])
    labels = tokenizer(examples["summary"])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    gleu = gleu_score.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
    bscore = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang='en')
    bscore['Bs-P'] = np.mean(np.array(bscore.pop('precision'))).round(6)
    bscore['Bs-R'] = np.mean(np.array(bscore.pop('recall'))).round(6)
    bscore['Bs-F1'] = np.mean(np.array(bscore.pop('f1'))).round(6)
    bscore.pop('hashcode')
    result = {**bscore, **gleu}
    return {k: round(v, 6) for k, v in result.items()}

In [None]:
train_data = pd.read_csv(train_data_path)
eval_data = pd.read_csv(test_data_path)
train_dataset = Dataset.from_pandas(train_data)
eval_dataset = Dataset.from_pandas(eval_data)
train_dataset = train_dataset.map(tokenize, batched=True)
eval_dataset = eval_dataset.map(tokenize, batched=True)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])

In [None]:
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
args = Seq2SeqTrainingArguments(
    output_dir=f'YOUR_PATH/{model_name}-{timestr}',
    evaluation_strategy="epoch",
    learning_rate=lr,
    lr_scheduler_type=lr_schedule,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    weight_decay=0.01,
    save_total_limit=8,
    num_train_epochs=num_train_epochs,
    generation_max_length=max_gen_length,
    predict_with_generate=True,
    logging_strategy='epoch',
    save_strategy='epoch'
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
training_outputs = trainer.train()
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
training_outputs

In [None]:
output = trainer.evaluation_loop(
            trainer.get_eval_dataloader(eval_dataset),
            description="Evaluation"
        )
preds = np.where(output.predictions != -100, output.predictions, tokenizer.pad_token_id)
predictions = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
eval_result_table = pd.DataFrame({'Input':list(map(lambda string: string[11:], eval_dataset['text'])), # Strip the "summarize: " prefix
                                'Prediction': predictions, 
                                'Reference': eval_dataset['summary']})
print(output.metrics)
eval_result_table