# Zero Shot Prediction

In [None]:
from plonk.pipe import PlonkPipeline
from PIL import Image
import torch
import pandas as pd
import os
from tqdm import tqdm
import math

# === CONFIGURATION ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_folder = "plonk/plonk/data/split_IJGIS" 
test_csv = os.path.join(data_folder, "val_80_fixed.csv") # select 80 or 70 split
metadata_xlsx = os.path.join(data_folder, "all_20241120.xlsx")
batch_size = 128

# === LOAD DATA ===
test_df = pd.read_csv(test_csv, header=None)
metadata_df = pd.read_excel(metadata_xlsx, engine="openpyxl")
vgi_column_index = 2
metadata_df["filename"] = metadata_df["ID"].astype(str) + ".jpg"

# === INIT PIPELINE ===
pipeline = PlonkPipeline("nicolas-dufour/PLONK_OSV_5M").to(device)
# nicolas-dufour/PLONK_OSV_5M
# nicolas-dufour/PLONK_iNaturalist
# nicolas-dufour/PLONK_YFCC
# === METRICS ===
recall_at_1 = 0
recall_lat = 0
recall_lon = 0
total = 0

results = []

def round_coords(lat, lon):
    return int(round(lat)), int(round(lon))

def haversine(lat1, lon1, lat2, lon2):
    """Compute distance in km between two lat/lon pairs"""
    R = 6371.0  # Earth radius in km
    lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
    c = 2 * math.asin(math.sqrt(a))
    return R * c

# === PREPARE TASKS ===
tasks = []
gt_cells = []
gt_latlons = []
filenames = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df)):
    vgi_file = row[vgi_column_index]
    vgi_id = int(vgi_file.split(".")[0].replace("VGI", "").replace("\\", ""))
    gt_row = metadata_df[metadata_df["ID"] == vgi_id]

    if gt_row.empty:
        print(f"No metadata for {vgi_file}")
        continue

    gt_lat = gt_row["Latitude"].values[0]
    gt_lon = gt_row["Longitude"].values[0]
    gt_cell = round_coords(gt_lat, gt_lon)

    image_path = os.path.join(data_folder, vgi_file.replace("\\", "/"))
    if not os.path.isfile(image_path):
        print(f"Missing image: {image_path}")
        continue

    tasks.append(image_path)
    gt_cells.append(gt_cell)
    gt_latlons.append( (gt_lat, gt_lon) )
    filenames.append(vgi_file)

# === PROCESS IN BATCHES ===
for i in tqdm(range(0, len(tasks), batch_size), desc="Batches"):
    batch_images = []
    batch_gt_cells = gt_cells[i:i+batch_size]
    batch_gt_latlons = gt_latlons[i:i+batch_size]
    batch_filenames = filenames[i:i+batch_size]

    for path in tasks[i:i+batch_size]:
        image = Image.open(path)
        if image.mode != "RGB":
            image = image.convert("RGB")
        batch_images.append(image)

    # Run pipeline on batch â€” use predictions
    gps_coords_batch = pipeline(batch_images, batch_size=len(batch_images))  # list of (lat, lon)

    for fname, gt_cell, gt_latlon, (pred_lat, pred_lon) in zip(batch_filenames, batch_gt_cells, batch_gt_latlons, gps_coords_batch):
        pred_cell = round_coords(pred_lat, pred_lon)
        total += 1

        if pred_cell == gt_cell:
            recall_at_1 += 1

        if pred_cell[0] == gt_cell[0]:
            recall_lat += 1

        if pred_cell[1] == gt_cell[1]:
            recall_lon += 1

        # compute distance
        dist_km = haversine(gt_latlon[0], gt_latlon[1], pred_lat, pred_lon)

        results.append({
            "filename": fname,
            "gt_lat": gt_latlon[0],
            "gt_lon": gt_latlon[1],
            "pred_lat": pred_lat,
            "pred_lon": pred_lon,
            "gt_cell_lat": gt_cell[0],
            "gt_cell_lon": gt_cell[1],
            "pred_cell_lat": pred_cell[0],
            "pred_cell_lon": pred_cell[1],
            "distance_km": dist_km
        })

# === REPORT ===
df_results = pd.DataFrame(results)
mean_error = df_results["distance_km"].mean()
median_error = df_results["distance_km"].median()

print(f"\n=== RESULTS ===")
print(f"Total images:     {total}")
print(f"Recall@1:         {recall_at_1/total:.3f}")
print(f"Latitude@1:       {recall_lat/total:.3f}")
print(f"Longitude@1:      {recall_lon/total:.3f}")
print(f"Mean distance:    {mean_error:.2f} km")
print(f"Median distance:  {median_error:.2f} km")

# Save per-image results
df_results.to_csv("predictions_with_distances.csv", index=False)


# Fine-tuned model prediction

Already includes checkpoints from retrieval to test performance before and after retrieval 

In [None]:
import os
import math
from tqdm import tqdm
import torch
import pandas as pd
from PIL import Image
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np

from plonk.pipe import PlonkPipeline

# ================== CONFIG ==================
class Config:
    # Paths
    data_folder = "plonk/plonk/data/split_IJGIS"
    test_csv = os.path.join(data_folder, "val_80_fixed.csv")
    metadata_xlsx = os.path.join(data_folder, "all_20241120.xlsx")
    checkpoint_start = "ian_weights80.pth"
    plonk_weights = "iandisaster20osm"
    # Model
    model = "timm/vit_large_patch14_dinov2.lvd142m"
    img_size = 384

    # Eval
    batch_size = 128
    thresholds_km = [1, 25, 50, 200, 750, 2500]
    retrieval_threshold_km = 50.0
    recall_topk = [1, 5, 10]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config()

# ================== MODEL ==================
class TimmModel(nn.Module):
    def __init__(self, model_name, pretrained=True, img_size=383, embed_dim=1024):
        super().__init__()
        self.img_size = img_size
        self.model_name = model_name
        self.embed_dim = embed_dim
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        if "vit" in model_name.lower():
            new_width = img_size * 2
            new_height = round((512 / 1024) * new_width)
            self.img_size_wide = (new_height, new_width)
            self.model_square = timm.create_model(model_name, pretrained=pretrained, num_classes=0,
                                                  img_size=(img_size, img_size))
            self.model_wide   = timm.create_model(model_name, pretrained=pretrained, num_classes=0,
                                                  img_size=self.img_size_wide)
            self.model_uav    = timm.create_model(model_name, pretrained=pretrained, num_classes=0,
                                                  img_size=(img_size, img_size))
        elif "convnext" in model_name.lower():
            self.model_main = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
            self.model_uav  = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        else:
            self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)

    def get_config(self):
        if hasattr(self, "model"):
            return timm.data.resolve_model_data_config(self.model)
        elif hasattr(self, "model_main"):
            return timm.data.resolve_model_data_config(self.model_main)
        else:
            return timm.data.resolve_model_data_config(self.model_square)

    def forward(self, data_dict):
        out = {}
        for key, img in data_dict.items():
            out[key] = self._forward_single(img, key)
        return out

    def _forward_single(self, x, key):
        if hasattr(self, "model"):
            return self.model(x)
        if "vit" in self.model_name.lower():
            if key == "SVI":
                return self.model_wide(x)
            elif key == "UAV":
                return self.model_uav(x)
            elif key in ["RSI", "VGI"]:
                return self.model_square(x)
            else:
                raise ValueError(f"Unknown view type '{key}' for ViT model.")
        elif "convnext" in self.model_name.lower():
            if key == "UAV":
                return self.model_uav(x)
            else:
                return self.model_main(x)
        else:
            raise ValueError(f"Unknown model type in '{self.model_name}'")

# ================== HELPERS ==================
def haversine(lat1, lon1, lat2, lon2):
    R = 6371.0
    lat1, lon1, lat2, lon2 = map(math.radians, [float(lat1), float(lon1), float(lat2), float(lon2)])
    dlat, dlon = lat2 - lat1, lon2 - lon1
    a = math.sin(dlat/2)**2 + math.cos(lat1)*math.cos(lat2)*math.sin(dlon/2)**2
    return 2 * R * math.asin(math.sqrt(a))

def haversine_vec(lat, lon, lat_vec, lon_vec):
    R = 6371.0
    lat = math.radians(float(lat)); lon = math.radians(float(lon))
    lat2 = torch.deg2rad(lat_vec.float())
    lon2 = torch.deg2rad(lon_vec.float())
    dlat = lat2 - lat
    dlon = lon2 - lon
    a = torch.sin(dlat/2)**2 + torch.cos(torch.tensor(lat)) * torch.cos(lat2) * torch.sin(dlon/2)**2
    return 2 * R * torch.asin(torch.sqrt(a))

def _id_from_relpath(rel_path: str) -> int:
    name = os.path.basename(rel_path.replace("\\", "/"))
    return int(os.path.splitext(name)[0])

def _abs_path(rel_path: str) -> str:
    return os.path.join(config.data_folder, rel_path.replace("\\", "/"))

# ================== LOAD METADATA ==================
test_df = pd.read_csv(config.test_csv, header=None)
metadata_df = pd.read_excel(config.metadata_xlsx)
id2lat = dict(zip(metadata_df["ID"].astype(int).tolist(), metadata_df["Latitude"].tolist()))
id2lon = dict(zip(metadata_df["ID"].astype(int).tolist(), metadata_df["Longitude"].tolist()))

# ================== LOAD MODEL ==================
retrieval_model = TimmModel(config.model, pretrained=True, img_size=config.img_size)
if config.checkpoint_start is not None:
    print("Start from:", config.checkpoint_start)
    sd = torch.load(config.checkpoint_start, map_location=device)
    retrieval_model.load_state_dict(sd, strict=False)
retrieval_model = retrieval_model.to(device).eval()

dc = retrieval_model.get_config()
mean, std = dc["mean"], dc["std"]
transform = T.Compose([
    T.Resize((config.img_size, config.img_size)),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std),
])

# ================== BUILD RSI EMBEDDING DB ==================
print("Build RSI Embedding DB")
rsi_paths, rsi_ids, rsi_lats, rsi_lons = [], [], [], []
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    rsi_rel = row[0]
    rid = _id_from_relpath(rsi_rel)
    if rid not in id2lat or rid not in id2lon: continue
    path = _abs_path(rsi_rel)
    if not os.path.isfile(path): continue
    rsi_paths.append(path)
    rsi_ids.append(rid)
    rsi_lats.append(id2lat[rid])
    rsi_lons.append(id2lon[rid])

rsi_feats = []
for j in tqdm(range(0, len(rsi_paths), config.batch_size), desc="RSI-Embeddings"):
    imgs = []
    for p in rsi_paths[j:j+config.batch_size]:
        img = Image.open(p).convert("RGB")
        imgs.append(transform(img))
    if len(imgs) == 0: continue
    batch = torch.stack(imgs).to(device)
    with torch.no_grad():
        feats = retrieval_model({"RSI": batch})["RSI"]
        feats = F.normalize(feats, dim=-1)
    rsi_feats.append(feats.cpu())
rsi_feats = torch.cat(rsi_feats, dim=0) if len(rsi_feats) else torch.empty(0, 1)
rsi_lats_t = torch.tensor(rsi_lats)
rsi_lons_t = torch.tensor(rsi_lons)
rsi_ids_t  = torch.tensor(rsi_ids)
print(f"RSI DB size: {len(rsi_paths)}")

# ================== BUILD VGI LIST ==================
vgi_items = []
for _, row in test_df.iterrows():
    vgi_rel = row[2]
    vid = _id_from_relpath(vgi_rel)
    path = _abs_path(vgi_rel)
    if not os.path.isfile(path): continue
    if vid not in id2lat or vid not in id2lon: continue
    vgi_items.append({
        "id": vid,
        "path": path,
        "gt_lat": id2lat[vid],
        "gt_lon": id2lon[vid],
    })

# ================== EVAL PIPELINE ==================
pipeline = PlonkPipeline(config.plonk_weights).to(device)

acc_counts_before = {thr: 0 for thr in config.thresholds_km}
acc_counts_after  = {thr: 0 for thr in config.thresholds_km}
acc_counts_retrieval = {thr: 0 for thr in config.thresholds_km}

recall_counts_before = {k: 0 for k in config.recall_topk}
recall_counts_after  = {k: 0 for k in config.recall_topk}

results_before, results_after, results_retrieval = [], [], []
total = 0

for i in tqdm(range(0, len(vgi_items), config.batch_size), desc="Batches"):
    batch = vgi_items[i:i+config.batch_size]
    batch_imgs = [Image.open(x["path"]).convert("RGB") for x in batch]
    plonk_preds = pipeline(batch_imgs, batch_size=len(batch_imgs))

    refined_coords = []
    for x, (pred_lat, pred_lon), img in zip(batch, plonk_preds, batch_imgs):
        total += 1
        # ---- Accuracy BEFORE ----
        dist_before = haversine(x["gt_lat"], x["gt_lon"], pred_lat, pred_lon)
        within_before = {thr: (dist_before <= thr) for thr in config.thresholds_km}
        for thr, ok in within_before.items():
            if ok: acc_counts_before[thr] += 1
        results_before.append({
            "filename": os.path.relpath(x["path"], config.data_folder),
            "gt_lat": x["gt_lat"], "gt_lon": x["gt_lon"],
            "pred_lat": pred_lat, "pred_lon": pred_lon,
            "distance_km": dist_before,
            **{f"within_{thr}km": within_before[thr] for thr in config.thresholds_km},
        })

        # ---- Build VGI embedding ----
        vgi_tensor = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            vgi_feat = retrieval_model({"VGI": vgi_tensor})["VGI"]
            vgi_feat = F.normalize(vgi_feat, dim=-1).squeeze(0).cpu()  # shape [D]

        # ---- GLOBAL Retrieval (BEFORE restriction) ----
        # Compute similarity against the full RSI DB (no geo filter)
        if rsi_feats.numel() > 0:
            sims_all = torch.mv(rsi_feats, vgi_feat)                 # [N_rsi]
            sorted_all = torch.argsort(sims_all, descending=True)    # indices into full DB
            all_ids = rsi_ids_t                                      # [N_rsi]
            # Recall BEFORE restriction
            for k in config.recall_topk:
                top_ids_all = all_ids[sorted_all[:k]].tolist()
                hit_all = (x["id"] in top_ids_all)
                recall_counts_before[k] += int(hit_all)
        else:
            # No RSI DB => no recall possible
            pass

        # ---- Retrieval Candidates (AFTER applies restriction) ----
        refined_lat, refined_lon = pred_lat, pred_lon
        dists = haversine_vec(pred_lat, pred_lon, rsi_lats_t, rsi_lons_t) if rsi_feats.numel() > 0 else None
        cand_mask = (dists <= config.retrieval_threshold_km) if dists is not None else torch.tensor(False)

        if rsi_feats.numel() > 0 and cand_mask.any():
            cand_idx = torch.nonzero(cand_mask, as_tuple=False).squeeze(1)   # indices into full DB
            cand_feats = rsi_feats.index_select(0, cand_idx)                 # [Nc, D]
            sims_cand = torch.mv(cand_feats, vgi_feat)                        # [Nc]
            sorted_cand = torch.argsort(sims_cand, descending=True)           # indices into cand subset
            cand_ids = rsi_ids_t[cand_idx]                                    # [Nc]

            # ---- Recall AFTER restriction ----
            for k in config.recall_topk:
                top_ids_cand = cand_ids[sorted_cand[:k]].tolist()
                hit_cand = (x["id"] in top_ids_cand)
                recall_counts_after[k] += int(hit_cand)

            # ---- Refinement (Top-1 from restricted pool) ----
            best_local_in_cand = cand_idx[torch.argmax(sims_cand)].item()     # index into full DB
            refined_lat = float(rsi_lats_t[best_local_in_cand])
            refined_lon = float(rsi_lons_t[best_local_in_cand])

            # ---- Accuracy RETRIEVAL (pure Top-1 from restricted pool) ----
            dist_retrieval = haversine(
                x["gt_lat"], x["gt_lon"], refined_lat, refined_lon
            )
            within_retrieval = {thr: (dist_retrieval <= thr) for thr in config.thresholds_km}
            for thr, ok in within_retrieval.items():
                if ok: acc_counts_retrieval[thr] += 1
            results_retrieval.append({
                "filename": os.path.relpath(x["path"], config.data_folder),
                "gt_lat": x["gt_lat"], "gt_lon": x["gt_lon"],
                "pred_lat": refined_lat, "pred_lon": refined_lon,
                "distance_km": dist_retrieval,
                **{f"within_{thr}km": within_retrieval[thr] for thr in config.thresholds_km},
            })
        else:
            # No candidates in radius (or empty DB): AFTER recall contributes 0; keep Plonk coords for "after" accuracy
            pass

        refined_coords.append((refined_lat, refined_lon))


    # ---- Accuracy AFTER ----
    for x, (pred_lat, pred_lon) in zip(batch, refined_coords):
        dist_after = haversine(x["gt_lat"], x["gt_lon"], pred_lat, pred_lon)
        within_after = {thr: (dist_after <= thr) for thr in config.thresholds_km}
        for thr, ok in within_after.items():
            if ok: acc_counts_after[thr] += 1
        results_after.append({
            "filename": os.path.relpath(x["path"], config.data_folder),
            "gt_lat": x["gt_lat"], "gt_lon": x["gt_lon"],
            "pred_lat": pred_lat, "pred_lon": pred_lon,
            "distance_km": dist_after,
            **{f"within_{thr}km": within_after[thr] for thr in config.thresholds_km},
        })

# ================== REPORT ==================
df_before = pd.DataFrame(results_before)
df_after  = pd.DataFrame(results_after)
df_retrieval = pd.DataFrame(results_retrieval)

if total == 0:
    print("No predictions to report.")
else:
    print("\n=== RESULTS BEFORE (Plonk) ===")
    print(f"Total images: {total}")
    for thr in config.thresholds_km:
        print(f"Accuracy @ {thr:>4} km: {acc_counts_before[thr]/total:.3f}")
    print(f"Mean distance:   {df_before['distance_km'].mean():.2f} km")
    print(f"Median distance: {df_before['distance_km'].median():.2f} km")

    print("\n=== RESULTS AFTER (Plonk + Refinement) ===")
    for thr in config.thresholds_km:
        print(f"Accuracy @ {thr:>4} km: {acc_counts_after[thr]/total:.3f}")
    print(f"Mean distance:   {df_after['distance_km'].mean():.2f} km")
    print(f"Median distance: {df_after['distance_km'].median():.2f} km")

    print("\n=== RESULTS RETRIEVAL (reines Top-1 Matching) ===")
    for thr in config.thresholds_km:
        print(f"Accuracy @ {thr:>4} km: {acc_counts_retrieval[thr]/total:.3f}")
    print(f"Mean distance:   {df_retrieval['distance_km'].mean():.2f} km")
    print(f"Median distance: {df_retrieval['distance_km'].median():.2f} km")

    print("\n=== RECALL @ K (BEFORE Refinement) ===")
    for k in config.recall_topk:
        print(f"Recall@{k}: {recall_counts_before[k]/total:.3f}")

    print("\n=== RECALL @ K (AFTER Refinement) ===")
    for k in config.recall_topk:
        print(f"Recall@{k}: {recall_counts_after[k]/total:.3f}")

    # ---- Global mean distances across all samples ----
    mean_dist_before    = df_before["distance_km"].mean()
    mean_dist_after     = df_after["distance_km"].mean()
    mean_dist_retrieval = df_retrieval["distance_km"].mean()
    print("\n=== GLOBAL MEAN DISTANCES ===")
    print(f"Generative only (Plonk):         {mean_dist_before:.2f} km")
    print(f"Retrieval only (Top-1):          {mean_dist_retrieval:.2f} km")
    print(f"Combined (Plonk + Refinement):   {mean_dist_after:.2f} km")

    df_before.to_csv("predictions_before.csv", index=False)
    df_after.to_csv("predictions_after.csv", index=False)
    df_retrieval.to_csv("predictions_retrieval.csv", index=False)


# Fine-tuned model prediction with multiple thresholds

Already includes checkpoints from retrieval to test performance before and after retrieval 

In [None]:
import os
import math
from tqdm import tqdm
import torch
import pandas as pd
from PIL import Image
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np

from plonk.pipe import PlonkPipeline

# ================== CONFIG ==================
class Config:
    # Paths
    data_folder = "plonk/plonk/data/split_IJGIS"
    test_csv = os.path.join(data_folder, "val_80_fixed.csv")
    metadata_xlsx = os.path.join(data_folder, "all_20241120.xlsx")
    checkpoint_start = "ian_weights80.pth"
    plonk_weights = "iandisaster20osm"
    # Model
    model = "timm/vit_large_patch14_dinov2.lvd142m"
    img_size = 384

    # Eval
    batch_size = 128
    thresholds_km = [1, 25, 50, 200, 750, 2500]
    retrieval_threshold_km = 50.0
    recall_topk = [1, 5, 10]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config()

# ================== MODEL ==================
class TimmModel(nn.Module):
    def __init__(self, model_name, pretrained=True, img_size=383, embed_dim=1024):
        super().__init__()
        self.img_size = img_size
        self.model_name = model_name
        self.embed_dim = embed_dim
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        if "vit" in model_name.lower():
            new_width = img_size * 2
            new_height = round((512 / 1024) * new_width)
            self.img_size_wide = (new_height, new_width)
            self.model_square = timm.create_model(model_name, pretrained=pretrained, num_classes=0,
                                                  img_size=(img_size, img_size))
            self.model_wide   = timm.create_model(model_name, pretrained=pretrained, num_classes=0,
                                                  img_size=self.img_size_wide)
            self.model_uav    = timm.create_model(model_name, pretrained=pretrained, num_classes=0,
                                                  img_size=(img_size, img_size))
        elif "convnext" in model_name.lower():
            self.model_main = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
            self.model_uav  = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        else:
            self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)

    def get_config(self):
        if hasattr(self, "model"):
            return timm.data.resolve_model_data_config(self.model)
        elif hasattr(self, "model_main"):
            return timm.data.resolve_model_data_config(self.model_main)
        else:
            return timm.data.resolve_model_data_config(self.model_square)

    def forward(self, data_dict):
        out = {}
        for key, img in data_dict.items():
            out[key] = self._forward_single(img, key)
        return out

    def _forward_single(self, x, key):
        if hasattr(self, "model"):
            return self.model(x)
        if "vit" in self.model_name.lower():
            if key == "SVI":
                return self.model_wide(x)
            elif key == "UAV":
                return self.model_uav(x)
            elif key in ["RSI", "VGI"]:
                return self.model_square(x)
            else:
                raise ValueError(f"Unknown view type '{key}' for ViT model.")
        elif "convnext" in self.model_name.lower():
            if key == "UAV":
                return self.model_uav(x)
            else:
                return self.model_main(x)
        else:
            raise ValueError(f"Unknown model type in '{self.model_name}'")

# ================== HELPERS ==================
def haversine(lat1, lon1, lat2, lon2):
    R = 6371.0
    lat1, lon1, lat2, lon2 = map(math.radians, [float(lat1), float(lon1), float(lat2), float(lon2)])
    dlat, dlon = lat2 - lat1, lon2 - lon1
    a = math.sin(dlat/2)**2 + math.cos(lat1)*math.cos(lat2)*math.sin(dlon/2)**2
    return 2 * R * math.asin(math.sqrt(a))

def haversine_vec(lat, lon, lat_vec, lon_vec):
    R = 6371.0
    lat = math.radians(float(lat)); lon = math.radians(float(lon))
    lat2 = torch.deg2rad(lat_vec.float())
    lon2 = torch.deg2rad(lon_vec.float())
    dlat = lat2 - lat
    dlon = lon2 - lon
    a = torch.sin(dlat/2)**2 + torch.cos(torch.tensor(lat)) * torch.cos(lat2) * torch.sin(dlon/2)**2
    return 2 * R * torch.asin(torch.sqrt(a))

def _id_from_relpath(rel_path: str) -> int:
    name = os.path.basename(rel_path.replace("\\", "/"))
    return int(os.path.splitext(name)[0])

def _abs_path(rel_path: str) -> str:
    return os.path.join(config.data_folder, rel_path.replace("\\", "/"))

# ================== LOAD METADATA ==================
test_df = pd.read_csv(config.test_csv, header=None)
metadata_df = pd.read_excel(config.metadata_xlsx)
id2lat = dict(zip(metadata_df["ID"].astype(int).tolist(), metadata_df["Latitude"].tolist()))
id2lon = dict(zip(metadata_df["ID"].astype(int).tolist(), metadata_df["Longitude"].tolist()))

# ================== LOAD MODEL ==================
retrieval_model = TimmModel(config.model, pretrained=True, img_size=config.img_size)
if config.checkpoint_start is not None:
    print("Start from:", config.checkpoint_start)
    sd = torch.load(config.checkpoint_start, map_location=device)
    retrieval_model.load_state_dict(sd, strict=False)
retrieval_model = retrieval_model.to(device).eval()

dc = retrieval_model.get_config()
mean, std = dc["mean"], dc["std"]
transform = T.Compose([
    T.Resize((config.img_size, config.img_size)),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std),
])

# ================== BUILD RSI EMBEDDING DB ==================
print("Build RSI Embedding DB")
rsi_paths, rsi_ids, rsi_lats, rsi_lons = [], [], [], []
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    rsi_rel = row[0]
    rid = _id_from_relpath(rsi_rel)
    if rid not in id2lat or rid not in id2lon: continue
    path = _abs_path(rsi_rel)
    if not os.path.isfile(path): continue
    rsi_paths.append(path)
    rsi_ids.append(rid)
    rsi_lats.append(id2lat[rid])
    rsi_lons.append(id2lon[rid])

rsi_feats = []
for j in tqdm(range(0, len(rsi_paths), config.batch_size), desc="RSI-Embeddings"):
    imgs = []
    for p in rsi_paths[j:j+config.batch_size]:
        img = Image.open(p).convert("RGB")
        imgs.append(transform(img))
    if len(imgs) == 0: continue
    batch = torch.stack(imgs).to(device)
    with torch.no_grad():
        feats = retrieval_model({"RSI": batch})["RSI"]
        feats = F.normalize(feats, dim=-1)
    rsi_feats.append(feats.cpu())
rsi_feats = torch.cat(rsi_feats, dim=0) if len(rsi_feats) else torch.empty(0, 1)
rsi_lats_t = torch.tensor(rsi_lats)
rsi_lons_t = torch.tensor(rsi_lons)
rsi_ids_t  = torch.tensor(rsi_ids)
print(f"RSI DB size: {len(rsi_paths)}")

# ================== BUILD VGI LIST ==================
vgi_items = []
for _, row in test_df.iterrows():
    vgi_rel = row[2]
    vid = _id_from_relpath(vgi_rel)
    path = _abs_path(vgi_rel)
    if not os.path.isfile(path): continue
    if vid not in id2lat or vid not in id2lon: continue
    vgi_items.append({
        "id": vid,
        "path": path,
        "gt_lat": id2lat[vid],
        "gt_lon": id2lon[vid],
    })

# ================== COLLECT VGI EMBEDDINGS + PREDICTIONS ==================
pipeline = PlonkPipeline(config.plonk_weights).to(device)

all_vgi_feats, all_gt_lats, all_gt_lons, all_pred_lats, all_pred_lons, all_ids = [], [], [], [], [], []

for i in tqdm(range(0, len(vgi_items), config.batch_size), desc="VGI Embeddings"):
    batch = vgi_items[i:i+config.batch_size]
    batch_imgs = [Image.open(x["path"]).convert("RGB") for x in batch]

    # Plonk predictions
    plonk_preds = pipeline(batch_imgs, batch_size=len(batch_imgs))

    # VGI embeddings
    batch_tensors = torch.stack([transform(img) for img in batch_imgs]).to(device)
    with torch.no_grad():
        vgi_feats = retrieval_model({"VGI": batch_tensors})["VGI"]
        vgi_feats = F.normalize(vgi_feats, dim=-1).cpu()

    for x, (pred_lat, pred_lon), feat in zip(batch, plonk_preds, vgi_feats):
        all_vgi_feats.append(feat)
        all_gt_lats.append(x["gt_lat"])
        all_gt_lons.append(x["gt_lon"])
        all_pred_lats.append(pred_lat)
        all_pred_lons.append(pred_lon)
        all_ids.append(x["id"])

all_vgi_feats = torch.stack(all_vgi_feats)  # [N, D]
all_gt_lats = torch.tensor(all_gt_lats)
all_gt_lons = torch.tensor(all_gt_lons)
all_pred_lats = torch.tensor(all_pred_lats)
all_pred_lons = torch.tensor(all_pred_lons)
all_ids = torch.tensor(all_ids)

# ================== EVAL FOR MULTIPLE THRESHOLDS ==================
thresholds = config.thresholds_km
results_by_thr = {}

for thr in thresholds:
    acc_counts = {t: 0 for t in thresholds}
    total = len(all_vgi_feats)
    distances = []

    for vid, vfeat, gt_lat, gt_lon, pred_lat, pred_lon in zip(
        all_ids, all_vgi_feats, all_gt_lats, all_gt_lons, all_pred_lats, all_pred_lons
    ):
        # Candidates within radius thr
        dists = haversine_vec(pred_lat.item(), pred_lon.item(), rsi_lats_t, rsi_lons_t)
        cand_mask = (dists <= thr)

        if rsi_feats.numel() > 0 and cand_mask.any():
            cand_idx = torch.nonzero(cand_mask, as_tuple=False).squeeze(1)
            cand_feats = rsi_feats.index_select(0, cand_idx)
            sims = torch.mv(cand_feats, vfeat)
            best_idx = cand_idx[torch.argmax(sims)].item()

            refined_lat = float(rsi_lats_t[best_idx])
            refined_lon = float(rsi_lons_t[best_idx])
        else:
            refined_lat, refined_lon = float(pred_lat), float(pred_lon)

        dist = haversine(gt_lat.item(), gt_lon.item(), refined_lat, refined_lon)
        distances.append(dist)
        for t in thresholds:
            if dist <= t:
                acc_counts[t] += 1

    results_by_thr[thr] = {
        "acc": {t: acc_counts[t] / total for t in thresholds},
        "mean_dist": np.mean(distances),
        "median_dist": np.median(distances),
    }

# ================== REPORT ==================
for thr in thresholds:
    print(f"\n=== Threshold {thr} km ===")
    for t in thresholds:
        print(f"Accuracy @ {t:>4} km: {results_by_thr[thr]['acc'][t]:.3f}")
    print(f"Mean distance:   {results_by_thr[thr]['mean_dist']:.2f} km")
    print(f"Median distance: {results_by_thr[thr]['median_dist']:.2f} km")
