<a href="https://colab.research.google.com/github/juliawol/WB_Embedder/blob/main/WB_Giga_Embeddings_Fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict
from sklearn.metrics import ndcg_score, accuracy_score, f1_score

MODEL_NAME = "ai-sage/Giga-Embeddings-instruct"
OUTPUT_DIR = "./fine_tuned_giga_embeddings"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

DATASETS = {
    "wbqasupport_facts": "wbqasupportfacts.csv",
    "wbqasupport_short": "wbqasupportshort.csv",
    "wbsearch_query_nm": "regenerated_wbsearchquerynmrelevance.csv",
    "wbreviews": "regenerated_wbreviews.csv",
    "wbgoods_categories": "regenerated_wbgoodscategoriesclassification.csv",
    "wbsim_goods": "regenerated_wbsimgoodstriplets.csv",
    "wbgender_classification": "regenerated_wbgenderclassification.csv",
    "wbbrand_syns": "synthetic_wb_brand_syns.csv"
}

# Load each dataset into a DatasetDict
def preprocess_data(file_path, task):
    data = load_dataset("csv", data_files=file_path)

    if task == "retrieval":
        def tokenize_retrieval(example):
            if "context" in example and "question" in example:
                return tokenizer(example["question"], example["context"], truncation=True)
            elif "previous_product" in example and "next_product" in example:
                return tokenizer(example["previous_product"], example["next_product"], truncation=True)

        data = data.map(tokenize_retrieval, batched=True)

    elif task == "classification":
        def tokenize_classification(example):
            if "review_text" in example:
                return tokenizer(example["review_text"], truncation=True)
            elif "product_name" in example:
                return tokenizer(example["product_name"], truncation=True)

        data = data.map(tokenize_classification, batched=True)

    elif task == "pairwise":
        def tokenize_pairwise(example):
            return tokenizer(example["brand_name"], example["brand_synonym"], truncation=True)

        data = data.map(tokenize_pairwise, batched=True)

    return data

# Prepare datasets by tasks
data_by_task = {
    "retrieval": ["wbqasupport_facts", "wbqasupport_short", "wbsearch_query_nm", "wbsim_goods"],
    "classification": ["wbreviews", "wbgoods_categories", "wbgender_classification"],
    "pairwise": ["wbbrand_syns"]
}

prepared_datasets = {}
for task, dataset_names in data_by_task.items():
    prepared_datasets[task] = DatasetDict()
    for name in dataset_names:
        prepared_datasets[task][name] = preprocess_data(DATASETS[name], task=task)

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    load_best_model_at_end=True,
    save_total_limit=2,
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

# Define a Trainer for each group
def train_model(task, dataset_dict):
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = torch.argmax(torch.tensor(logits), dim=-1).numpy()
        labels = labels.numpy()

        if task == "retrieval":
            # Compute NDCG metrics
            ndcg1 = ndcg_score([labels], [predictions], k=1)
            ndcg10 = ndcg_score([labels], [predictions], k=10)
            return {"NDCG@1": ndcg1, "NDCG@10": ndcg10}
        elif task == "classification":
            # Compute Accuracy and F1 metrics
            accuracy = accuracy_score(labels, predictions)
            f1 = f1_score(labels, predictions, average="weighted")
            return {"accuracy": accuracy, "f1": f1}
        elif task == "pairwise":
            # Compute NDCG metrics for pairwise tasks
            ndcg1 = ndcg_score([labels], [predictions], k=1)
            ndcg10 = ndcg_score([labels], [predictions], k=10)
            return {"NDCG@1": ndcg1, "NDCG@10": ndcg10}

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset_dict[list(dataset_dict.keys())[0]]['train'],
        eval_dataset=dataset_dict[list(dataset_dict.keys())[0]]['validation'],
        compute_metrics=compute_metrics
    )
    trainer.train()
    trainer.save_model(os.path.join(OUTPUT_DIR, task))

# Train models for grouped tasks
for task, dataset_dict in prepared_datasets.items():
    train_model(task, dataset_dict)

print("Fine-tuning completed and models saved.")
