In [2]:
import os
import sys
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from ellement.transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
import evaluate

In [37]:
train_data_path = './../../data/gen/ebay_us-train.csv'
test_data_path = './../../data/gen/ebay_us-test.csv'
# model_checkpoint = 'google/flan-t5-xl'
model_checkpoint = '/data/ebay-slc-a100/data/jingcshi/ICON_models/gen/flan-t5-xl-sota'
model_name = model_checkpoint.split("/")[-1]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_batch_size = 32
eval_batch_size = 32
num_train_epochs = 10
lr = 2e-5
lr_schedule='linear'
max_gen_length = 64
np.random.seed(114514)
torch.manual_seed(114514)

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [38]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,model_max_length=128)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
bertscore = evaluate.load('bertscore')
gleu_score = evaluate.load("google_bleu")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
def tokenize(examples):
    model_inputs = tokenizer(examples["text"])
    labels = tokenizer(examples["summary"])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [6]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    gleu = gleu_score.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
    bscore = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang='en')
    bscore['Bs-P'] = np.mean(np.array(bscore.pop('precision'))).round(6)
    bscore['Bs-R'] = np.mean(np.array(bscore.pop('recall'))).round(6)
    bscore['Bs-F1'] = np.mean(np.array(bscore.pop('f1'))).round(6)
    bscore.pop('hashcode')
    result = {**bscore, **gleu}
    return {k: round(v, 6) for k, v in result.items()}

In [14]:
train_data = pd.read_csv(train_data_path)
eval_data = pd.read_csv(test_data_path)
train_dataset = Dataset.from_pandas(train_data)
eval_dataset = Dataset.from_pandas(eval_data)
train_dataset = train_dataset.map(tokenize, batched=True)
eval_dataset = eval_dataset.map(tokenize, batched=True)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask","labels"])

Map:   0%|          | 0/20993 [00:00<?, ? examples/s]

Map:   0%|          | 0/2333 [00:00<?, ? examples/s]

In [40]:
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
args = Seq2SeqTrainingArguments(
    output_dir=f"/data/ebay-slc-a100/data/jingcshi/ICON_models/gen/{model_name}_{timestr}",
    evaluation_strategy="epoch",
    learning_rate=lr,
    lr_scheduler_type=lr_schedule,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    weight_decay=0.01,
    save_total_limit=8,
    num_train_epochs=num_train_epochs,
    generation_max_length=max_gen_length,
    predict_with_generate=True,
    logging_strategy='epoch',
    save_strategy='epoch'
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [8]:
training_outputs = trainer.train()
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
training_outputs

Epoch,Training Loss,Validation Loss,Bs-p,Bs-r,Bs-f1,Google Bleu
1,0.5348,0.23122,0.976946,0.975041,0.975907,0.636584
2,0.264,0.162499,0.981803,0.98065,0.981163,0.69664
3,0.1847,0.136677,0.9842,0.98335,0.983723,0.73585
4,0.1382,0.118561,0.985128,0.984599,0.984815,0.75582
5,0.1124,0.114898,0.98643,0.985981,0.986162,0.777375
6,0.096,0.115332,0.987039,0.986568,0.986761,0.787733
7,0.0844,0.11302,0.9873,0.986851,0.987033,0.792457
8,0.0771,0.112417,0.987413,0.987106,0.98722,0.796414


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


TrainOutput(global_step=10504, training_loss=0.18645494872163654, metrics={'train_runtime': 9297.6444, 'train_samples_per_second': 18.063, 'train_steps_per_second': 1.13, 'total_flos': 1.4368585553097523e+17, 'train_loss': 0.18645494872163654, 'epoch': 8.0})

In [43]:
output = trainer.evaluation_loop(
            trainer.get_eval_dataloader(eval_dataset),
            description="Evaluation"
        )
preds = np.where(output.predictions != -100, output.predictions, tokenizer.pad_token_id)
predictions = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
eval_result_table = pd.DataFrame({'Input':list(map(lambda string: string[11:], eval_dataset['text'])), # Strip the "summarize: " prefix
                                'Prediction': predictions, 
                                'Reference': eval_dataset['summary']})
print(output.metrics)
eval_result_table

{'eval_loss': 0.1113143041729927, 'eval_Bs-P': 0.987661, 'eval_Bs-R': 0.987425, 'eval_Bs-F1': 0.987502, 'eval_google_bleu': 0.801726}


Unnamed: 0,Input,Prediction,Reference
0,Asian Travel Maps[SEP]Travel Accessories[SEP]T...,Travel,Travel
1,Bathroom Shelves[SEP]Bath Towels & Washcloths,Bathroom Supplies & Accessories,Bathroom Supplies & Accessories
2,Collectible Police Handcuffs & Keys[SEP]Collec...,Police Collectibles,Police Collectibles
3,Other Horse Wear[SEP]Horse Lead Ropes[SEP]Equi...,Equestrian Equipment,Horse Wear
4,Pliers[SEP]Routers & Joiners,Hand Tools,Tools & Workshop Equipment
...,...,...,...
2328,Queen (Musical Artist) Apparel[SEP]Other Queen...,Root Concept,Root Concept
2329,Nail Treatment Creams[SEP]Nail Strengtheners[S...,Root Concept,Root Concept
2330,Industrial Anvils[SEP]Industrial Woodworking V...,Root Concept,Root Concept
2331,Collectible Postcards[SEP]Collectible Topograp...,Root Concept,Root Concept


In [46]:
eval_result_table.to_csv(f'./../evaluation/results/gen/{model_name}_{timestr}.csv', index=False)