# Embeddings + Classic ML Classifiers


## 1) Configuration

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dotenv import load_dotenv

from classifiers import benchmark_classifiers, make_classifiers
from embeddings import EmbeddingCache, OpenAIEmbeddingCache, SPECTER2EmbeddingCache
from utils import clean_dataframe, join_title_abstract

load_dotenv("../.env")

In [None]:
FULL_DATASET_CSV = os.getenv("FULL_DATASET_CSV")
LABEL_COL = os.getenv("LABEL_COL")
RESULTS_DIR = os.getenv("RESULTS_DIR")
RANDOM_STATE = int(os.getenv("RANDOM_STATE"))

# embeddings cache directory
CACHE_DIR_EMBEDDINGS = "../.embeddings_cache"
CACHE_DIR_CLASSIFIERS = "../.classifiers_cache"

# CV
N_SPLITS = 5


## 2) Choose embedding models


In [None]:
EMBEDDING_MODELS = [
    # generalist
    "sentence-transformers/all-MiniLM-L6-v2",                 # https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
    "sentence-transformers/all-mpnet-base-v2",                # https://huggingface.co/sentence-transformers/all-mpnet-base-v2
    "text-embedding-3-small",                                 # https://platform.openai.com/docs/guides/embeddings
    "text-embedding-3-large",                                 # https://platform.openai.com/docs/guides/embeddings
    
    # biomedical/scientifical
    "allenai/specter2_base",                                  # https://huggingface.co/allenai/specter2_base
    "allenai/biomed_roberta_base",                            # https://huggingface.co/allenai/biomed_roberta_base
    "pritamdeka/S-PubMedBERT-MS-MARCO",                       # https://huggingface.co/pritamdeka/S-PubMedBert-MS-MARCO
    "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",   # https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract
    "sentence-transformers/embeddinggemma-300m-medical"       # https://huggingface.co/sentence-transformers/embeddinggemma-300m-medical
]


## 3) Load & clean data


In [None]:
df_raw = pd.read_csv(FULL_DATASET_CSV, usecols=["id", "title", "abstract", LABEL_COL])
df_raw = df_raw.rename(columns={LABEL_COL: "label"})  # rename the label column to "label"

print(f"Raw dataset shape: {df_raw.shape}")
df_raw.head(5)

In [None]:
df = clean_dataframe(df_raw)
df.head(5)

In [None]:
print("Rows:", len(df))

# show class imbalance
label_names = df["label"].map({True: "Positive", False: "Negative"})
summary = pd.DataFrame({"count": label_names.value_counts(), "percent": (label_names.value_counts(normalize=True) * 100).round(2)})
summary

In [None]:
texts = join_title_abstract(df)
y = df["label"].values.astype(int)


## 4) Compute (and cache) embeddings
Embeddings are cached under `CACHE_DIR_EMBEDDINGS` and keyed by model+dataset hash.

Store unnormalized embeddings: it's always easy to normalize embeddings later than to "unnormalize" already normalized ones.


In [None]:
# these settings help with some libraries that use multiple threads by default (specter2)
import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

In [None]:
cache_general = EmbeddingCache(cache_dir=CACHE_DIR_EMBEDDINGS)
cache_specter = SPECTER2EmbeddingCache(cache_dir=CACHE_DIR_EMBEDDINGS)
cache_openai = OpenAIEmbeddingCache(cache_dir=CACHE_DIR_EMBEDDINGS)

In [None]:
# compute embeddings for all models
X_by_model = {}
for model_name in EMBEDDING_MODELS:
    print(f"Processing model: {model_name}")

    # choose the right cache class
    if EmbeddingCache.is_specter2_model(model_name):
        cache = cache_specter
    elif EmbeddingCache.is_openai_model(model_name):
        cache = cache_openai
    else:
        cache = cache_general

    try:
        X, meta = cache.compute(
            texts,
            model_name,
            batch_size=128,
            normalize_embeddings=False,
            device=None,
        )
        X_by_model[model_name] = X
        print("Shape:", X.shape, "| dim:", X.shape[1])
    except Exception as e:
        print(f"Failed to process {model_name}: {e}")
        continue

In [None]:
# normalize embeddings
X_by_model_normalized = {}
for model_name, X in X_by_model.items():

    # choose the right cache class
    if EmbeddingCache.is_specter2_model(model_name):
        cache = cache_specter
    elif EmbeddingCache.is_openai_model(model_name):
        cache = cache_openai
    else:
        cache = cache_general

    X_normalized = cache.normalize(X)
    X_by_model_normalized[model_name] = X_normalized

In [None]:
# embedding shapes
rows = []
for model_name, X in X_by_model_normalized.items():
    rows.append({"embedding_model": model_name, "shape": X.shape})

shapes_df = pd.DataFrame(rows)
shapes_df

## 5) Evaluate classifiers (Stratified K-Fold)

In [None]:
# create classifiers
classifiers = make_classifiers(random_state=RANDOM_STATE)
classifiers

In [None]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning, message=".*pkg_resources is deprecated.*")

### Normalized

In [None]:
# normalized embeddings
per_fold_df_norm, summary_df_norm = benchmark_classifiers(
    X_by_model=X_by_model_normalized,
    y=y,
    classifiers=classifiers,
    n_splits=N_SPLITS,
    random_state=RANDOM_STATE,
    use_cache=True,
    cache_dir=os.path.join(CACHE_DIR_CLASSIFIERS, "base"),
)
summary_df_norm.round(2)

### Unnormalized

In [None]:
# unnormalized embeddings
per_fold_df_unnorm, summary_df_unnorm = benchmark_classifiers(
    X_by_model=X_by_model,
    y=y,
    classifiers=classifiers,
    n_splits=N_SPLITS,
    random_state=RANDOM_STATE,
    use_cache=True,
    cache_dir=os.path.join(CACHE_DIR_CLASSIFIERS, "unnormalized"),
)
summary_df_unnorm.round(2)

In [None]:
# export data to generate tables in the paper
summary_df_norm.embedding_model = summary_df_norm.embedding_model.str.split("/").str[-1]
summary_df_norm = summary_df_norm.round(2)
summary_df_norm = summary_df_norm.sort_values(by="roc_auc_mean", ascending=False)[
    ["embedding_model", "classifier", "roc_auc_mean", "recall_mean", "precision_mean", "specificity_mean", "acc_mean"]
]
summary_df_norm = summary_df_norm.reset_index(drop=True)

summary_df_norm.to_csv(
    os.path.join(
        RESULTS_DIR,
        "thesis_figures_tables_generation",
        "1",
        "summary_normalized_embeddings.csv",
    ),
    index=False,
)
summary_df_norm


## 6) Simple visualizations

In [None]:
# normalized embeddingd
# plots for PR AUC, Accuracy, Recall, Precision (top 15 combos each)
metrics = [
    ("acc_mean", "Accuracy"),
    ("recall_mean", "Recall"),
    ("precision_mean", "Precision"),
    ("specificity_mean", "Specificity"),
]

summary_df_norm.embedding_model = summary_df_norm.embedding_model.str.split("/").str[-1]


# helper to find the column in summary_df_norm (handles names like pr_auc_mean / pr_auc_std etc.)
def find_metric_col(df, key):
    key = key.lower()
    for c in df.columns:
        if key in c.lower():
            # prefer mean columns if available
            if c.lower().endswith("_mean"):
                return c
            # otherwise just return first match
            return c
    return None


top_k = 8
fig, axes = plt.subplots(2, 2, figsize=(18, 10), dpi=150)
axes = axes.flatten()

bar_height = 0.6  # decrease this value for more spacing (default is ~0.8)

for ax, (metric_key, metric_label) in zip(axes, metrics):
    col = find_metric_col(summary_df_norm, metric_key)
    if col is None:
        ax.text(0.5, 0.5, f"No column found for '{metric_label}'", ha="center", va="center")
        ax.set_xticks([])
        ax.set_yticks([])
        continue

    summary_sorted = summary_df_norm.sort_values(col, ascending=False).head(top_k)
    labels = (summary_sorted["embedding_model"] + "\n" + summary_sorted["classifier"]).tolist()
    vals = summary_sorted[col].values

    # plot bars with reduced height for more spacing
    ax.barh(range(len(vals)), vals, height=bar_height, color="C0")
    ax.set_yticks(range(len(vals)))
    ax.set_yticklabels(labels, linespacing=1.3)
    ax.invert_yaxis()

    # force x limits to 0-1
    ax.set_xlim(0, 1.0)

    # major ticks every 0.1, minor ticks every 0.05
    major_xticks = np.arange(0, 1.0001, 0.1)
    minor_xticks = np.arange(0, 1.0001, 0.05)
    ax.set_xticks(major_xticks)
    ax.set_xticks(minor_xticks, minor=True)

    # lighter vertical gridlines at 0.05 (minor); keep major ticks without heavy grid
    ax.grid(axis="x", which="minor", linestyle="--", color="gray", linewidth=0.5, alpha=0.35)

    # data callouts at end of bars
    x_offset = (1.0 - 0) * 0.01
    for i, v in enumerate(vals):
        ax.text(v + x_offset, i, f"{v:.3f}", va="center", ha="left", fontsize=9, color="black")

    ax.set_xlabel(f"{metric_label}" if "mean" in col.lower() else metric_label)
    # ax.set_title(f"Top {top_k} combinations by {metric_label.upper()}")

# plt.suptitle("NORMALIZED EMBEDDINGS", fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# plots for PR AUC, Accuracy, Recall, Precision (top 15 combos each)
metrics = [
    # ("pr_auc", "PR AUC"),
    ("acc_mean", "Accuracy"),
    ("recall_mean", "Recall"),
    ("precision_mean", "Precision"),
    ("specificity_mean", "Specificity"),
]

summary_df_unnorm.embedding_model = summary_df_unnorm.embedding_model.str.split("/").str[-1]


# helper to find the column in summary_df_unnorm (handles names like pr_auc_mean / pr_auc_std etc.)
def find_metric_col(df, key):
    key = key.lower()
    for c in df.columns:
        if key in c.lower():
            # prefer mean columns if available
            if c.lower().endswith("_mean"):
                return c
            # otherwise just return first match
            return c
    return None


top_k = 10
fig, axes = plt.subplots(2, 2, figsize=(18, 10), dpi=150)
axes = axes.flatten()

for ax, (metric_key, metric_label) in zip(axes, metrics):
    col = find_metric_col(summary_df_unnorm, metric_key)
    if col is None:
        ax.text(0.5, 0.5, f"No column found for '{metric_label}'", ha="center", va="center")
        ax.set_xticks([])
        ax.set_yticks([])
        continue

    summary_sorted = summary_df_unnorm.sort_values(col, ascending=False).head(top_k)
    labels = (summary_sorted["embedding_model"] + " | " + summary_sorted["classifier"]).tolist()
    vals = summary_sorted[col].values

    # plot bars
    ax.barh(range(len(vals)), vals, color="C0")
    ax.set_yticks(range(len(vals)))
    ax.set_yticklabels(labels)
    ax.invert_yaxis()

    # force x limits to 0-1
    ax.set_xlim(0, 1.0)

    # major ticks every 0.1, minor ticks every 0.05
    major_xticks = np.arange(0, 1.0001, 0.1)
    minor_xticks = np.arange(0, 1.0001, 0.05)
    ax.set_xticks(major_xticks)
    ax.set_xticks(minor_xticks, minor=True)

    # lighter vertical gridlines at 0.05 (minor); keep major ticks without heavy grid
    ax.grid(axis="x", which="minor", linestyle="--", color="gray", linewidth=0.5, alpha=0.35)

    # data callouts at end of bars
    x_offset = (1.0 - 0) * 0.01
    for i, v in enumerate(vals):
        ax.text(v + x_offset, i, f"{v:.3f}", va="center", ha="left", fontsize=9, color="black")

    ax.set_xlabel(f"{metric_label} (mean across folds)" if "mean" in col.lower() else metric_label)
    ax.set_title(f"Top {top_k} combos by {metric_label.upper()}")

plt.suptitle("UNNORMALIZED EMBEDDINGS", fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
comparison_df = summary_df_norm.merge(
    summary_df_unnorm,
    on=["embedding_model", "classifier"],
    suffixes=("_norm", "_unnorm"),
).round(3)

# compute differences between normalized and unnormalized
comparison_df["diff_acc"] = comparison_df["acc_mean_norm"] - comparison_df["acc_mean_unnorm"]
comparison_df["diff_recall"] = comparison_df["recall_mean_norm"] - comparison_df["recall_mean_unnorm"]
comparison_df["diff_precision"] = comparison_df["precision_mean_norm"] - comparison_df["precision_mean_unnorm"]
comparison_df["diff_specificity"] = comparison_df["specificity_mean_norm"] - comparison_df["specificity_mean_unnorm"]

comparison_df[["diff_acc", "diff_recall", "diff_precision", "diff_specificity"]].describe().round(3)

# 8) Fixed Recall (Sensitivity, TPR) Analysis (95% Recall Target)

Now let's implement the threshold tuning approach to fix recall at 95% and see how other metrics perform.

In [None]:
import os
import sys

sys.path.append(os.path.join("..", "src"))

from classifiers.fixed_recall_classifiers import analyze_recall_precision_tradeoff, benchmark_classifiers_fixed_recall, find_recall_threshold

Benchmark all combos with fixed recall

In [None]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning, message=".*pkg_resources is *")

In [None]:
from classifiers.classifier_cache import ClassifierCache

# Create cache for fixed recall classifiers
fixed_recall_cache = ClassifierCache(cache_dir=os.path.join(CACHE_DIR_CLASSIFIERS, "fixed_recall"))

print("Running fixed recall evaluation (95% target) for selected models...")
per_fold_df_fixed, summary_df_fixed = benchmark_classifiers_fixed_recall(
    X_by_model_normalized,
    y,
    classifiers,
    target_recall=0.95,
    random_state=RANDOM_STATE,
    n_splits=N_SPLITS,
    cache=fixed_recall_cache,
)

summary_df_fixed

In [None]:
# export data to generate tables in the paper
summary_df_fixed = summary_df_fixed.round(2)
summary_df_fixed = summary_df_fixed.sort_values(by="specificity_mean", ascending=False)[
    ["embedding_model", "classifier", "recall_mean", "precision_mean", "specificity_mean", "acc_mean", "threshold_mean"]
]
summary_df_fixed = summary_df_fixed.reset_index(drop=True)

summary_df_fixed.to_csv(
    os.path.join(
        RESULTS_DIR,
        "thesis_figures_tables_generation",
        "1",
        "summary_fixed_threshold_95.csv",
    ),
    index=False,
)
summary_df_fixed

In [None]:
# plots for PR AUC, Accuracy, Recall, Precision (top 15 combos each)
metrics = [
    ("acc_mean", "Accuracy"),
    ("recall_mean", "Recall"),
    ("precision_mean", "Precision"),
    ("specificity_mean", "Specificity"),
]

summary_df_fixed["embedding_model_short"] = summary_df_fixed.embedding_model.str.split("/").str[-1]


# helper to find the column in summary_df_fixed (handles names like pr_auc_mean / pr_auc_std etc.)
def find_metric_col(df, key):
    key = key.lower()
    for c in df.columns:
        if key in c.lower():
            # prefer mean columns if available
            if c.lower().endswith("_mean"):
                return c
            # otherwise just return first match
            return c
    return None


top_k = 8
fig, axes = plt.subplots(2, 2, figsize=(20, 10), dpi=150)
axes = axes.flatten()

for ax, (metric_key, metric_label) in zip(axes, metrics):
    col = find_metric_col(summary_df_fixed, metric_key)
    if col is None:
        ax.text(0.5, 0.5, f"No column found for '{metric_label}'", ha="center", va="center")
        ax.set_xticks([])
        ax.set_yticks([])
        continue

    summary_sorted = summary_df_fixed.sort_values(col, ascending=False).head(top_k)
    labels = (summary_sorted["embedding_model_short"] + " | " + summary_sorted["classifier"]).tolist()
    vals = summary_sorted[col].values

    # plot bars
    ax.barh(range(len(vals)), vals, color="C0")
    ax.set_yticks(range(len(vals)))
    ax.set_yticklabels(labels)
    ax.invert_yaxis()

    # force x limits to 0-1
    ax.set_xlim(0, 1.0)

    # major ticks every 0.1, minor ticks every 0.05
    major_xticks = np.arange(0, 1.0001, 0.1)
    minor_xticks = np.arange(0, 1.0001, 0.05)
    ax.set_xticks(major_xticks)
    ax.set_xticks(minor_xticks, minor=True)

    # lighter vertical gridlines at 0.05 (minor); keep major ticks without heavy grid
    ax.grid(axis="x", which="minor", linestyle="--", color="gray", linewidth=0.5, alpha=0.35)

    # data callouts at end of bars
    x_offset = (1.0 - 0) * 0.01
    for i, v in enumerate(vals):
        ax.text(v + x_offset, i, f"{v:.2f}", va="center", ha="left", fontsize=9, color="black")

    ax.set_xlabel(f"{metric_label}" if "mean" in col.lower() else metric_label)
    ax.set_title(f"Top {top_k} combos by {metric_label.upper()}")

plt.suptitle("Fixed Recall Evaluation (Target Recall = 0.95)", fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# Analyze recall-precision tradeoff for the best performing model
# Use the best model from our fixed recall results (highest precision)
best_fixed = summary_df_fixed.iloc[0]  # Best precision in fixed recall results
best_embed = best_fixed["embedding_model"]
best_clf = best_fixed["classifier"]

print(f"Analyzing recall-precision tradeoff for: {best_embed} + {best_clf}")

# Analyze different recall targets
tradeoff_analysis = analyze_recall_precision_tradeoff(
    X_by_model_normalized[best_embed],
    y,
    classifiers[best_clf],
    recall_targets=[0.95, 0.96, 0.97, 0.98, 0.99],
    random_state=RANDOM_STATE,
    n_splits=N_SPLITS,
)

print("\n=== RECALL-PRECISION TRADEOFF ANALYSIS ===")
tradeoff_analysis.round(2)

In [None]:
# save tradeoff analysis results
tradeoff_analysis_dir = os.path.join(RESULTS_DIR, "1", best_embed.replace("/", "_"), best_clf)
os.makedirs(tradeoff_analysis_dir, exist_ok=True)
tradeoff_analysis.to_csv(os.path.join(tradeoff_analysis_dir, "recall_precision_tradeoff_analysis.csv"), index=False)

In [None]:
# export data to generate tables in the paper

summary_df_norm = summary_df_norm.round(2)
summary_df_norm = summary_df_norm.sort_values(by="roc_auc_mean", ascending=False)[['embedding_model', 'classifier', 'roc_auc_mean', 'recall_mean', "precision_mean", "specificity_mean", "acc_mean"]]
summary_df_norm = summary_df_norm.reset_index(drop=True)

summary_df_norm.to_csv(
    os.path.join(
        RESULTS_DIR,
        "thesis_figures_tables_generation",
        "1",
        "summary_normalized_embeddings.csv",
    ),
    index=False,
)
summary_df_norm

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle(f"Recall-Precision Tradeoff Analysis\n{best_embed} + {best_clf}", fontsize=16)

# Plot 1: Accuracy vs Target Recall
ax1.plot(tradeoff_analysis["target_recall"], tradeoff_analysis["accuracy_mean"], "bo-", linewidth=2, markersize=6)
ax1.set_xlabel("Target Recall")
ax1.set_ylabel("Achieved Accuracy")
ax1.set_title("Accuracy vs Target Recall")
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 1.0)
ax1.axvline(x=0.95, color="red", linestyle="--", alpha=0.7, label="95% Target")
ax1.legend()

# Plot 2: F1 Score vs Target Recall
ax2.plot(tradeoff_analysis["target_recall"], tradeoff_analysis["f1_mean"], "go-", linewidth=2, markersize=6)
ax2.set_xlabel("Target Recall")
ax2.set_ylabel("F1 Score")
ax2.set_title("F1 Score vs Target Recall")
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 1.0)
ax2.axvline(x=0.95, color="red", linestyle="--", alpha=0.7, label="95% Target")
ax2.legend()

# Plot 3: Actual vs Target Recall
ax3.plot(tradeoff_analysis["target_recall"], tradeoff_analysis["actual_recall_mean"], "ro-", linewidth=2, markersize=6)
ax3.plot([tradeoff_analysis.target_recall.min(), 0.99], [tradeoff_analysis.target_recall.min(), 0.99], "k--", alpha=0.5, label="Perfect Match")
ax3.set_xlabel("Target Recall")
ax3.set_ylabel("Achieved Recall")
ax3.set_title("Achieved vs Target Recall")
ax3.grid(True, alpha=0.3)
ax3.legend()

# Plot 4: Specificity vs Target Recall
ax4.plot(tradeoff_analysis["target_recall"], tradeoff_analysis["specificity_mean"], "mo-", linewidth=2, markersize=6)
ax4.set_xlabel("Target Recall")
ax4.set_ylabel("Achieved Specificity")
ax4.set_title("Specificity vs Target Recall")
ax4.grid(True, alpha=0.3)
ax4.axvline(x=0.95, color="red", linestyle="--", alpha=0.7, label="95% Target")
ax4.legend()

# format x-axis ticks to show two decimals
for ax in (ax1, ax2, ax3, ax4):
    ax.xaxis.set_major_locator(mticker.MultipleLocator(0.01))
    ax.xaxis.set_major_formatter(mticker.FormatStrFormatter("%.2f"))

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()