# Adaptive Re-Ranking (Logistic Regression Gating)

Given query-level statistics from the full re-ranking pipeline, we train LR models to decide **which queries to re-rank** under a compute budget.

- Feature extraction from inlier statistics (top-1 vs best in top-$K$).
- Two LR objectives: **hardness** vs **benefit**.
- Budgeted trade-off curves (Recall@1 vs fraction reranked) and summary tables (e.g., $b=20\%$).


In [None]:
!pip -q install numpy pandas scikit-learn tqdm matplotlib

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
PROJECT_ROOT = Path("/content/drive/MyDrive/VPR_Results")  # change if needed
LOGS_ROOT = PROJECT_ROOT / "logs_csv"

OUT_ROOT = PROJECT_ROOT / "results_6_1_adaptive_updated"
OUT_ROOT.mkdir(parents=True, exist_ok=True)

print("LOGS_ROOT:", LOGS_ROOT)
print("OUT_ROOT:", OUT_ROOT)


LOGS_ROOT: /content/drive/MyDrive/VPR_Results/logs_csv
OUT_ROOT: /content/drive/MyDrive/VPR_Results/results_6_1_adaptive_updated


## Time model (per-query latencies)

We compute **average latency** and **saving** under a reranking budget using the measured per-query times:
- **NetVLAD** (NVIDIA T4): $t_{global}=0.105$s, $t_{SP{+}LG}=3.43$s, $t_{SG}=1.44$s, $t_{LoFTR}=3.63$s.
- **MixVPR** (Apple M4 Pro): $t_{global}=0.245$s, $t_{SP{+}LG}=4.58$s, $t_{SG}=9.26$s, $t_{LoFTR}=3.28$s.

Savings are computed relative to the **always rerank** upper bound ($t_{global}+t_{rerank}$).


## Discover runs

Expected layout:
logs/<VPR>/<split>/<dataset>/
  preds_superpoint-lg/   (torch files)
  preds_loftr/           (torch files)
  stats_preds_superpoint-lg.csv  (or similar)
  stats_preds_loftr.csv

In [None]:
def find_runs(logs_root: Path):
    runs = []
    for vpr_dir in sorted([p for p in logs_root.iterdir() if p.is_dir()]):
        for split_dir in sorted([p for p in vpr_dir.iterdir() if p.is_dir()]):
            for ds_dir in sorted([p for p in split_dir.iterdir() if p.is_dir()]):
                runs.append({"vpr": vpr_dir.name, "split": split_dir.name, "dataset": ds_dir.name, "run_dir": ds_dir})
    return runs

runs = find_runs(LOGS_ROOT)
print("Found runs:", len(runs))
runs[:5]

Found runs: 14


[{'vpr': 'MixVPR',
  'split': 'test',
  'dataset': 'sf_xs',
  'run_dir': PosixPath('/content/drive/MyDrive/VPR_Results/logs_csv/MixVPR/test/sf_xs')},
 {'vpr': 'MixVPR',
  'split': 'test',
  'dataset': 'svox_night',
  'run_dir': PosixPath('/content/drive/MyDrive/VPR_Results/logs_csv/MixVPR/test/svox_night')},
 {'vpr': 'MixVPR',
  'split': 'test',
  'dataset': 'svox_sun',
  'run_dir': PosixPath('/content/drive/MyDrive/VPR_Results/logs_csv/MixVPR/test/svox_sun')},
 {'vpr': 'MixVPR',
  'split': 'test',
  'dataset': 'tokyo_xs',
  'run_dir': PosixPath('/content/drive/MyDrive/VPR_Results/logs_csv/MixVPR/test/tokyo_xs')},
 {'vpr': 'MixVPR',
  'split': 'train',
  'dataset': 'svox_night',
  'run_dir': PosixPath('/content/drive/MyDrive/VPR_Results/logs_csv/MixVPR/train/svox_night')}]

## Find updated stats CSVs

We will search for files matching `stats_*.csv` inside each run folder.
Then infer matcher based on filename.

In [None]:
def infer_im_method(name: str):
    n = name.lower()
    if "loftr" in n:
        return "loftr"
    if "superpoint" in n or ("sp" in n and "lg" in n):
        return "superpoint-lg"
    return "unknown"

def find_stats_csvs(run_dir: Path):
    candidates = list(run_dir.rglob("stats_*.csv"))
    return candidates

rows = []
for r in runs:
    for p in find_stats_csvs(r["run_dir"]):
        im = infer_im_method(p.name)
        if im == "unknown":
            continue
        rows.append({
            "vpr": r["vpr"], "split": r["split"], "dataset": r["dataset"],
            "im": im, "csv_path": str(p)
        })

index_df = pd.DataFrame(rows).sort_values(["vpr","split","dataset","im"])
print("Found stats csv:", len(index_df))
index_df.head(20)

Found stats csv: 28


Unnamed: 0,vpr,split,dataset,im,csv_path
0,MixVPR,test,sf_xs,loftr,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
1,MixVPR,test,sf_xs,superpoint-lg,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
2,MixVPR,test,svox_night,loftr,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
3,MixVPR,test,svox_night,superpoint-lg,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
4,MixVPR,test,svox_sun,loftr,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
5,MixVPR,test,svox_sun,superpoint-lg,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
6,MixVPR,test,tokyo_xs,loftr,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
7,MixVPR,test,tokyo_xs,superpoint-lg,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
8,MixVPR,train,svox_night,loftr,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...
9,MixVPR,train,svox_night,superpoint-lg,/content/drive/MyDrive/VPR_Results/logs_csv/Mi...


## Load and validate updated CSV

We expect the new schema:
query_id, inliers_rank0, is_correct_rank0, max_inliers, is_correct_final

In [None]:
REQUIRED_COLS = ["query_id", "inliers_rank0", "is_correct_rank0", "max_inliers", "is_correct_final"]

def load_stats_csv(path: Path):
    df = pd.read_csv(path, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]
    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns {missing} in {path}")
    # types
    df["query_id"] = pd.to_numeric(df["query_id"], errors="coerce").astype(int)
    df["inliers_rank0"] = pd.to_numeric(df["inliers_rank0"], errors="coerce").fillna(0.0).astype(float)
    df["max_inliers"] = pd.to_numeric(df["max_inliers"], errors="coerce").fillna(0.0).astype(float)
    df["is_correct_rank0"] = pd.to_numeric(df["is_correct_rank0"], errors="coerce").fillna(0).astype(int)
    df["is_correct_final"] = pd.to_numeric(df["is_correct_final"], errors="coerce").fillna(0).astype(int)

    # derived: benefit of reranking
    df["benefit"] = ((df["is_correct_rank0"] == 0) & (df["is_correct_final"] == 1)).astype(int)
    df["harm"] = ((df["is_correct_rank0"] == 1) & (df["is_correct_final"] == 0)).astype(int)
    return df.sort_values("query_id").reset_index(drop=True)

# quick test
if len(index_df):
    test_path = Path(index_df.iloc[0]["csv_path"])
    dft = load_stats_csv(test_path)
    dft.head()


## Logistic Regression

We implement two LR tasks:

### Task 1 (Hardness classifier)
Predict if query is **hard** (baseline retrieval fails):
- y_hard = 1 if is_correct_rank0 == 0 else 0

### Task 2 (Benefit classifier) [recommended]
Predict if reranking is **useful**:
- y_benefit = 1 if (rank0 wrong) and (final correct)

Features:
- inliers_rank0
- max_inliers
- gap = max_inliers - inliers_rank0
- log(1+features)

In [None]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score

def make_features(df: pd.DataFrame):
    x1 = df["inliers_rank0"].to_numpy(dtype=np.float32)
    x2 = df["max_inliers"].to_numpy(dtype=np.float32)
    gap = (x2 - x1).astype(np.float32)

    X = np.stack([x1, x2, gap], axis=1)
    Xlog = np.log1p(X)
    Xall = np.concatenate([X, Xlog], axis=1)
    feat_names = ["inliers_rank0", "max_inliers", "gap", "log_inliers_rank0", "log_max_inliers", "log_gap"]
    return Xall, feat_names

def fit_lr(X, y, C=1.0):
    model = Pipeline([
        ("scaler", StandardScaler()),
        ("lr", LogisticRegression(C=C, class_weight="balanced", max_iter=2000))
    ])
    model.fit(X, y)
    return model

def lr_metrics(y, p):
    out = {
        "acc@0.5": float(accuracy_score(y, (p >= 0.5).astype(int))),
        "roc_auc": float(roc_auc_score(y, p)) if len(np.unique(y)) > 1 else np.nan,
        "auprc": float(average_precision_score(y, p)) if len(np.unique(y)) > 1 else np.nan
    }
    return out

In [None]:
def collect_split(vpr, im, split):
    sub = index_df[(index_df["vpr"] == vpr) & (index_df["im"] == im) & (index_df["split"].str.lower() == split.lower())]
    dfs = []
    for _, r in sub.iterrows():
        if "gsv" in r["dataset"].lower():
            continue
        df = load_stats_csv(Path(r["csv_path"]))
        df["dataset"] = r["dataset"]
        dfs.append(df)
    if not dfs:
        return None
    return pd.concat(dfs, ignore_index=True)

available_vprs = sorted(index_df["vpr"].unique())
print("VPRs:", available_vprs)

In [None]:
lr_dir = OUT_ROOT / "lr_models"
lr_dir.mkdir(parents=True, exist_ok=True)

models = {}  # (vpr,im,task) -> dict

Cs = [0.1, 0.3, 1.0, 3.0, 10.0]
rows = []

for vpr in available_vprs:
    for im in ["superpoint-lg", "loftr"]:
        train_df = collect_split(vpr, im, "train")
        val_df = collect_split(vpr, im, "val")
        if train_df is None or val_df is None:
            continue

        Xtr, feat_names = make_features(train_df)
        Xva, _ = make_features(val_df)

        # Task1: hard = not correct at rank0
        ytr_h = (train_df["is_correct_rank0"] == 0).astype(int).to_numpy()
        yva_h = (val_df["is_correct_rank0"] == 0).astype(int).to_numpy()

        # Task2: benefit
        ytr_b = train_df["benefit"].to_numpy()
        yva_b = val_df["benefit"].to_numpy()

        for task_name, ytr, yva in [("hard", ytr_h, yva_h), ("benefit", ytr_b, yva_b)]:
            best = None
            for C in Cs:
                m = fit_lr(Xtr, ytr, C=C)
                p = m.predict_proba(Xva)[:,1]
                met = lr_metrics(yva, p)
                score = met["auprc"]
                if best is None or score > best[0]:
                    best = (score, C, met, m)

            score, Cbest, metbest, m_best = best
            models[(vpr, im, task_name)] = {"model": m_best, "C": Cbest, "feat_names": feat_names, "val_metrics": metbest}

            rows.append({
                "vpr": vpr, "im": im, "task": task_name,
                "C_best": Cbest,
                "val_auprc": metbest["auprc"],
                "val_roc_auc": metbest["roc_auc"],
                "val_acc@0.5": metbest["acc@0.5"],
            })

pd.DataFrame(rows).sort_values(["vpr","im","task"])

Unnamed: 0,vpr,im,task,C_best,val_auprc,val_roc_auc,val_acc@0.5
3,MixVPR,loftr,benefit,0.1,0.761468,0.974416,0.942
2,MixVPR,loftr,hard,0.3,0.661566,0.941714,0.892
1,MixVPR,superpoint-lg,benefit,0.1,0.876636,0.982431,0.948
0,MixVPR,superpoint-lg,hard,1.0,0.588385,0.936171,0.893
7,NetVLAD,loftr,benefit,0.1,0.920467,0.967032,0.895
6,NetVLAD,loftr,hard,10.0,0.961529,0.972254,0.865
5,NetVLAD,superpoint-lg,benefit,0.3,0.942782,0.973969,0.922
4,NetVLAD,superpoint-lg,hard,0.1,0.948911,0.970058,0.92


## LR gating curves on TEST

Policy:
- rerank if p(task=hard) > τ_prob  (or p(benefit)>τ_prob)

Final correctness:
- if reranked -> is_correct_final
- else -> is_correct_rank0

We will output curves and summary budgets.

In [None]:
# --- Time model (per-query latencies) used to compute savings ---
# NOTE: values are provided by the report (baseline global + full top-20 re-ranking latencies).
# NetVLAD is measured on NVIDIA T4; MixVPR is measured on Apple M4 Pro.
TIME_GLOBAL = {
    "NetVLAD": 0.105,
    "MixVPR":  0.245,
}

TIME_RERANK = {
    # full top-20 local verification (seconds/query)
    "superpoint-lg": {"NetVLAD": 3.43, "MixVPR": 4.58},
    "superglue":     {"NetVLAD": 1.44, "MixVPR": 9.26},
    "loftr":         {"NetVLAD": 3.63, "MixVPR": 3.28},
}

def lr_gating_curve(df_split, proba, vpr, im, thresholds=None):
    """Compute Recall@1 vs reranking fraction under LR gating.

    - proba: predicted probability for the positive class (hard or benefit).
    - vpr, im: used for the time model.
    - thresholds: list of probability thresholds; if None, a default grid is used.
    """
    if thresholds is None:
        thresholds = np.linspace(0, 1, 101)

    t_global = TIME_GLOBAL[vpr]
    t_rerank = TIME_RERANK[im][vpr]
    t_always = t_global + t_rerank

    rows = []
    for thr in thresholds:
        do_rerank = proba >= thr
        frac = float(do_rerank.mean())

        # Final correctness: if reranked -> is_correct_final else is_correct_rank0
        final_correct = np.where(do_rerank, df_split["is_correct_final"].values, df_split["is_correct_rank0"].values)
        r1 = 100.0 * float(final_correct.mean())

        # Average time and saving (relative to always rerank)
        t_avg = t_global + frac * t_rerank
        saving = 100.0 * (1.0 - t_avg / t_always)

        rows.append({
            "thr_prob": float(thr),
            "frac_reranked": frac,
            "R@1_adaptive": r1,
            "time_avg_s": t_avg,
            "saving_%": saving,
        })
    return pd.DataFrame(rows).sort_values("frac_reranked").reset_index(drop=True)


## Evaluate LR gating under a compute budget

We compute, for each (VPR, dataset, matcher) configuration:
- the **trade-off curve** (Recall@1 vs reranked fraction) for **LR-hardness** and **LR-benefit**;
- the **best operating point** under a fixed budget (default: $b=20\%$);
- macro-averages across datasets.


In [None]:
BUDGET = 0.20  # 20%
DATASET_ORDER = ["sf_xs", "tokyo_xs", "svox_sun", "svox_night"]
VPRS_TO_USE = ["NetVLAD", "MixVPR"]
IMS_TO_USE  = ["loftr", "superpoint-lg"]
TASKS = ["hard", "benefit"]

def best_under_budget(curve_df, b=BUDGET):
    sub = curve_df[curve_df["frac_reranked"] <= b]
    if len(sub) == 0:
        sub = curve_df.iloc[[0]]
    # maximize R@1, then minimize rerank fraction
    return sub.sort_values(["R@1_adaptive", "frac_reranked"], ascending=[False, True]).iloc[0]

curves_all = []
summary_rows = []

for vpr in VPRS_TO_USE:
    for im in IMS_TO_USE:
        # collect once per (vpr,im)
        df_test = collect_split(vpr, im, "test")
        X_test = make_features(df_test)
        for task in TASKS:
            key = (vpr, im, task)
            if key not in models:
                print("Skipping missing model:", key)
                continue
            proba = models[key].predict_proba(X_test)[:, 1]

            # dataset-wise curves
            for ds in DATASET_ORDER:
                df_ds = df_test[df_test["dataset"] == ds].reset_index(drop=True)
                if len(df_ds) == 0:
                    continue
                p_ds = proba[df_test["dataset"].values == ds]

                curve = lr_gating_curve(df_ds, p_ds, vpr, im)
                curve["vpr"] = vpr
                curve["im"] = im
                curve["dataset"] = ds
                curve["task"] = task

                r_base = 100.0 * float(df_ds["is_correct_rank0"].mean())
                r_alw  = 100.0 * float(df_ds["is_correct_final"].mean())
                curve["R@1_baseline"] = r_base
                curve["R@1_always"] = r_alw

                curves_all.append(curve)

                best = best_under_budget(curve, BUDGET)
                summary_rows.append({
                    "vpr": vpr,
                    "im": im,
                    "dataset": ds,
                    "task": task,
                    "budget_%": 100*BUDGET,
                    "R@1": float(best["R@1_adaptive"]),
                    "rerank_%": 100.0 * float(best["frac_reranked"]),
                    "saving_%": float(best["saving_%"]),
                    "R@1_baseline": r_base,
                    "R@1_always": r_alw,
                })

curves_df = pd.concat(curves_all, ignore_index=True)
summary_df = pd.DataFrame(summary_rows)

# Macro-average across datasets
macro = (summary_df
         .groupby(["vpr","im","task"], as_index=False)
         .agg({"R@1":"mean","rerank_%":"mean","saving_%":"mean",
               "R@1_baseline":"mean","R@1_always":"mean"}))
macro["dataset"] = "macro"
macro["budget_%"] = 100*BUDGET
macro = macro[["vpr","im","dataset","task","budget_%","R@1","rerank_%","saving_%","R@1_baseline","R@1_always"]]

# Save artifacts
OUT_ROOT.mkdir(parents=True, exist_ok=True)
curves_df.to_csv(OUT_ROOT / "lr_curves.csv", index=False)
summary_df.to_csv(OUT_ROOT / "summary_budgets.csv", index=False)
macro.to_csv(OUT_ROOT / "summary_macro_budget20.csv", index=False)

display(summary_df.head(12))
display(macro.sort_values(["vpr","im","task"]))


In [None]:
def plot_tradeoff_lr(curves_df, vpr, dataset, im, outpath=None):
    sub = curves_df[(curves_df["vpr"]==vpr) & (curves_df["dataset"]==dataset) & (curves_df["im"]==im)]
    if len(sub)==0:
        raise ValueError(f"No curves for {(vpr,dataset,im)}")

    plt.figure(figsize=(6,4))
    # baselines
    r_base = float(sub["R@1_baseline"].iloc[0])
    r_alw  = float(sub["R@1_always"].iloc[0])
    plt.axhline(r_base, linestyle="--", linewidth=1, label="Global (base)")
    plt.axhline(r_alw,  linestyle="--", linewidth=1, label="Always rerank")

    for task in ["hard","benefit"]:
        tsub = sub[sub["task"]==task].sort_values("frac_reranked")
        plt.plot(tsub["frac_reranked"], tsub["R@1_adaptive"], label=f"LR-{task}")

    plt.xlabel("Fraction reranked")
    plt.ylabel("Final R@1 (%)")
    plt.title(f"{vpr} | {dataset} | {im}")
    plt.legend()
    plt.tight_layout()
    if outpath is not None:
        outpath = Path(outpath)
        outpath.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(outpath, dpi=200)
    plt.show()

# Save the two trade-off plots used in the report
plot_tradeoff_lr(curves_df, "NetVLAD", "sf_xs", "loftr",
                 outpath=OUT_ROOT / "plots_lr_only" / "tradeoff_NetVLAD_sf_xs_loftr.png")
plot_tradeoff_lr(curves_df, "MixVPR", "sf_xs", "superpoint-lg",
                 outpath=OUT_ROOT / "plots_lr_only" / "tradeoff_MixVPR_sf_xs_superpoint-lg.png")


In [None]:
# Optional: export a compact LaTeX table for macro-average (b=20%)
macro = pd.read_csv(OUT_ROOT / "summary_macro_budget20.csv")
macro = macro.sort_values(["vpr","im","task"])

def im_short(im):
    return {"loftr":"LoFTR", "superpoint-lg":"SP+LG", "superglue":"SG"}.get(im, im)

latex_rows = []
for _,r in macro.iterrows():
    latex_rows.append(f"{r['vpr']} & {im_short(r['im'])} & {r['task']} & {r['R@1']:.2f} & {r['rerank_%']:.1f} & {r['saving_%']:.2f} \\")
latex = "\n".join(latex_rows)
print(latex)
