# Embedding Models Evaluation

In [None]:
import os
import json
import pandas as pd
import datetime
from typing import List, Dict
import torch
import gc
import re

from sentence_transformers import SentenceTransformer

from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores.faiss import FAISS
from langchain.embeddings.base import Embeddings

##################################################
# Custom SentenceTransformer Embeddings
##################################################
class CustomSentenceTransformerEmbeddings(Embeddings):
    """
    Allows using a SentenceTransformer model within a LangChain-based FAISS store.
    Handles initialization of different models with specific arguments.
    """

    def __init__(self, embedding_model_name: str):
        self.embedding_model_name = embedding_model_name
        self.model = self._initialize_model()

    def _initialize_model(self) -> SentenceTransformer:
        """
        Initializes the SentenceTransformer model based on the embedding_model_name.
        """
        # Define initialization configurations for each model
        model_configs = {
            "jinaai/jina-embeddings-v3": {
                "trust_remote_code": True,
                "revision": "main",
                "device": "cuda",
                "model_kwargs": {"use_flash_attn": False},
            },
            "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5": {
                "local_files_only": True,
                "device": "cuda",
                "model_kwargs": {"attn_implementation": "eager"},
            },
            "Alibaba-NLP/gte-large-en-v1.5": {
                "trust_remote_code": True,
                "revision": "main",
                "device": "cuda",
                "model_kwargs": {"attn_implementation": "eager"},
            },
        }

        config = model_configs.get(
            self.embedding_model_name,
            {"device": "cuda", "model_kwargs": {}},  # default fallback
        )

        try:
            model = SentenceTransformer(self.embedding_model_name, **config)
            print(f"Initialized SentenceTransformer model: {self.embedding_model_name}")
            return model
        except Exception as e:
            print(f"Error initializing model {self.embedding_model_name}: {e}")
            raise

    def embed_query(self, text: str) -> List[float]:
        return self.model.encode(text).tolist()

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self.model.encode(texts).tolist()

    def unload_model(self):
        """
        Remove the model from memory after processing to free up GPU resources.
        """
        if self.model:
            del self.model
            self.model = None
            torch.cuda.empty_cache()
            gc.collect()
            print(f"Unloaded embedding model: {self.embedding_model_name}")
        else:
            print("[DEBUG] Embedding model was already None or not set.")


################################
# Evaluation Helper Functions
################################
def strip_content_between_dashes(chunk):
    """
    Removes content between `---` markers, inclusive, if it exists in the chunk.
    """
    if "---" in chunk:
        return re.sub(r"---.*?---", "", chunk, flags=re.DOTALL).strip()
    return chunk


def evaluate_retrieval(retrieved_chunks, expected_chunks, comparison_length=100):
    """
    Strips content between `---` markers in the retrieved chunks, then compares
    them with the expected chunks based on the first `comparison_length` characters.
    """
    processed_retrieved_chunks = [
        strip_content_between_dashes(chunk) for chunk in retrieved_chunks
    ]

    expected_prefixes = set([c[:comparison_length].strip() for c in expected_chunks])
    retrieved_prefixes = set([c[:comparison_length].strip() for c in processed_retrieved_chunks])

    correctly_retrieved = expected_prefixes.intersection(retrieved_prefixes)
    missed_chunks = expected_prefixes.difference(retrieved_prefixes)

    print(f"\nExact-Match Evaluation (First {comparison_length} characters):")
    print(f"  Correctly retrieved chunks: {len(correctly_retrieved)}/{len(expected_chunks)}")
    for prefix in correctly_retrieved:
        print(f"    ✔ Retrieved Prefix: {prefix}")

    print(f"  Missed chunks: {len(missed_chunks)}/{len(expected_chunks)}")
    for prefix in missed_chunks:
        print(f"    ✘ Missed Prefix: {prefix}")

    return correctly_retrieved, missed_chunks


def evaluate_ranked_retrieval(docs_with_ranks, expected_chunks, comparison_length=100):
    """
    Rank-based check: for each expected chunk, see at which rank its prefix appears
    in the retrieved list, then compute Mean Rank and MRR.
    """
    processed_docs_with_ranks = [
        (rank, strip_content_between_dashes(doc.page_content))
        for rank, doc in docs_with_ranks
    ]
    processed_expected_chunks = [
        strip_content_between_dashes(chunk) for chunk in expected_chunks
    ]

    ranks = []
    for chunk in processed_expected_chunks:
        chunk_prefix = chunk[:comparison_length].strip()
        matched_positions = [
            rank
            for (rank, doc) in processed_docs_with_ranks
            if doc[:comparison_length].strip() == chunk_prefix
        ]
        if matched_positions:
            ranks.append(matched_positions[0])
        else:
            ranks.append(None)

    found_ranks = [r for r in ranks if r is not None]
    missed_count = sum(r is None for r in ranks)

    if found_ranks:
        if len(found_ranks) == 1:
            mean_rank = found_ranks[0]
            rank = int(mean_rank)
        else:
            mean_rank = sum(found_ranks) / len(found_ranks)
            rank = None
        mrr = sum((1.0 / r) for r in found_ranks) / len(found_ranks)
    else:
        mean_rank = None
        rank = None
        mrr = None

    print(f"\nRanked Evaluation (First {comparison_length} characters):")
    print(f"  Found: {len(found_ranks)}/{len(expected_chunks)}")
    print(f"  Missed: {missed_count}/{len(expected_chunks)}")
    if mean_rank is not None:
        print(f"  Mean Rank: {mean_rank:.2f}")
    else:
        print(f"  Mean Rank: N/A")
    if rank is not None:
        print(f"  Rank: {rank}")
    else:
        print(f"  Rank: N/A")
    if mrr is not None:
        print(f"  MRR: {mrr:.3f}")
    else:
        print(f"  MRR: N/A")

    return mean_rank, rank, mrr


def save_csv(evaluation_data, k_value, lambda_mult, output_csv_path):
    df_eval = pd.DataFrame(evaluation_data)
    df_eval["k_value"] = k_value
    df_eval["lambda_mult"] = lambda_mult

    columns_order = [
        "model",
        "row_index",
        "product_name",
        "k_value",
        "lambda_mult",
        "retrieved_docs",
        "found_exact",
        "missed_exact",
        "mean_rank",
        "rank",
        "mrr",
    ]
    df_eval = df_eval.reindex(columns=columns_order, fill_value="N/A")

    os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
    if not os.path.exists(output_csv_path):
        df_eval.to_csv(output_csv_path, index=False, mode="w", encoding="utf-8")
    else:
        df_eval.to_csv(output_csv_path, index=False, mode="a", header=False, encoding="utf-8")


############################################################
# compare_rows: For each row in the CSV (first 100),
# build a query using Product Name + Classification +
# Technology Description. Then retrieve & evaluate.
############################################################
def compare_rows(
    df: pd.DataFrame,
    bench_dict: Dict[str, str],
    embedding_model_name: str,
    schema: str,
    k_value: int,
    output_csv_path: str,
    comparison_length: int = 100,
    fetch_k: int = 100,
    lambda_mult: float = 0.8,
):
    """
    For each row in the DataFrame, build a query that includes:
      - Product Name
      - Classification
      - Technology Description (summarized)
    Then retrieve documents using the vectorstore retriever and
    evaluate them against the expected chunk from bench_dict,
    keyed by the Product Name.
    """
    evaluation_data = []

    for idx, row in df.iterrows():
        # Use for JSON
        product_name = row["Product Name"]
        # classification = row["Classification Suggestion"]
        classification = row["Classification Result"]
        # tech_description = row["Technology Description Summary"]
        tech_description = row["Technology Description"]
        # comp_materials = row["Extracted Composition"]

        # Use for CSV
        product_name = row["Product Name"]
        classification = row["Classification Result"]
        tech_description = row["Technology Description"]

        # Build the query from the entire row
        query_str = (
            # f"Product Name: {product_name}\n"
            # f"Classification: {classification}\n"
            # f"Technology Description: {tech_description}\n"
            # f"Composition Materials: {comp_materials}"
            f"Product: {product_name}\n"
            f"Description: {tech_description}\n"
            # f"{tech_description}\n"
        )

        if classification != "Other":
            query_str += f"Classification: {classification}\n"

        # The expected chunk for ranking/evaluation:
        # If the product name doesn't exist in bench_dict,
        # we treat the expected chunk as empty or skip evaluation.
        expected_chunk = bench_dict.get(product_name.lower(), "")

        # If there's no known expected chunk, we can optionally skip or do partial eval
        if not expected_chunk:
            print(f"[Warning] No expected benchmark info for product '{product_name}'")
            continue

        print(f"\n[Model: {embedding_model_name}]")
        print(f"[Row Index: {idx}]")
        print(f"Query:\n{query_str}")  # Print query for inspection

        # Retrieve using MMR-based approach
        retriever = vectorstore.as_retriever(
            search_type="mmr",
            search_kwargs={
                # "filter": {"schema_type": schema}, # There is no filter
                "k": k_value,
                "fetch_k": fetch_k,
                "lambda_mult": lambda_mult,
            },
        )

        retrieved_docs = retriever.invoke(query_str)
        print(f"Number of retrieved documents: {len(retrieved_docs)}")

        retrieved_chunks = [doc.page_content for doc in retrieved_docs]

        # Evaluate retrieval
        found_exact, missed_chunks = evaluate_retrieval(
            retrieved_chunks, [expected_chunk], comparison_length=comparison_length
        )

        # Evaluate ranking
        docs_with_ranks = list(enumerate(retrieved_docs, start=1))
        mean_rank, rank, mrr = evaluate_ranked_retrieval(
            docs_with_ranks, [expected_chunk], comparison_length=comparison_length
        )

        # Show top 5 for inspection (optional)
        top_docs = retrieved_chunks[:5]
        processed_top_docs = [strip_content_between_dashes(doc) for doc in top_docs]
        print(f"\n[Top 5 Retrieved Docs]")
        for i, doc_text in enumerate(processed_top_docs, start=1):
            print(f"  {i}. {doc_text}")

        # Collect the evaluation results
        evaluation_data.append(
            {
                "model": embedding_model_name,
                "row_index": idx,
                "product_name": product_name,
                "retrieved_docs": len(retrieved_docs),
                "found_exact": len(found_exact),
                "missed_exact": len(missed_chunks),
                "mean_rank": mean_rank,
                "rank": rank,
                "mrr": mrr,
            }
        )

    # Save results for all rows
    save_csv(evaluation_data, k_value, lambda_mult, output_csv_path)

##################################################
# Load the JSON summaries
##################################################
# summaries_path = "../../data/pipeline2/json/100_tech_sum_one_paragraph_regex.json"
# summaries_path = "../../data/pipeline2/json/100_technology_compositions.json"
# with open(summaries_path, "r", encoding="utf-8") as f:
#     tech_summaries = json.load(f)

# Create a DataFrame from the JSON data
# df = pd.DataFrame(tech_summaries)

##################################################
# Load the CSV raw data
##################################################
summaries_path = "../../data/pipeline2/sql/filtered_epd_data02.csv"

# Load the CSV file into a DataFrame
df = pd.read_csv(summaries_path)


##################################################
# Load the JSON benchmark data (100_matched_bench.json)
# and map each "epd_name" to "epd_category" to use
# as our "expected chunk" for that product.
##################################################
bench_path = "../../data/pipeline2/json/200_matched_bench_EN.json"
with open(bench_path, "r", encoding="utf-8") as f:
    bench_data = json.load(f)

bench_dict = {
    item["epd_name"]: item["epd_category"]
    for item in bench_data
}


#########
# Usage
#########
k_values = [10, 20, 30, 40, 50]
# k_values = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80]
# k_values = [40, 50, 60, 70, 80]
# lambda_mults = [0.7, 0.8, 0.9, 1.0]
lambda_mults = [1.0]
# lambda_mults = [1.0]

# Paths for the FAISS vectorstores (one for each model)
embeddings_name = "faiss_index_COS_EN"
embeddings_name_512 = "faiss_index_COS_EN"
vectorstore_paths = {
    # "bge-m3:latest": f"../../embeddings/pipeline2/bge-m3/{embeddings_name}",
    # "snowflake-arctic-embed2:latest": f"../../embeddings/pipeline2/snowflake-arctic-embed2/{embeddings_name}",
    # "jina/jina-embeddings-v2-base-de:latest": f"../../embeddings/pipeline2/jina_jina-embeddings-v2-base-de/{embeddings_name}",
    # "paraphrase-multilingual:latest": f"../../embeddings/pipeline2/paraphrase-multilingual/{embeddings_name}",
    # "jeffh/intfloat-multilingual-e5-large-instruct:f32": f"../../embeddings/pipeline2/jeffh_intfloat-multilingual-e5-large-instruct/{embeddings_name_512}",
    # "granite-embedding:278m": f"../../embeddings/pipeline2/granite-embedding/{embeddings_name_512}",
    # "bge-large:latest": f"../../embeddings/pipeline2/bge-large/{embeddings_name_512}",
    "mxbai-embed-large:latest": f"../../embeddings/pipeline2/mxbai-embed-large/{embeddings_name_512}",
    # "jinaai/jina-embeddings-v3": f"../../embeddings/pipeline2/jinaai_jina-embeddings-v3/{embeddings_name}",
    # "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5": f"../../embeddings/pipeline2/HIT-TMG_KaLM-embedding-multilingual-mini-instruct-v1.5/{embeddings_name}",
    # "Alibaba-NLP/gte-large-en-v1.5": f"../../embeddings/pipeline2/Alibaba-NLP_gte-large-en-v1.5/{embeddings_name_512}",
}

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_csv_path = f"../../data/pipeline2/embed_eval/embed_eval_res_{timestamp}.csv"

# We'll evaluate row-by-row for each embedding model
schema_type = "epd_bench_schema"  # a placeholder schema type for filtering
for embedding_model_name, faiss_path in vectorstore_paths.items():
    print(f"\n=== Processing Embedding Model: {embedding_model_name} ===")

    # Choose embeddings class
    if embedding_model_name in [
        "jinaai/jina-embeddings-v3",
        "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5",
        "Alibaba-NLP/gte-large-en-v1.5",
    ]:
        embeddings = CustomSentenceTransformerEmbeddings(embedding_model_name)
    else:
        embeddings = OllamaEmbeddings(model=embedding_model_name)
        print(f"Initialized Ollama embedding model: {embedding_model_name}")

    # Load vectorstore
    vectorstore = FAISS.load_local(
        faiss_path,
        embeddings=embeddings,
        allow_dangerous_deserialization=True,
    )

    for lm in lambda_mults:
        print(f"\n--- Evaluating with lambda_mult = {lm} ---")
        for kv in k_values:
            print(f"\n--- Evaluating with k_value = {kv} ---")
            compare_rows(
                df=df,
                bench_dict=bench_dict,
                embedding_model_name=embedding_model_name,
                schema=schema_type,
                k_value=kv,
                output_csv_path=output_csv_path,
                comparison_length=100,
                fetch_k=100,
                lambda_mult=lm,
            )

    # Unload the model to free up GPU memory
    if isinstance(embeddings, CustomSentenceTransformerEmbeddings):
        embeddings.unload_model()

print("\n>>> All embedding models have been processed and unloaded.")
print(f"\n>>> Completed. See results in {output_csv_path}.")


## Evaluation Visualizations Embedding Models


In [None]:
CSV = "../../data/pipeline2/embed_eval/embed_eval_res_20250208_130510.csv"

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tabulate import tabulate
import os


def abbreviate_model_name(name):
    """
    Abbreviates the model name to show only the first 5 and last 5 characters.

    Parameters:
        name (str): The original model name.

    Returns:
        str: The abbreviated model name.
    """
    if len(name) <= 10:
        return name
    return f"{name[:10]}...{name[-6:]}"


def load_and_prepare_data(csv_file_path):
    """
    Loads the CSV data, performs basic cleanup, converts specified columns to numeric,
    and organizes the 'model' and 'lambda_mult' columns to preserve their order of appearance.

    Parameters:
        csv_file_path (str): The file path to the CSV file.

    Returns:
        df (pd.DataFrame): The cleaned and prepared DataFrame.
        model_order (list): List of models in the order they appear in the CSV.
        lambda_order (list): List of lambda_mult values in sorted order.
    """
    # 1) Load Data
    df = pd.read_csv(csv_file_path)

    # A. Basic Cleanup & Info
    print(f"Loaded {len(df)} rows from {csv_file_path}")
    # Uncomment below lines if you wish to see the first few rows
    # print("\n--- First 5 Rows of the DataFrame ---")
    # print(df.head())

    # B. Convert columns to numeric (MRR, mean_rank, etc.) if needed
    numeric_cols = [
        "k_value",
        "lambda_mult",
        "retrieved_docs",
        "found_exact",
        "missed_exact",
        "mean_rank",
        "mrr",
    ]
    for col in numeric_cols:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")

    # C. Visual Setup
    sns.set_theme(style="whitegrid", font_scale=1.1)

    # D. Organize 'model' Column Order
    # Extract unique models in the order they appear
    model_order = df["model"].drop_duplicates().tolist()

    # Set 'model' as a categorical variable with the specified order
    df["model"] = pd.Categorical(df["model"], categories=model_order, ordered=True)

    # E. Organize 'lambda_mult' Column Order (sorted)
    lambda_order = sorted(df["lambda_mult"].dropna().unique().tolist())
    df["lambda_mult"] = pd.Categorical(
        df["lambda_mult"], categories=lambda_order, ordered=True
    )

    return df, model_order, lambda_order


def plot_found_vs_missed(df, model_order, lambda_order):
    """
    Creates a series of bar charts comparing 'found_exact' vs. 'missed_exact' for each model segmented by 'lambda_mult'.
    Each lambda_mult value is plotted in its own subplot with abbreviated model names.
    A single consolidated legend is positioned below all subplots.

    Parameters:
        df (pd.DataFrame): The prepared DataFrame.
        model_order (list): List of models in the desired order.
        lambda_order (list): List of lambda_mult values in sorted order.

    Returns:
        agg_dict (dict): Dictionary containing aggregated DataFrames for each lambda_mult
    """
    num_lambdas = len(lambda_order)
    cols = 2  # Define number of columns in subplot grid
    rows = (num_lambdas + 1) // cols  # Calculate number of rows needed

    fig, axes = plt.subplots(rows, cols, figsize=(16, rows * 5), sharey=True)
    axes = axes.flatten()  # Flatten in case of multiple rows

    agg_dict = {}  # To store aggregated data for each lambda_mult

    for idx, lambda_val in enumerate(lambda_order):
        ax = axes[idx]
        subset = df[df["lambda_mult"] == lambda_val]
        agg = (
            subset.groupby("model", observed=True)[["found_exact", "missed_exact"]]
            .sum()
            .reindex(model_order)
        )

        # Abbreviate model names
        abbreviated_models = [abbreviate_model_name(name) for name in agg.index]
        agg.index = abbreviated_models

        # Plotting
        agg.plot(kind="bar", ax=ax, width=0.8, legend=False)
        ax.set_title(f"λ = {lambda_val}")
        ax.set_xlabel("Models")
        ax.set_ylabel("Count of Chunks")
        ax.tick_params(axis="x", rotation=45)
        plt.setp(
            ax.get_xticklabels(), ha="right"
        )  # Set horizontal alignment separately

        agg_dict[lambda_val] = agg  # Store aggregated data

    # Remove any unused subplots
    for j in range(idx + 1, len(axes)):
        fig.delaxes(axes[j])

    # Create a single legend for all subplots
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", ncol=5, bbox_to_anchor=(0.5, 0))

    fig.suptitle("Found vs. Missed Chunks per Model and Lambda Multiplier", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    return agg_dict


def plot_mean_rank_vs_kvalue(df, model_order, lambda_order):
    """
    Generates a series of line plots showing 'mean_rank' versus 'k_value' for each model segmented by 'lambda_mult'.
    Each lambda_mult value is plotted in its own subplot with a single legend positioned below all subplots.

    Parameters:
        df (pd.DataFrame): The prepared DataFrame.
        model_order (list): List of models in the desired order.
        lambda_order (list): List of lambda_mult values in sorted order.

    Returns:
        agg_dict (dict): Dictionary containing aggregated DataFrames for each lambda_mult.
    """
    num_lambdas = len(lambda_order)
    cols = 2  # Define number of columns in subplot grid
    rows = (num_lambdas + 1) // cols  # Calculate number of rows needed

    fig, axes = plt.subplots(
        rows, cols, figsize=(16, rows * 5), sharex=True, sharey=True
    )
    axes = axes.flatten()  # Flatten in case of multiple rows

    palette = sns.color_palette("husl", n_colors=len(model_order))

    agg_dict = {}  # To store aggregated data for each lambda_mult

    for idx, lambda_val in enumerate(lambda_order):
        ax = axes[idx]
        subset = df[df["lambda_mult"] == lambda_val]
        agg = (
            subset.groupby(["k_value", "model"], observed=True)["mean_rank"]
            .mean()
            .reset_index()
        )

        for i, model in enumerate(model_order):
            model_data = agg[agg["model"] == model]
            ax.plot(
                model_data["k_value"],
                model_data["mean_rank"],
                marker="o",
                label=abbreviate_model_name(model),
                color=palette[i % len(palette)],
            )

        ax.set_title(f"λ = {lambda_val}")
        ax.set_xlabel("k_value")
        ax.set_ylabel("Mean Rank")
        ax.tick_params(axis="x", rotation=45)
        plt.setp(
            ax.get_xticklabels(), ha="right"
        )  # Set horizontal alignment separately

        agg_dict[lambda_val] = agg  # Store aggregated data

    # Remove any unused subplots
    for j in range(idx + 1, len(axes)):
        fig.delaxes(axes[j])

    # Create a single legend for all subplots
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", ncol=5, bbox_to_anchor=(0.5, 0))

    fig.suptitle("Mean Rank vs. k_value per Model and Lambda Multiplier", fontsize=16)
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.show()

    return agg_dict


def plot_missed_chunks_vs_kvalue(df, model_order, lambda_order):
    """
    Produces a series of line plots illustrating 'missed_exact' versus 'k_value' for each model segmented by 'lambda_mult'.
    Each lambda_mult value is plotted in its own subplot with a single legend positioned below all subplots.

    Parameters:
        df (pd.DataFrame): The prepared DataFrame.
        model_order (list): List of models in the desired order.
        lambda_order (list): List of lambda_mult values in sorted order.

    Returns:
        agg_dict (dict): Dictionary containing aggregated DataFrames for each lambda_mult.
    """
    num_lambdas = len(lambda_order)
    cols = 2  # Define number of columns in subplot grid
    rows = (num_lambdas + 1) // cols  # Calculate number of rows needed

    fig, axes = plt.subplots(
        rows, cols, figsize=(16, rows * 5), sharex=True, sharey=True
    )
    axes = axes.flatten()  # Flatten in case of multiple rows

    palette = sns.color_palette("husl", n_colors=len(model_order))

    agg_dict = {}  # To store aggregated data for each lambda_mult

    for idx, lambda_val in enumerate(lambda_order):
        ax = axes[idx]
        subset = df[df["lambda_mult"] == lambda_val]
        agg = (
            subset.groupby(["k_value", "model"], observed=True)["missed_exact"]
            .sum()
            .reset_index()
        )

        for i, model in enumerate(model_order):
            model_data = agg[agg["model"] == model]
            ax.plot(
                model_data["k_value"],
                model_data["missed_exact"],
                marker="o",
                label=abbreviate_model_name(model),
                color=palette[i % len(palette)],
            )

        ax.set_title(f"λ = {lambda_val}")
        ax.set_xlabel("k_value")
        ax.set_ylabel("Missed Chunks")
        ax.tick_params(axis="x", rotation=45)
        plt.setp(
            ax.get_xticklabels(), ha="right"
        )  # Set horizontal alignment separately

        agg_dict[lambda_val] = agg  # Store aggregated data

    # Remove any unused subplots
    for j in range(idx + 1, len(axes)):
        fig.delaxes(axes[j])

    # Create a single legend for all subplots
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", ncol=5, bbox_to_anchor=(0.5, 0.02))

    fig.suptitle(
        "Missed Chunks vs. k_value per Model and Lambda Multiplier", fontsize=16
    )
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.show()

    return agg_dict


def interpret_results(df, model_order, lambda_order):
    """
    Determines the best model based on the following criteria for each lambda_mult and returns the lambda_mult
    where the overall best model is found:
    1. Prefer models with zero total missed chunks (summed across all attributes).
       - Among these, choose the one with the lowest k_value.
       - If tied, choose the one with the lowest mean_rank.
    2. If no model has zero misses for a lambda_mult, choose the model with the lowest number of misses.
       - Among these, choose the one with the lowest k_value.
       - If tied, choose the one with the lowest mean_rank.

    Prints:
        Models sorted by k_value and then by mean_rank within each lambda_mult.
        DataFrame table of each model's best performance, sorted by lowest misses and then by lowest mean rank.

    Parameters:
        df (pd.DataFrame): The prepared DataFrame.
        model_order (list): List of models in the desired order.
        lambda_order (list): List of lambda_mult values in sorted order.

    Returns:
        best_lambda_mult (float): The lambda_mult corresponding to the overall best model.
    """
    # Build summary table
    agg = (
        df.groupby(["model", "lambda_mult", "k_value"], observed=True)
        .agg(
            total_missed_exact=pd.NamedAgg(column="missed_exact", aggfunc="sum"),
            average_mean_rank=pd.NamedAgg(column="mean_rank", aggfunc="mean"),
        )
        .reset_index()
    )

    results = {}
    for model in model_order:
        model_rows = agg[agg["model"] == model]
        if model_rows.empty:
            continue
        model_rows = model_rows.copy()
        # Sort key: lower misses, then lower mean rank, then lower k_value
        model_rows["sort_key"] = model_rows.apply(
            lambda row: (
                row["total_missed_exact"],
                row["average_mean_rank"],
                row["k_value"],
            ),
            axis=1,
        )
        best = model_rows.sort_values("sort_key").iloc[0]
        criteria = "Zero Misses" if best["total_missed_exact"] == 0 else "Lowest Misses"
        results[model] = {
            "lambda_mult": best["lambda_mult"],
            "k_value": best["k_value"],
            "misses": best["total_missed_exact"],
            "mean_rank": best["average_mean_rank"],
            "criteria": criteria,
        }

    rows = []
    for model in model_order:
        if model in results:
            r = results[model]
            rows.append(
                {
                    "Model": model,
                    "Lambda": r["lambda_mult"],
                    "k_value": r["k_value"],
                    "Misses": r["misses"],
                    "Mean Rank": r["mean_rank"],
                }
            )
        else:
            rows.append(
                {
                    "Model": model,
                    "Lambda": "N/A",
                    "k_value": "N/A",
                    "Misses": float("inf"),
                    "Mean Rank": float("inf"),
                }
            )

    summary_df = pd.DataFrame(rows)
    summary_df["Misses"] = pd.to_numeric(summary_df["Misses"], errors="coerce")
    summary_df["Mean Rank"] = pd.to_numeric(summary_df["Mean Rank"], errors="coerce")
    summary_df = summary_df.sort_values(
        by=["Misses", "Mean Rank"], ascending=[True, True]
    )
    summary_df["Mean Rank"] = summary_df["Mean Rank"].apply(
        lambda x: f"{x:.2f}" if pd.notnull(x) else "N/A"
    )

    # Set column header justification to left
    pd.set_option('display.colheader_justify', 'left')
    print("\n====== BEST PERFORMANCE PER MODEL ======")
    print(summary_df.to_string(index=False))

    # Save summary as markdown
    markdown_str = tabulate(summary_df, headers="keys", tablefmt="pipe", showindex=False)

    md_dir = os.path.join("..", "..", "data", "pipeline2", "md")
    os.makedirs(md_dir, exist_ok=True)
    csv_timestamp = CSV.split("/")[-1].replace(".csv", "")
    md_path = os.path.join(md_dir, f"summary_table_{csv_timestamp}.md")
    with open(md_path, "w", encoding="utf-8") as f:
        f.write(markdown_str)
    print(f"\nMarkdown summary saved to: {md_path}")

    # Determine overall best model per lambda
    best_overall_criteria = None
    best_overall_list = []

    for lambda_val in lambda_order:
        print(f"\n===== Lambda Multiplier: {lambda_val} =====")
        subset = df[df["lambda_mult"] == lambda_val]
        group_agg = (
            subset.groupby(["model", "k_value"], observed=True)
            .agg(
                total_missed_exact=pd.NamedAgg(column="missed_exact", aggfunc="sum"),
                average_mean_rank=pd.NamedAgg(column="mean_rank", aggfunc="mean"),
            )
            .reset_index()
        )

        sorted_group = group_agg.sort_values(
            by=["total_missed_exact", "k_value", "average_mean_rank"],
            ascending=[True, True, True],
        )

        best_model_row = sorted_group.iloc[0]
        criteria = (
            "Zero Misses"
            if best_model_row["total_missed_exact"] == 0
            else "Lowest Misses"
        )
        best_model = best_model_row["model"]
        best_k = best_model_row["k_value"]
        best_misses = best_model_row["total_missed_exact"]
        best_mean_rank = best_model_row["average_mean_rank"]

        print("\n------ SUMMARY PER LAMBDA ------")
        if criteria == "Zero Misses":
            print(
                f"The best model for lambda {lambda_val} is '{best_model}' with "
                f"k_value={best_k} and mean_rank={best_mean_rank:.2f} (Zero Misses)."
            )
        else:
            print(
                f"No model had zero misses. Best for lambda {lambda_val} is '{best_model}' with "
                f"{int(best_misses)} misses, k_value={best_k}, mean_rank={best_mean_rank:.2f}."
            )

        # Now compare with the overall best
        if not best_overall_list:
            # This is the first best found; store it
            best_overall_criteria = criteria
            best_overall_list.append(
                {
                    "model": best_model,
                    "lambda_mult": lambda_val,
                    "k_value": best_k,
                    "misses": best_misses,
                    "mean_rank": best_mean_rank,
                    "criteria": criteria,
                }
            )
        else:
            # Compare with current best
            current_best = best_overall_list[0]  # All in list share same metrics
            # If new candidate has Zero Misses and old was Lowest Misses
            if criteria == "Zero Misses" and best_overall_criteria != "Zero Misses":
                best_overall_list.clear()
                best_overall_criteria = criteria
                best_overall_list.append(
                    {
                        "model": best_model,
                        "lambda_mult": lambda_val,
                        "k_value": best_k,
                        "misses": best_misses,
                        "mean_rank": best_mean_rank,
                        "criteria": criteria,
                    }
                )
            elif criteria == best_overall_criteria:
                # Both are "Zero Misses" or both "Lowest Misses"
                # Compare total misses, k_value, mean_rank
                if best_misses < current_best["misses"]:
                    best_overall_list.clear()
                    best_overall_list.append(
                        {
                            "model": best_model,
                            "lambda_mult": lambda_val,
                            "k_value": best_k,
                            "misses": best_misses,
                            "mean_rank": best_mean_rank,
                            "criteria": criteria,
                        }
                    )
                elif best_misses == current_best["misses"]:
                    if best_k < current_best["k_value"]:
                        best_overall_list.clear()
                        best_overall_list.append(
                            {
                                "model": best_model,
                                "lambda_mult": lambda_val,
                                "k_value": best_k,
                                "misses": best_misses,
                                "mean_rank": best_mean_rank,
                                "criteria": criteria,
                            }
                        )
                    elif best_k == current_best["k_value"]:
                        if best_mean_rank < current_best["mean_rank"]:
                            # strictly better mean rank
                            best_overall_list.clear()
                            best_overall_list.append(
                                {
                                    "model": best_model,
                                    "lambda_mult": lambda_val,
                                    "k_value": best_k,
                                    "misses": best_misses,
                                    "mean_rank": best_mean_rank,
                                    "criteria": criteria,
                                }
                            )
                        elif abs(best_mean_rank - current_best["mean_rank"]) < 1e-9:
                            # Same misses, same k, same mean rank => It's a tie
                            best_overall_list.append(
                                {
                                    "model": best_model,
                                    "lambda_mult": lambda_val,
                                    "k_value": best_k,
                                    "misses": best_misses,
                                    "mean_rank": best_mean_rank,
                                    "criteria": criteria,
                                }
                            )

            # If new candidate is "Lowest Misses" but current best is "Zero Misses", do nothing
            # because Zero Misses is always better.
    print("\n====== FINAL OVERALL BEST MODEL(S) ======")
    # All entries in best_overall_list share the same performance, but may differ in lambda
    # or even be different models with identical performance.
    # Print them all:
    if best_overall_list:
        for i, item in enumerate(best_overall_list, start=1):
            if item["criteria"] == "Zero Misses":
                print(
                    f"({i}) `{item['model']}` with `k_value = {item['k_value']}`, `mean_rank = {item['mean_rank']:.2f}`, "
                    f"(Zero Misses), `lambda_mult = {item['lambda_mult']}`"
                )
            else:
                print(
                    f"({i}) `{item['model']}` with `{int(item['misses'])} misses`, `k_value = {item['k_value']}`, "
                    f"`mean_rank = {item['mean_rank']:.2f}`, `lambda_mult = {item['lambda_mult']}`"
                )
        return best_overall_list[0]["lambda_mult"]
    else:
        print("No best model was found.")
        return None


def visualize_embedding_results(csv_file_path):
    """
    Orchestrates the loading of data and creation of visualizations for embedding model comparisons,
    now segmented by 'lambda_mult'. Aggregated data from all plots are printed at the end.

    Parameters:
        csv_file_path (str): The file path to the CSV file.
    """
    # Load and prepare data
    df, model_order, lambda_order = load_and_prepare_data(csv_file_path)

    # Create Visualizations and collect aggregated data
    agg_found_vs_missed = plot_found_vs_missed(df, model_order, lambda_order)
    agg_mean_rank = plot_mean_rank_vs_kvalue(df, model_order, lambda_order)
    agg_missed_chunks = plot_missed_chunks_vs_kvalue(df, model_order, lambda_order)

    # Print all aggregated data
    print("\n=== Aggregated Data from Plots ===")

    # Found vs. Missed
    print("\n--- Found Exact vs. Missed Exact per Model and Lambda Multiplier ---")
    for lambda_val, agg in agg_found_vs_missed.items():
        print(f"\nLambda Multiplier: {lambda_val}")
        print(agg)

    # Mean Rank
    print("\n--- Mean Rank per k_value for Each Model and Lambda Multiplier ---")
    for lambda_val, agg in agg_mean_rank.items():
        print(f"\nLambda Multiplier: {lambda_val}")
        print(agg)

    # Missed Chunks
    print("\n--- Missed Chunks per k_value for Each Model and Lambda Multiplier ---")
    for lambda_val, agg in agg_missed_chunks.items():
        print(f"\nLambda Multiplier: {lambda_val}")
        print(agg)

    # Interpret Results and determine best lambda
    best_lambda_mult = interpret_results(df, model_order, lambda_order)
    return best_lambda_mult


plot_csv_file = CSV

best_lambda_mult = visualize_embedding_results(plot_csv_file)
print("BEST LAMBDA MULT FROM FIRST CODE:", best_lambda_mult)

In [None]:
best_lambda_mult = 1.0

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def abbreviate_product_name(name):
    """
    Abbreviates the product name.

    Parameters:
        name (str): The original product name.

    Returns:
        str: The abbreviated product name.
    """
    if len(name) <= 30:
        return name
    return f"{name.lower()[:70]}..."


def load_and_prepare_data(csv_file_path):
    """
    Loads the CSV data, performs basic cleanup, converts specified columns to numeric,
    fills missing numeric values with -1, and organizes 'model' and 'lambda_mult'
    to preserve their order of appearance.
    """
    df = pd.read_csv(csv_file_path)
    print(f"Loaded {len(df)} rows from {csv_file_path}")

    # Fixed numeric_cols list by adding the missing comma after "rank"
    numeric_cols = [
        "k_value",
        "retrieved_docs",
        "found_exact",
        "missed_exact",
        "mean_rank",
        "rank",
        "mrr",
        "lambda_mult",
    ]
    
    for col in numeric_cols:
        if col in df.columns:
            # Convert the column to numeric, coercing errors to NaN
            df[col] = pd.to_numeric(df[col], errors="coerce")
            # Fill NaN values with -1
            df[col] = df[col].fillna(-1)

    sns.set_theme(style="whitegrid", font_scale=1.1)

    # Organize 'model' by appearance order
    model_order = df["model"].drop_duplicates().tolist()
    df["model"] = pd.Categorical(df["model"], categories=model_order, ordered=True)

    # Organize 'lambda_mult' by sorted order if it exists
    if "lambda_mult" in df.columns:
        lambda_order = sorted(df["lambda_mult"].dropna().unique().tolist())
        df["lambda_mult"] = pd.Categorical(
            df["lambda_mult"], categories=lambda_order, ordered=True
        )
    else:
        lambda_order = []

    return df, model_order, lambda_order

def plot_found_vs_missed(df, model_order, best_lambda_mult):
    """
    Side-by-side bar chart for 'found_exact' vs. 'missed_exact' for each model (filtered by best lambda).
    """
    fig, ax = plt.subplots(figsize=(12, 6), constrained_layout=True)
    subset = df.groupby("model", as_index=False, observed=True)[
        ["found_exact", "missed_exact"]
    ].sum()

    # Re-sort models after filtering
    subset["model"] = pd.Categorical(
        subset["model"], categories=model_order, ordered=True
    )
    subset = subset.dropna(subset=["model"]).sort_values("model")

    X = range(len(subset))
    width = 0.35
    ax.bar(
        [x - width / 2 for x in X],
        subset["found_exact"],
        width=width,
        label="Found Exact",
    )
    ax.bar(
        [x + width / 2 for x in X],
        subset["missed_exact"],
        width=width,
        label="Missed Exact",
    )
    ax.set_xticks(X)
    ax.set_xticklabels(subset["model"], rotation=45, ha="right")
    ax.set_title(
        f"Found vs. Missed Chunks \n(Summed over all k_values with best_lambda_mult={best_lambda_mult})"
    )
    ax.set_xlabel("Models")
    ax.set_ylabel("Count of Chunks")
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        ncol=2,
        frameon=True,
        bbox_to_anchor=(0.45, 0.1),
    )
    fig.subplots_adjust(bottom=0.2)
    plt.show()

    print("\n--- Found Exact vs. Missed Exact per Model (Filtered) ---")
    print(subset)


def plot_mean_rank_vs_kvalue(df, model_order, best_lambda_mult):
    """
    Line plot showing 'mean_rank' vs. 'k_value' for each model (filtered by best lambda).
    """
    plt.figure(figsize=(12, 6))
    mean_rank_summary = {}

    palette = sns.color_palette("husl", n_colors=len(model_order))

    for i, model_name in enumerate(model_order):
        model_data = df[df["model"] == model_name]
        meanranks = model_data.groupby("k_value")["mean_rank"].mean().reset_index()
        mean_rank_summary[model_name] = meanranks
        plt.plot(
            meanranks["k_value"],
            meanranks["mean_rank"],
            marker="o",
            label=model_name,
            color=palette[i % len(palette)],
        )

    plt.title(
        f"Mean Rank vs. k_value for Each Model \n(best_lambda_mult={best_lambda_mult})"
    )
    plt.xlabel("k_value")
    plt.ylabel("Mean Rank")
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()

    print("\n--- Mean Rank per k_value for Each Model (Filtered) ---")
    for model, summary in mean_rank_summary.items():
        print(f"\nModel: {model}")
        print(summary.to_string(index=False))


def plot_missed_chunks_vs_kvalue(df, model_order, best_lambda_mult):
    """
    Line plot illustrating 'missed_exact' vs. 'k_value' for each model (filtered by best lambda).
    """
    plt.figure(figsize=(12, 6))
    missed_summary = {}

    palette = sns.color_palette("husl", n_colors=len(model_order))

    for i, model_name in enumerate(model_order):
        model_data = df[df["model"] == model_name]
        if model_data.empty:
            continue
        missed_per_k = model_data.groupby("k_value")["missed_exact"].sum().reset_index()
        missed_summary[model_name] = missed_per_k
        plt.plot(
            missed_per_k["k_value"],
            missed_per_k["missed_exact"],
            marker="o",
            label=model_name,
            color=palette[i % len(palette)],
        )

    plt.title(
        f"Missed Chunks vs. k_value for Each Model \n(best_lambda_mult={best_lambda_mult})"
    )
    plt.xlabel("k_value")
    plt.ylabel("Missed Chunks")
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()

    print("\n--- Missed Chunks per k_value for Each Model (Filtered) ---")
    for model, summary in missed_summary.items():
        print(f"\nModel: {model}")
        print(summary.to_string(index=False))


def plot_heatmap_rank(df, model_order, best_lambda_mult, csv_file_path):
    """
    Plots a heatmap of rank values with Product Name on the Y-axis,
    ensuring the products appear in the original CSV order and that 
    missing rank values are replaced with -1. Also, appends part of the
    CSV file name to the plot title.
    """
    # Extract the base CSV filename without extension
    csv_base = os.path.basename(csv_file_path)
    csv_name = os.path.splitext(csv_base)[0]
    
    # Preserve original CSV order by creating a unique order index
    df["Product Order"] = df["product_name"].factorize()[0]
    df["Product Name (Short)"] = df["product_name"].apply(abbreviate_product_name)

    # Aggregate rank values to ensure unique (Product Order, Product Name, Model) combinations
    rank_df = (
        df.groupby(["Product Order", "Product Name (Short)", "model"], observed=True)["rank"]
          .mean()  # Take the mean rank across different k-values
          .reset_index()
          .rename(columns={
              "model": "Model",
              "Product Name (Short)": "Product Name",
              "rank": "Rank"
          })
    )

    # Ensure the DataFrame is sorted by the original CSV order
    rank_df = rank_df.sort_values("Product Order")

    # Pivot using both "Product Order" and "Product Name" to maintain order
    pivot_data = rank_df.pivot_table(
        index=["Product Order", "Product Name"],
        columns="Model",
        values="Rank",
        aggfunc="mean",
        observed=True,
    )

    # Sort the pivot table by the "Product Order" level
    pivot_data = pivot_data.sort_index(level="Product Order")

    # Drop the "Product Order" level so that only the product names appear on the y-axis
    pivot_data.index = pivot_data.index.droplevel("Product Order")
    
    # Plotting the heatmap with the CSV part appended to the title
    plt.figure(figsize=(12, 50))  # Adjust height for better visualization
    sns.heatmap(
        pivot_data,
        annot=True,
        fmt=".1f",
        cmap="YlGnBu",
        linewidths=0.5,
        cbar_kws={"label": "Rank"}
    )
    plt.title(f"Rank by Product Name per Model ({csv_name})")
    plt.xlabel("Model")
    plt.ylabel("Product Name (Original CSV Order)")
    plt.xticks(rotation=0, ha="right")
    plt.yticks(rotation=0, ha="right")
    plt.tight_layout()
    plt.show()

    heatmap_rank_df_string = pivot_data.to_string()

    return heatmap_rank_df_string


def interpret_results(df):
    """
    Same logic as before, but operates on the already filtered DataFrame (by best_lambda).
    Determines the best model under that single lambda context.
    """
    group_agg = (
        df.groupby(["model", "k_value"], observed=True)
        .agg(
            missed_exact=pd.NamedAgg(column="missed_exact", aggfunc="sum"),
            mean_rank=pd.NamedAgg(column="mean_rank", aggfunc="mean"),
        )
        .reset_index()
    )

    print("\n====== AGGREGATED RESULTS (Filtered) ======")
    print(group_agg)

    sorted_group = group_agg.sort_values(
        by=["missed_exact", "k_value", "mean_rank"],
        ascending=[True, True, True],
    )

    best_model_row = sorted_group.iloc[0]
    if best_model_row["missed_exact"] == 0:
        criteria = "Zero Misses"
    else:
        criteria = "Lowest Misses"

    best_model = best_model_row["model"]
    best_k = best_model_row["k_value"]
    best_misses = best_model_row["missed_exact"]
    best_mean_rank = best_model_row["mean_rank"]

    print("\n====== MINIMUM k_value WITH ZERO MISSES (Filtered) ======")
    zero_miss_df = group_agg[group_agg["missed_exact"] == 0]
    if not zero_miss_df.empty:
        models_zero_miss = (
            zero_miss_df.groupby("model", observed=True)
            .agg(min_k_value=pd.NamedAgg(column="k_value", aggfunc="min"))
            .reset_index()
        )
        merged_zero_miss = pd.merge(
            models_zero_miss,
            zero_miss_df,
            left_on=["model", "min_k_value"],
            right_on=["model", "k_value"],
        )
        merged_zero_miss_sorted = merged_zero_miss.sort_values(
            by=["min_k_value", "mean_rank"]
        )
        for _, row in merged_zero_miss_sorted.iterrows():
            print(
                f"{row['model']}: k_value = {row['min_k_value']}, mean_rank = {row['mean_rank']:.2f}"
            )
    else:
        print("No model achieved zero total misses under this lambda.")

    print("\n====== SUMMARY (Filtered) ======")
    if criteria == "Zero Misses":
        print(
            f"Best model (filtered) is `{best_model}` with "
            f"`k_value = {best_k}` and `mean_rank = {best_mean_rank:.2f}` (Zero Misses)."
        )
    else:
        print(
            f"No zero-miss model under this lambda. Best is `{best_model}` with "
            f"`{int(best_misses)} misses`, `k_value = {best_k}`, `mean_rank = {best_mean_rank:.2f}`."
        )


def visualize_embedding_results(csv_file_path, best_lambda_mult):
    """
    Orchestrates loading data, filtering on the best lambda, and creating visualizations.
    """
    df, model_order, lambda_order = load_and_prepare_data(csv_file_path)

    # Filter to the best lambda identified from the first code
    df = df[df["lambda_mult"] == best_lambda_mult]
    df = df.dropna(subset=["model"])
    # Update model order after filtering
    new_model_order = df["model"].drop_duplicates().tolist()
    df["model"] = pd.Categorical(df["model"], categories=new_model_order, ordered=True)

    print(f"\nUsing best_lambda_mult = {best_lambda_mult}")
    print(f"Remaining rows after filtering: {len(df)}")

    plot_found_vs_missed(df, new_model_order, best_lambda_mult)
    plot_mean_rank_vs_kvalue(df, new_model_order, best_lambda_mult)
    plot_missed_chunks_vs_kvalue(df, new_model_order, best_lambda_mult)
    heatmap_rank_df_string = plot_heatmap_rank(df, new_model_order, best_lambda_mult, csv_file_path)
    interpret_results(df)

    return heatmap_rank_df_string


########
# Usage
########
best_lambda_mult = best_lambda_mult
visualize_embedding_results_txt = visualize_embedding_results(plot_csv_file, best_lambda_mult)

In [None]:
# Print the heatmap data in text format
print("\nHeatmap Data (Text Format):\n")
print(visualize_embedding_results_txt)