In [1]:
import os
import sys
sys.path.append(os.getcwd() + '/..')
sys.path.append(os.getcwd() + '/../..')
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
from datasets import Dataset
import evaluate

In [5]:
train_data_path = './../../data/gen/ebay_us-train.csv'
test_data_path = './../../data/gen/ebay_us-test.csv'
model_checkpoint = "/data2T/jingchuan/tuned/gen/flan-t5-base_20231208-1351/checkpoint-10931/"
model_name = model_checkpoint.split("/")[-2]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_batch_size = 8
eval_batch_size = 32
num_train_epochs = 8
lr = 1e-5
lr_schedule='linear'
max_gen_length = 64
np.random.seed(114514)
torch.manual_seed(114514)

<torch._C.Generator at 0x7eff2ed109b0>

In [3]:
seq2seqmodel = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,model_max_length=128)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=seq2seqmodel,padding=True)
bertscore = evaluate.load('bertscore')
gleu_score = evaluate.load("google_bleu")

In [6]:
def tokenize(examples):
    model_inputs = tokenizer(examples["text"])
    labels = tokenizer(examples["summary"])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [7]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    gleu = gleu_score.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
    gleu['Gleu'] = gleu.pop('google_bleu')
    bscore = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang='en')
    bscore['Bs-P'] = np.mean(np.array(bscore.pop('precision'))).round(6)
    bscore['Bs-R'] = np.mean(np.array(bscore.pop('recall'))).round(6)
    bscore['Bs-F1'] = np.mean(np.array(bscore.pop('f1'))).round(6)
    bscore.pop('hashcode')
    result = {**bscore, **gleu}
    return {k: round(v, 6) for k, v in result.items()}
                                                        
def compute_metrics_plaintext(results):
    predictions = results['Prediction']
    labels = results['Reference']
    gleu = gleu_score.compute(predictions=predictions, references=[[l] for l in labels])
    gleu = {k:round(v,6) for k,v in gleu.items()}
    gleu['Gleu'] = gleu.pop('google_bleu')
    bscore = bertscore.compute(predictions=predictions, references=labels, lang='en')
    bscore['Bs-P'] = np.mean(np.array(bscore.pop('precision'))).round(6)
    bscore['Bs-R'] = np.mean(np.array(bscore.pop('recall'))).round(6)
    bscore['Bs-F1'] = np.mean(np.array(bscore.pop('f1'))).round(6)
    bscore.pop('hashcode')
    metrics = {**bscore, **gleu}
    return metrics

In [8]:
train_data = pd.read_csv(train_data_path)
eval_data = pd.read_csv(test_data_path)
train_dataset = Dataset.from_pandas(train_data)
eval_dataset = Dataset.from_pandas(eval_data)
train_dataset = train_dataset.map(tokenize, batched=True)
eval_dataset = eval_dataset.map(tokenize, batched=True)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])

Map:   0%|          | 0/20553 [00:00<?, ? examples/s]

Map:   0%|          | 0/2283 [00:00<?, ? examples/s]

In [9]:
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
args = Seq2SeqTrainingArguments(
    output_dir=f"/data2T/jingchuan/tuned/gen/{model_name}_{timestr}",
    evaluation_strategy="epoch",
    learning_rate=lr,
    lr_scheduler_type=lr_schedule,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    weight_decay=0.01,
    save_total_limit=8,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_strategy='epoch',
    save_strategy='epoch'
)

trainer = Seq2SeqTrainer(
    seq2seqmodel,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [10]:
training_outputs = trainer.train()
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
training_outputs

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bs-p,Bs-r,Bs-f1,Gleu
1,0.4965,0.441173,0.964593,0.962788,0.963542,0.598019
2,0.4779,0.438278,0.96514,0.963454,0.964151,0.604361
3,0.4613,0.428093,0.965313,0.963583,0.9643,0.60423
4,0.4472,0.425803,0.965335,0.964002,0.964525,0.605386
5,0.4379,0.421037,0.965543,0.964296,0.964778,0.610057
6,0.4301,0.419362,0.965804,0.964412,0.964965,0.611327
7,0.4233,0.419048,0.965504,0.964394,0.964806,0.610288
8,0.4211,0.419539,0.965799,0.96459,0.965053,0.613549


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


TrainOutput(global_step=10280, training_loss=0.4493949266723158, metrics={'train_runtime': 26932.6927, 'train_samples_per_second': 6.105, 'train_steps_per_second': 0.382, 'total_flos': 9360690516882432.0, 'train_loss': 0.4493949266723158, 'epoch': 8.0})

In [11]:
def infer(example):
    reformatted_example = {k:v for k,v in example.items() if k in ['input_ids','attention_mask','labels']}
    reformatted_example = [{k:v[i] for k,v in reformatted_example.items()} for i in range(len(example['input_ids']))]
    inputs = data_collator(reformatted_example)
    outputs = seq2seqmodel.generate(inputs['input_ids'].to(device),max_length=max_gen_length)
    predictions = tokenizer.batch_decode(outputs.cpu().numpy(), skip_special_tokens=True)
    return {'Input':tokenizer.batch_decode(example['input_ids'], skip_special_tokens=True), 'Reference':tokenizer.batch_decode(example['labels'], skip_special_tokens=True), 'Prediction':predictions}

In [12]:
eval_results = eval_dataset.map(infer,batched=True,batch_size=64,remove_columns=eval_dataset.column_names).to_pandas()
eval_results['Prediction'] = eval_results['Prediction']
metrics = compute_metrics_plaintext(eval_results)
print(f'Model: {model_name}-{timestr}')
print(metrics)
display(eval_results)
eval_results.to_csv(f'./../results/gen/{timestr}.csv',index=False)

Map:   0%|          | 0/2283 [00:00<?, ? examples/s]

Model: checkpoint-10931-20231209-0719
{'Bs-P': 0.965864, 'Bs-R': 0.964687, 'Bs-F1': 0.965134, 'Gleu': 0.616313}


Unnamed: 0,Input,Reference,Prediction
0,summarize: Original Beanie Babies; Retired Ty ...,Ty Beanbag Plushies,Beanbag Plushies
1,summarize: Needlepoint Kits; Ribbon Embroidery...,Hand Embroidery Sets & Kits,Embroidery & Cross Stitch Supplies
2,summarize: Ethnic Americana Collectibles; Coll...,Ethnic & Cultural Collectibles,Ethnic & Cultural Collectibles
3,"summarize: Industrial Rock, Gravel & Sand; Ind...","Industrial Cement, Concrete & Masonry",Industrial Building Materials & Supplies
4,summarize: Bowling Clothing; Youth Bowling Clo...,Bowling Clothing,Bowling Clothing
...,...,...,...
2278,summarize: Game Used NFL Jerseys; Game Used NF...,Root Concept,Root Concept
2279,summarize: Industrial Wood Composite Panels & ...,Root Concept,Root Concept
2280,summarize: Other Tesla Cars & Trucks; Tesla Ro...,Root Concept,Root Concept
2281,summarize: Women's Golf Socks; Women's Golf Co...,Root Concept,Root Concept


In [13]:
seq2seqmodel.save_pretrained(f'/data2T/jingchuan/tuned/gen/{model_name}-{timestr}-sota')
tokenizer.save_pretrained(f'/data2T/jingchuan/tuned/gen/{model_name}-{timestr}-sota')

('/data2T/jingchuan/tuned/gen/flan-t5-sota/tokenizer_config.json',
 '/data2T/jingchuan/tuned/gen/flan-t5-sota/special_tokens_map.json',
 '/data2T/jingchuan/tuned/gen/flan-t5-sota/tokenizer.json')