
# 05 — Evaluate, Calibrate Threshold, Export

**Goal:** Use validation to pick an operating threshold τ, visualize PR, and export:
- Saved `mil_head.pt` + `model_card.json` (already saved in the previous notebook).
- Export detections for **test** set: image-level probability + **top-k tiles** (for guidance).
- Write **GeoJSON/CSV/GPX** of presence points (using tile centers as proxy locations).


In [None]:

%pip -q install --extra-index-url https://download.pytorch.org/whl/cu121   torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0
%pip -q install numpy pandas pyarrow pillow tqdm scikit-learn geopandas shapely pyproj matplotlib timm==1.0.9


In [None]:

import os, json, numpy as np, pandas as pd, torch, timm, math
from torch import nn
from pathlib import Path
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import Point


In [None]:

BASE = Path('/content')  # change if needed
IM_TILE_DIR = BASE/'data/tiles'
CACHE_DIR   = BASE/'cache/embeddings'
MODEL_DIR   = BASE/'models'
EXPORT_DIR  = BASE/'exports'
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

ENC_NAME = "vit_small_patch14_dinov2.lvd142m"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Preproc for tile re-load (only needed for visual sanity checks; embeddings are cached already)
preproc = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])


In [None]:

# Load cached embeddings
import glob
def load_split(split):
    paths = sorted(glob.glob(str(CACHE_DIR/f"emb_{split}_*.parquet")))
    dfs = [pd.read_parquet(p) for p in paths]
    return pd.concat(dfs, axis=0, ignore_index=True) if dfs else pd.DataFrame()

val_df  = load_split('val')
test_df = load_split('test')
for name, d in [('val',val_df),('test',test_df)]:
    print(name, d.shape, d['label'].value_counts().to_dict() if 'label' in d.columns else {})


In [None]:

# Recreate MIL head and load weights
FEAT_DIM = 384
POOLING = json.load(open(MODEL_DIR/'model_card.json'))['pooling'] if (MODEL_DIR/'model_card.json').exists() else "max"

class MILHead(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.fc = nn.Linear(d, 1)
    def forward(self, tile_feats, pooling="max"):
        logits = self.fc(tile_feats).squeeze(-1)
        if pooling=="max":
            img_logit = logits.max()
        elif pooling=="lse":
            s=10.0; img_logit = torch.logsumexp(logits*s, dim=0)/s
        else:
            img_logit = logits.mean()
        return img_logit, logits

head = MILHead(FEAT_DIM).to(device)
head.load_state_dict(torch.load(MODEL_DIR/'mil_head.pt', map_location=device))
head.eval()


In [None]:

# Utility to evaluate per-image probs from cached embeddings
def to_groups(df):
    emb_cols = [c for c in df.columns if c.startswith('emb_')]
    groups = {}
    for img_id, g in df.groupby('image_id'):
        feats = torch.tensor(g[emb_cols].values, dtype=torch.float32).to(device)
        label = int(g['label'].iloc[0])
        groups[img_id] = {'feats': feats, 'label': label, 'tiles': g}
    return groups

val_groups  = to_groups(val_df)
test_groups = to_groups(test_df)

def score_groups(head, groups, pooling="max"):
    y_true, y_prob = [], []
    for img_id, g in groups.items():
        with torch.no_grad():
            logit, tile_logits = head(g['feats'], pooling=pooling)
            prob = torch.sigmoid(logit).item()
        y_true.append(g['label']); y_prob.append(prob)
        g['prob'] = prob
        g['tile_probs'] = torch.sigmoid(tile_logits).detach().cpu().numpy()
    return np.array(y_true), np.array(y_prob)

yv, pv = score_groups(head, val_groups, pooling=POOLING)
tv, tp = score_groups(head, test_groups, pooling=POOLING)

ap_val = average_precision_score(yv, pv)
print(f"Validation AP: {ap_val:.4f}")


In [None]:

# Plot PR curve (matplotlib; default colors)
prec, rec, thr = precision_recall_curve(yv, pv)
plt.figure()
plt.plot(rec, prec)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Validation PR Curve")
plt.grid(True)
plt.show()


In [None]:

# Choose operating threshold τ to meet target precision (e.g., >= 0.9) or best F1
target_precision = 0.90
best_tau = 0.5
best_gap = 1e9
for p, r, t in zip(prec[:-1], rec[:-1], thr):
    if p >= target_precision:
        gap = (1.0 - r)  # prefer higher recall given precision target
        if gap < best_gap:
            best_gap = gap; best_tau = float(t)

print("Chosen threshold τ =", best_tau)

# Save threshold for deployment
with open(MODEL_DIR/'threshold.json','w') as f:
    json.dump({"threshold": best_tau, "target_precision": target_precision}, f, indent=2)


In [None]:

# Export detections for TEST set using τ
tau = best_tau
rows = []
for img_id, g in test_groups.items():
    present = g['prob'] >= tau
    # get top tile index and prob for guidance
    tp = g['tile_probs']
    if tp.size > 0:
        j = int(tp.argmax())
        top_prob = float(tp[j])
        top_row = g['tiles'].iloc[j]
        # approximate point as tile center in pixel space (no geo). If you have geo, replace with lat/lon.
        W, H = top_row['W'], top_row['H']
        x, y = top_row['x'], top_row['y']
        cx, cy = x + 0.5*640, y + 0.5*640  # assumes TILE=640
    else:
        top_prob, cx, cy, W, H = 0.0, None, None, None, None
    rows.append({
        "image_id": img_id,
        "present_prob": float(g['prob']),
        "present_flag": bool(present),
        "top_tile_prob": top_prob,
        "center_x": cx, "center_y": cy, "W": W, "H": H,
    })
det_df = pd.DataFrame(rows)
det_df.to_csv(EXPORT_DIR/'detections_test.csv', index=False)
det_df.head()


In [None]:

# If you have geolocation per tile/image, convert to GeoJSON/GPX here.
# For now, we produce a minimal GeoJSON using pixel coords as properties.
gdf = gpd.GeoDataFrame(det_df, geometry=[Point(float(x), float(y)) if (x is not None and y is not None) else None 
                                         for x,y in zip(det_df['center_x'], det_df['center_y'])], crs="EPSG:3857")
gdf_present = gdf[gdf['present_flag'] & gdf['geometry'].notnull()].copy()
geojson_path = EXPORT_DIR/'detections_test.geojson'
gdf_present.to_file(geojson_path, driver="GeoJSON")
print("Wrote:", geojson_path)



## Deployment Artifacts

You now have:
- `models/mil_head.pt` (weights)
- `models/model_card.json` (hyperparams)
- `models/threshold.json` (operating threshold)
- `exports/detections_test.csv` and `exports/detections_test.geojson`

These are ready to plug into your interface. Inference side just needs to:
- tile each new image,
- load cached DINOv2 encoder + `mil_head.pt`,
- compute presence probability,
- compare against `threshold.json`,
- return top tile center as a point to export to GeoJSON/GPX.
