In [1]:
import sys

sys.path.append("..")
from src import metrics
from src import constant
from src.utils import get_device, set_seed, haversine
from src.datasets.mp16 import MP16Dataset, collate_fn
from torch.utils.data import DataLoader
from tqdm import tqdm
import polars as pl
from src.pipeline.feature_extractor import FeatureExtractor
from src.eval_s4 import merge_responses
from qdrant_client import QdrantClient, models

import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  from pkg_resources import packaging


In [2]:
set_seed(42)
device = get_device()

In [3]:
extractor = FeatureExtractor()

df_ref = pl.read_csv("../datasets/mp16-reason-train.csv")
df_test = pl.read_csv("../datasets/mp16-reason-test.csv")
dataset = MP16Dataset(
    # df_test,
    df_ref,
    img_col="IMG_ID",
    img_base_path="../datasets/mp16-reason",
)
loader = DataLoader(
    dataset,
    batch_size=128,
    collate_fn=collate_fn,
)

The image processor of type `Mask2FormerImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 
Loading weights: 100%|██████████| 782/782 [00:00<00:00, 1863.31it/s, Materializing param=model.transformer_module.queries_features.weight]                                                       


In [4]:
all_outputs = {}
for batch, target, images in tqdm(loader, desc="Mask"):
    out = extractor(batch, images, target)
    for key, val in out.items():
        if key not in all_outputs:
            all_outputs[key] = val
        else:
            all_outputs[key].extend(val)

Mask: 100%|██████████| 264/264 [26:07<00:00,  5.94s/it]


In [6]:
len(df_ref), len(all_outputs["alpha_embeddings"])

(33721, 33721)

In [37]:
import faiss

d = 768
M = 32
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)

In [38]:
new_items = []
for i, ref_item in tqdm(enumerate(df_ref.to_dicts()), total=len(df_ref), desc="alignment"):
    for n in range(len(all_outputs["labels"][i])):
        new_items.append(ref_item)
    
    _embed = all_outputs["alpha_embeddings"][i].astype("float32")
    faiss.normalize_L2(_embed)
    index.add(_embed)

alignment: 100%|██████████| 33721/33721 [00:41<00:00, 808.61it/s]


In [39]:
index.ntotal, len(new_items)

(272276, 272276)

In [41]:
import json

faiss.write_index(index, "s4_hnsw.index")
with open("s4_metadata.json", "w") as f:
    json.dump({"metadata": new_items}, f)

In [42]:
nindex = faiss.read_index("s4_hnsw.index")
nindex.ntotal

272276

test

In [43]:
df_test = pl.read_csv("../datasets/mp16-reason-test.csv")
dataset = MP16Dataset(
    df_test,
    # df_ref,
    img_col="IMG_ID",
    img_base_path="../datasets/mp16-reason",
)
loader = DataLoader(
    dataset,
    batch_size=128,
    collate_fn=collate_fn,
)

all_outputs = {}
for batch, target, images in tqdm(loader, desc="Mask"):
    out = extractor(batch, images, target)
    for key, val in out.items():
        if key not in all_outputs:
            all_outputs[key] = val
        else:
            all_outputs[key].extend(val)

Mask: 100%|██████████| 94/94 [15:45<00:00, 10.06s/it]


In [91]:
import pandas as pd

all_ref_gps = []
for ix in tqdm(range(len(df_test)), desc="sim search"):
    query_emb = all_outputs["alpha_embeddings"][ix].astype("float32")
    faiss.normalize_L2(query_emb)
    sim, ind = index.search(query_emb, 100)
    flat_D, flat_I = sim.reshape(-1), ind.reshape(-1)
    sorted_sim_ids = pd.DataFrame({"idx": flat_I, "score": flat_D}).sort_values(by="score", ascending=False).drop_duplicates(subset="idx").idx.tolist()
    sim_meta = [new_items[ii] for ii in sorted_sim_ids]
    ref_gps = [[item["LAT"], item["LON"]] for item in sim_meta[:100]]
    all_ref_gps.append(ref_gps)

sim search: 100%|██████████| 12000/12000 [00:18<00:00, 659.07it/s]


In [92]:
gt_gps = df_test.select("LAT", "LON").to_numpy().tolist()

In [100]:
def haversine_np(gps1: list | tuple | np.ndarray, gps2: list | tuple | np.ndarray):
    if not isinstance(gps1, np.ndarray):
        gps1 = np.array(gps1)
    if not isinstance(gps2, np.ndarray):
        gps2 = np.array(gps2)

    gps1 = np.atleast_2d(gps1)
    gps2 = np.atleast_2d(gps2)

    lat1, lon1 = np.radians(gps1).T
    lat2, lon2 = np.radians(gps2).T

    dlat = lat2 - lat1
    dlon = lon2 - lon1

    a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
    c = 2 * np.arcsin(np.sqrt(a))

    dist = 6371 * c
    if dist.size == 1:
        return dist.item()
    return dist

def precision_k(
    gt_gps: np.ndarray, ret_gps: np.ndarray, k: int = 10, min_dist: int = 50
):
    """Out of the top-K results, how many are actually outside the forbidden radius?"""
    distances = haversine_np(gt_gps, ret_gps).T
    return np.mean(distances[:, :k] >= min_dist)

In [101]:
metrics = {
    "precision@10": precision_k(gt_gps, all_ref_gps, k=10, min_dist=250).item(),
    "precision@100": precision_k(gt_gps, all_ref_gps, k=100, min_dist=250).item(),
}

metrics

{'precision@10': 0.5468666666666666, 'precision@100': 0.6068675}