In [None]:
import os
import sys
sys.path.append(os.getcwd() + '/..')
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
from datasets import Dataset
import evaluate

In [None]:
train_data_path = './../data/gen/google-train.csv'
test_data_path = './../data/gen/google-train.csv'
model_checkpoint = "/data2T/jingchuan/untuned/flan-t5-base/"
model_name = model_checkpoint.split("/")[-2]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_batch_size = 8
eval_batch_size = 32
num_train_epochs = 8
lr = 1e-5
lr_schedule='linear'
max_gen_length = 64
np.random.seed(114514)
torch.manual_seed(114514)

In [None]:
seq2seqmodel = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,model_max_length=128)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=seq2seqmodel,padding=True)
bertscore = evaluate.load('bertscore')
gleu_score = evaluate.load("google_bleu")

In [None]:
def tokenize(examples):
    model_inputs = tokenizer(examples["text"])
    labels = tokenizer(examples["summary"])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    gleu = gleu_score.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
    gleu['Gleu'] = gleu.pop('google_bleu')
    bscore = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang='en')
    bscore['Bs-P'] = np.mean(np.array(bscore.pop('precision'))).round(6)
    bscore['Bs-R'] = np.mean(np.array(bscore.pop('recall'))).round(6)
    bscore['Bs-F1'] = np.mean(np.array(bscore.pop('f1'))).round(6)
    bscore.pop('hashcode')
    result = {**bscore, **gleu}
    return {k: round(v, 6) for k, v in result.items()}
                                                        
def compute_metrics_plaintext(results):
    predictions = results['Prediction']
    labels = results['Reference']
    gleu = gleu_score.compute(predictions=predictions, references=[[l] for l in labels])
    gleu = {k:round(v,6) for k,v in gleu.items()}
    gleu['Gleu'] = gleu.pop('google_bleu')
    bscore = bertscore.compute(predictions=predictions, references=labels, lang='en')
    bscore['Bs-P'] = np.mean(np.array(bscore.pop('precision'))).round(6)
    bscore['Bs-R'] = np.mean(np.array(bscore.pop('recall'))).round(6)
    bscore['Bs-F1'] = np.mean(np.array(bscore.pop('f1'))).round(6)
    bscore.pop('hashcode')
    metrics = {**bscore, **gleu}
    return metrics

In [None]:
train_data = pd.read_csv(train_data_path)
eval_data = pd.read_csv(test_data_path)
train_dataset = Dataset.from_pandas(train_data)
eval_dataset = Dataset.from_pandas(eval_data)
train_dataset = train_dataset.map(tokenize, batched=True)
eval_dataset = eval_dataset.map(tokenize, batched=True)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])

In [None]:
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
args = Seq2SeqTrainingArguments(
    output_dir=f"/data2T/jingchuan/tuned/gen/{model_name}_{timestr}",
    evaluation_strategy="epoch",
    learning_rate=lr,
    lr_scheduler_type=lr_schedule,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    weight_decay=0.01,
    save_total_limit=8,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_strategy='epoch',
    save_strategy='epoch'
)

trainer = Seq2SeqTrainer(
    seq2seqmodel,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
training_outputs = trainer.train()
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
training_outputs

In [None]:
def infer(example):
    reformatted_example = {k:v for k,v in example.items() if k in ['input_ids','attention_mask','labels']}
    reformatted_example = [{k:v[i] for k,v in reformatted_example.items()} for i in range(len(example['input_ids']))]
    inputs = data_collator(reformatted_example)
    outputs = seq2seqmodel.generate(inputs['input_ids'].to(device),max_length=max_gen_length)
    predictions = tokenizer.batch_decode(outputs.cpu().numpy(), skip_special_tokens=True)
    return {'Input':tokenizer.batch_decode(example['input_ids'], skip_special_tokens=True), 'Reference':tokenizer.batch_decode(example['labels'], skip_special_tokens=True), 'Prediction':predictions}

In [None]:
eval_results = eval_dataset.map(infer,batched=True,batch_size=64,remove_columns=eval_dataset.column_names).to_pandas()
eval_results['Prediction'] = eval_results['Prediction']
metrics = compute_metrics_plaintext(eval_results)
print(f'Model: {model_name}-{timestr}')
print(metrics)
display(eval_results)
eval_results.to_csv(f'./../results/gen/{timestr}.csv',index=False)

In [None]:
seq2seqmodel.save_pretrained(f'/data2T/jingchuan/tuned/gen/{model_name}-{timestr}-sota')
tokenizer.save_pretrained(f'/data2T/jingchuan/tuned/gen/{model_name}-{timestr}-sota')