# Fine-Tuning for Retrieval Tasks

This notebook demonstrates how to fine-tune a model for retrieval tasks, leveraging embeddings to find top-k relevant items. It includes:

1. Data Preprocessing
2. Model Fine-Tuning
3. Checkpoint Loading
4. Inference to Generate Top-k Results
5. Saving Results as Parquet Files


In [1]:
# Import Libraries
import os
import re
import numpy as np
import pandas as pd
import polars as pl
import torch
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics.pairwise import cosine_similarity
from pylatexenc.latex2text import LatexNodes2Text
from safetensors.torch import load_file
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import TripletLoss, MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import mine_hard_negatives
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import Dataset

# Configurations
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Use GPU 2

CONFIG = {
    "retrieve_num": 25,
    "base_lr": 2e-5,
    "mini_bs": 12,
    "bs_multi": 1,
    "base_steps": 15,
    "eval_steps": 50,
    "save_steps": 50,
    "lr_scheduler_type": "cosine_with_restarts",
    "fine_tuning": True,
    "run_evaluation": True,
    "wandb": True,
    "exp_name": (
        "SFR-Embedding-2_R_ZeroShotSelfConsistency_CleanLatex_"
        "UsingNEWLINES_IncreasedBatchSize_GradClip"
    ),
    "data_path": "/flash2/aml/lad24/dataset",
    "syn_path": "../eedi_synthetic.csv",
    "model_name": "/flash2/aml/chenjiah24_wangwd24_lad24/SFR-Embedding-2_R",
    "competition_name": "misconceptions-in-mathematics-project",
    "llm_answer_path": "/flash2/aml/wangwd24/full_prompt_initial_reasoning_SelfConsitency5.parquet",
    "output_path": ".",
    "target_column": "AllTextWithLlmMisconceptionCleaned",
    "checkpoint_path": None,
    "gradient_accumulation_steps": 8,
}

# Derived Configuration Parameters
CONFIG["bs"] = CONFIG["bs_multi"] * CONFIG["mini_bs"] * torch.cuda.device_count()
CONFIG["lr"] = max(CONFIG["base_lr"], (CONFIG["bs"] / 128) * CONFIG["base_lr"])
CONFIG["max_steps"] = int((128 / CONFIG["bs"]) * CONFIG["base_steps"])
CONFIG["model_output_path"] = f"{CONFIG['output_path']}/{CONFIG['exp_name']}"

# W&B Setup
if CONFIG["wandb"]:
    import wandb

    wandb.login()
    wandb.init(project=CONFIG["competition_name"], name=CONFIG["exp_name"])
    report_to = "wandb"
else:
    report_to = "none"


  from tqdm.autonotebook import tqdm, trange
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhectorrodriguezrodriguez52[0m ([33mhectorrodriguezrodriguez52-tsinghua-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Data Loading and Preprocessing

1. Load the training data, misconception mapping, and LLM-generated answers.
2. Convert data into a long format suitable for embeddings.
3. Clean LaTeX formatting from the text data.


In [2]:
# Data Preparation
def load_data():
    """
    Loads training data, misconception mapping, and LLM answers from predefined file paths.

    Returns:
        tuple: A tuple containing:
            - train (pl.DataFrame): Training dataset.
            - misconception_mapping (pl.DataFrame): Mapping of misconceptions.
            - llm_answers (pl.DataFrame): Answers generated by the language model.
    """
    train = pl.read_csv(f"{CONFIG['data_path']}/train.csv")
    misconception_mapping = pl.read_csv(f"{CONFIG['data_path']}/misconception_mapping.csv")
    llm_answers = pl.read_parquet(CONFIG["llm_answer_path"])
    return train, misconception_mapping, llm_answers

def convert_to_long(
    df: pl.DataFrame,
    common_cols: list = [
        "QuestionId", "ConstructName", "SubjectName", "QuestionText", "CorrectAnswer"
    ],
):
    """
    Converts the input dataframe into a long format and processes question and answer text.

    Args:
        df (pl.DataFrame): Input dataframe with wide format data.
        common_cols (list): List of common column names to retain during transformation.

    Returns:
        pl.DataFrame: A long-format dataframe with processed text and additional columns.
    """
    df_long = (
        df
        .select(
            pl.col(common_cols + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]])
        )
        .unpivot(
            index=common_cols,
            variable_name="AnswerType",
            value_name="AnswerText",
        )
        .with_columns(
            pl.col("QuestionText")
            .str.replace("\n\n", "<NEWLINE>")  # Replacing '\n\n' with '<NEWLINE>'
            .alias("QuestionText")
        )
        .with_columns(
            pl.concat_str(
                [
                    pl.format("Construct Name:\n{}\n\n", pl.col("ConstructName")),
                    pl.format("Subject Name:\n{}\n\n", pl.col("SubjectName")),
                    pl.format("Question Text:\n{}\n\n", pl.col("QuestionText")),
                    pl.format("Answer Text:\n{}\n\n", pl.col("AnswerText")),
                ],
                separator=""
            ).alias("AllText"),
            pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
        )
        .with_columns(
            pl.concat_str(
                [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
            ).alias("QuestionId_Answer"),
        )
        .sort("QuestionId_Answer")
    )

    df_misconception_long = (
        df.select(
            pl.col(
                common_cols + [f"Misconception{alpha}Id" for alpha in ["A", "B", "C", "D"]]
            )
        )
        .unpivot(
            index=common_cols,
            variable_name="MisconceptionType",
            value_name="MisconceptionId",
        )
        .with_columns(
            pl.col("MisconceptionType")
            .str.extract(r"Misconception([A-D])Id$")
            .alias("AnswerAlphabet"),
        )
        .with_columns(
            pl.concat_str(
                [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
            ).alias("QuestionId_Answer"),
        )
        .sort("QuestionId_Answer")
        .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
        .with_columns(pl.col("MisconceptionId").cast(pl.Int64))
    )

    df_long = df_long.join(df_misconception_long, on="QuestionId_Answer")
    df_long = df_long.join(misconception_mapping, on="MisconceptionId")
    return df_long

def add_llm_misconceptions(df_long: pl.DataFrame, llm_answers: pl.DataFrame) -> pl.DataFrame:
    """
    Adds LLM-generated misconceptions to the dataframe.

    Args:
        df_long (pl.DataFrame): Dataframe in long format.
        llm_answers (pl.DataFrame): Dataframe containing LLM misconceptions.

    Returns:
        pl.DataFrame: Updated dataframe with LLM misconceptions included in the text.
    """
    df_merged = df_long.join(
        llm_answers[["QuestionId_Answer", "llmMisconception"]],
        on="QuestionId_Answer",
        how="left"
    )
    df_merged = df_merged.with_columns(
        pl.concat_str(
            [pl.col("AllText"), pl.format("Misconception:\n{}", pl.col("llmMisconception"))]
        ).alias(CONFIG["target_column"])
    )
    return df_merged

def preprocess_latex_column(df, column):
    """
    Cleans and preprocesses a specific column containing LaTeX text.

    Args:
        df (pl.DataFrame): Input dataframe.
        column (str): Name of the column to preprocess.

    Returns:
        pl.DataFrame: Updated dataframe with the cleaned LaTeX column.
    """
    def clean_latex(row):
        clean_text = LatexNodes2Text().latex_to_text(row[column])
        return re.sub(r":\n+\s*", ":\n", clean_text)

    return df.with_columns([
        pl.struct(pl.all()).map_elements(clean_latex, return_dtype=pl.String).alias(column)
    ])

train, misconception_mapping, llm_answers = load_data()
train_long = convert_to_long(train)
train_long = train_long.drop_nulls("MisconceptionId")
train_long = add_llm_misconceptions(train_long, llm_answers)
train_long = preprocess_latex_column(train_long, CONFIG["target_column"])


## Data Splitting

Split the dataset into training and validation sets using `StratifiedGroupKFold`.


In [3]:
def split_train_val(data):
    """
    Splits the input data into training and validation sets using StratifiedGroupKFold.

    Args:
        data (pl.DataFrame): The dataframe to split.

    Returns:
        tuple: A tuple containing:
            - train_ds (pl.DataFrame): The training subset of the data.
            - val_ds (pl.DataFrame): The validation subset of the data.
    """
    sgkf = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=42)
    train_idx, val_idx = next(sgkf.split(data, 
                                         y=data.select(pl.concat_str([pl.col("SubjectName"), pl.col("MisconceptionId")])),
                                         groups=data["QuestionId"]))
    return data[train_idx], data[val_idx]

train_ds, val_ds = split_train_val(train_long)


train_ds = (Dataset.from_polars(train_ds))
train_ds = train_ds.rename_columns({
    CONFIG['target_column']: 'query',
    'MisconceptionName': 'answer'
})
train_ds = train_ds.select_columns(['query', 'answer'])


val_ds = (Dataset.from_polars(val_ds))
val_ds_eval = val_ds.rename_columns({
    CONFIG['target_column']: 'query',
    'MisconceptionName': 'answer'
})
val_ds_eval = val_ds_eval.select_columns(['query', 'answer'])




## Model Setup

1. Load the `SentenceTransformer` model.
2. Configure LoRA for parameter-efficient fine-tuning.
3. Add functionality to load checkpoints.


In [4]:
def load_model():
    """
    Loads and configures a SentenceTransformer model with optional PEFT configuration.

    Returns:
        SentenceTransformer: The configured SentenceTransformer model, optionally enhanced with PEFT.
    """
    model = SentenceTransformer(CONFIG["model_name"])
    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",
        ],
        lora_dropout=0.2
    )

    checkpoint_path = CONFIG["checkpoint_path"]
    if checkpoint_path:
        model = PeftModel.from_pretrained(
            model,
            checkpoint_path,
            is_trainable=True
        )
        print(f"PEFT model loaded from '{checkpoint_path}'.")
    else:
        model = get_peft_model(model, lora_config)
    return model

model = load_model()


Loading checkpoint shards: 100%|██████████| 3/3 [00:42<00:00, 14.28s/it]


## Hard negative Mining


In [5]:
# Mine hard negatives
train_ds_with_negatives = mine_hard_negatives(
    dataset=train_ds,
    model=model,
    range_min=100,            # Minimum number of negatives to consider
    range_max=1000,            # Maximum number of negatives to consider
    max_score=0.8,           # Exclude negatives with a similarity score above this threshold
    margin=0.01,              # Minimum margin between positive and negative scores
    num_negatives=5,         # Number of hard negatives to select per query
    sampling_strategy="random",  # Random sampling of negatives
    batch_size=128,          # Batch size for processing
    use_faiss=False,           # Use FAISS for fast similarity search (recommended for large datasets)
)

Found 2902 unique queries out of 2912 total queries.
Found an average of 1.003 positives per query.


Batches: 100%|██████████| 10/10 [00:10<00:00,  1.07s/it]
Batches: 100%|██████████| 23/23 [02:32<00:00,  6.65s/it]


Metric       Positive       Negative     Difference
Count           2,912         14,453               
Mean           0.7293         0.6079         0.1224
Median         0.7343         0.6065         0.1226
Std            0.0507         0.0332         0.0529
Min            0.5238         0.4963         0.0164
25%            0.6969         0.5848         0.0827
50%            0.7344         0.6065         0.1226
75%            0.7656         0.6302         0.1604
Max            0.8684         0.7345         0.3084
Skipped 211390 potential negatives (7.26%) due to the margin of 0.01.
Skipped 65 potential negatives (0.00%) due to the maximum score of 0.8.
Could not find enough negatives for 107 samples (0.73%). Consider adjusting the range_max, range_min, margin and max_score parameters if you'd like to find more valid negatives.


## Model Training

Fine-tune the model using `SentenceTransformerTrainer'.


In [6]:
def prepare_queries_and_corpus(val_ds: Dataset, misconception_mapping: pd.DataFrame, target_column: str):
    """
    Prepare queries and corpus dictionaries for evaluator and inference.

    Args:
        val_ds (Dataset): Validation dataset.
        misconception_mapping (pd.DataFrame): Mapping of misconception IDs to names.
        target_column (str): Column name containing query text.

    Returns:
        tuple: (queries, corpus) dictionaries.
    """
    # Prepare corpus from misconception mapping
    corpus = dict(
        zip(misconception_mapping["MisconceptionId"], misconception_mapping["MisconceptionName"])
    )
    
    # Prepare queries from validation dataset
    queries = dict(
        zip(val_ds["QuestionId_Answer"], val_ds[target_column])
    )

    return queries, corpus

queries, corpus = prepare_queries_and_corpus(
    val_ds=val_ds,
    misconception_mapping=misconception_mapping,
    target_column=CONFIG["target_column"]
)

In [7]:
def setup_evaluator(queries: dict, corpus: dict, misconception_mapping):
    """
    Set up the evaluator for validation.

    Args:
        queries (dict): Dictionary of query IDs to query texts.
        corpus (dict): Dictionary of corpus IDs to corpus texts.
        misconception_mapping (pd.DataFrame): Mapping of misconception IDs to names.

    Returns:
        InformationRetrievalEvaluator: Configured evaluator.
    """
    # Prepare corpus and queries
    corpus = dict(
        zip(misconception_mapping["MisconceptionId"], misconception_mapping["MisconceptionName"])
    )
    queries = dict(zip(val_ds["QuestionId_Answer"], val_ds[CONFIG["target_column"]]))
    relevant_docs = {}

    for qid, cid in zip(val_ds["QuestionId_Answer"], val_ds["MisconceptionId"]):
        qid = str(qid)
        cid = int(cid)
        if qid not in relevant_docs:
            relevant_docs[qid] = set()
        relevant_docs[qid].add(cid)
    
    evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name="eval",
        map_at_k=[3, 5, 10, 25, 50],
        precision_recall_at_k=[3, 5, 10, 25, 50],
        ndcg_at_k=[3, 5, 10, 25, 50],
        mrr_at_k=[3, 5, 10, 25, 50]
    )
    return evaluator

# Training function
def train_model(model, train_ds, evaluator, loss):
    """
    Trains the model with the given dataset, evaluator, and loss function.

    Args:
        model (SentenceTransformer): The model to train.
        train_ds (Dataset): Training dataset.
        evaluator (InformationRetrievalEvaluator): Evaluator for validation.
        loss (Loss): Loss function for training.

    Returns:
        SentenceTransformer: The trained model.
    """
    args = SentenceTransformerTrainingArguments(
        output_dir=CONFIG["output_path"],
        max_steps=CONFIG["max_steps"],
        per_device_train_batch_size=CONFIG["bs"],
        per_device_eval_batch_size=CONFIG["bs"],
        learning_rate=CONFIG["lr"],
        weight_decay=0.01,
        warmup_ratio=0.1,
        bf16=True,  # Enable BF16 training
        batch_sampler="no_duplicates",
        lr_scheduler_type=CONFIG["lr_scheduler_type"],
        eval_strategy="steps",
        eval_steps=CONFIG["eval_steps"],
        save_strategy="steps",
        save_steps=CONFIG["save_steps"],
        logging_strategy="steps",
        logging_steps=1,
        save_total_limit=4,
        load_best_model_at_end=True,
        report_to="wandb" if CONFIG["wandb"] else "none",
        run_name=CONFIG["exp_name"],
        do_eval=False,
        metric_for_best_model="eval_cosine_map@25",
        greater_is_better=True,
        gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    )

    trainer = SentenceTransformerTrainer(
        model=model, args=args, 
        train_dataset=train_ds, evaluator=evaluator, loss=loss, eval_dataset=val_ds_eval
    )
    trainer.train()
    model.save_pretrained(CONFIG["model_output_path"])
    return model

if CONFIG["fine_tuning"]:
    loss = MultipleNegativesRankingLoss(model)
    evaluator = setup_evaluator(queries, corpus, misconception_mapping)
    model = train_model(model, train_ds_with_negatives, evaluator, loss)
else:
    print("Fine-tuning is disabled. Skipping training steps.")

if CONFIG["run_evaluation"]:
    print("Running evaluation on the loaded model...")
    # Set up evaluator
    evaluator = setup_evaluator(queries, corpus, misconception_mapping)
    results = evaluator(model)
    print("Evaluation results:", results["eval_cosine_map@25"])
else:
    print("Evaluation is disabled.")


[2024-12-10 00:54:58,295] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  super().__init__(
/flash2/aml/lad24/env/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/flash2/aml/lad24/env/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/flash2/aml/lad24/env/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/flash2/aml/lad24/env/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/flash2/aml/lad24/env/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/flash2/aml/lad24/env/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/flash2/aml/lad24/env/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so

Step,Training Loss,Validation Loss


## Inference

Use the trained model to retrieve top-k relevant items and save the results as a parquet file.


In [None]:
def perform_inference(model, queries, corpus, top_k=25, output_path="./submission.parquet"):
    """
    Perform inference by computing similarity between queries and corpus to retrieve top-k results.

    Args:
        model (SentenceTransformer): The trained model for encoding queries and corpus.
        queries (dict): Dictionary where keys are query IDs and values are query texts.
        corpus (dict): Dictionary where keys are corpus IDs and values are corpus texts.
        top_k (int, optional): Number of top results to retrieve. Defaults to 25.
        output_path (str, optional): Path to save the output parquet file. Defaults to "./submission.parquet".

    Returns:
        None
    """
    if not queries or not corpus:
        raise ValueError("Queries or corpus cannot be empty.")

    print("Encoding queries...")
    query_embeddings = model.encode(
        list(queries.values()), batch_size=32, convert_to_tensor=True, device="cuda"
    )
    print("Encoding corpus...")
    corpus_embeddings = model.encode(
        list(corpus.values()), batch_size=32, convert_to_tensor=True, device="cuda"
    )

    print("Computing cosine similarity...")
    cos_sim_matrix = torch.matmul(query_embeddings, corpus_embeddings.T)

    print("Sorting results...")
    top_k_indices = torch.topk(cos_sim_matrix, k=top_k, dim=1).indices

    output_data = []
    query_keys = list(queries.keys())
    corpus_keys = list(corpus.keys())

    for i, query_id in enumerate(query_keys):
        top_items = [corpus_keys[idx] for idx in top_k_indices[i].tolist()]
        output_data.append({"QuestionId_Answer": query_id, "MisconceptionId": " ".join(map(str, top_items))})

    print(f"Saving results to {output_path}...")
    result_df = pd.DataFrame(output_data)
    result_df.to_parquet(output_path, index=False)
    print(f"Inference results saved to {output_path}.")

perform_inference(model, queries, corpus, top_k=25)