In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from configs.config import load_config
from dataloaders.dataloader import get_dataloader
from models import model, train, validate, test
from preprocessing.dataset_split import prepare_and_split
from preprocessing.augment_inplace import augment_train_df
from Utils import *

Paths

In [None]:
config = load_config("configs/default.yaml")


Check if working with CUDA or not

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else: device = torch.device("cpu")

Read in bbox df for crops

In [None]:
crop_bbox_df = pd.read_csv(config.paths.csv_annot)


In [None]:
n_total = crop_bbox_df["image_path"].nunique()
n_elig = crop_bbox_df.dropna(subset=["xmin","ymin","xmax","ymax"])["image_path"].nunique()
print(f"Total number of images: {n_total}   Eligible images: {n_elig}")
# Total number of images: 2703   Eligible images: 2061

Split dataset into Train-Validate-Test

In [None]:
train_df, val_df, test_df, class_names = prepare_and_split(og_df=crop_bbox_df, config=config)

Augment Train Images

In [None]:
images_root = config.paths.images_root

In [None]:
n_total_train = train_df["image_path"].nunique()
n_elig_train  = train_df.dropna(subset=["xmin","ymin","xmax","ymax"])["image_path"].nunique()
n_total_val = val_df["image_path"].nunique()
n_elig_val = val_df.dropna(subset=["xmin","ymin","xmax","ymax"])["image_path"].nunique()
n_total_test = test_df["image_path"].nunique()
n_elig_test = test_df.dropna(subset=["xmin","ymin","xmax","ymax"])["image_path"].nunique()

print(f"Total train images = {n_total_train} Eligible train images = {n_elig_train} ")
print(f"Total val images={n_total_val}  Eligible val images={n_elig_val}")
print(f"Total test images = {n_total_test}  Eligible test images={n_elig_test}")


In [None]:
train_df = augment_train_df(train_df=train_df, images_root=images_root)

In [None]:
new_total_train = train_df["image_path"].nunique()
new_elig_train = train_df.dropna(subset=["xmin","ymin","xmax","ymax"])["image_path"].nunique()
print(f"New total train images = {new_total_train}  New eligible train images={new_elig_train}")

Load data

In [None]:
train_loader, _ = get_dataloader(
    data=train_df,
    images_root=config.paths.images_root,
    class_names=class_names,
    batch_size=config.train.batch_size,
    num_workers=config.train.num_workers,
    collate_fn=collate
)

val_loader, _ = get_dataloader(
    data=val_df,
    images_root=config.paths.images_root,
    class_names=class_names,
    batch_size=config.train.batch_size,
    shuffle=False,
    num_workers=config.train.num_workers,
    collate_fn=collate
)

test_loader, _ = get_dataloader(
    data=test_df,
    images_root=config.paths.images_root,
    class_names=class_names,
    batch_size=config.train.batch_size,
    shuffle=False,
    num_workers=config.train.num_workers,
    collate_fn=collate
)



Get model

In [None]:
cunei_model = model.build_model(
    num_classes=len(class_names),
    anchor_sizes=config.model.anchor_sizes,
    aspect_ratios=config.model.aspect_ratios,
    score_thresh=config.model.score_thresh,
    nms_thresh=config.model.nms_thresh,
    detections_per_img=config.model.detections_per_img,
    imagenet_weights=bool(config.model.use_imagenet_weights)
)

Get Validation function

In [None]:
val_function = validate.validate_loss_factory(val_loader=val_loader, device=device)

Resume Code

In [None]:
resume_path = getattr(config.train, "resume", None)

Run Training and Validation

In [None]:
history = train.train(
    model=cunei_model,
    train_loader=train_loader,
    config=config,
    device=device,
    val_fn=val_function,
    eval_ctx=train.EvalCtx(loader=val_loader, class_names=class_names),
    resume_path=resume_path
)



Run Test

In [None]:
run_root = getattr(config.train, "ckpt_dir", getattr(config.train, "checkpoint_dir", "runs"))
best_path = resolve_best_ckpt(run_root=run_root, metric_name=config.train.best_metric)
ckpt = torch.load(best_path, map_location=device, weights_only=False)
cunei_model.load_state_dict(ckpt["model"])
cunei_model.to(device).eval()


In [None]:
sweep_rows, best = sweep_score_thresh(model=cunei_model, loader=val_loader, device=device, iou=0.5, max_batches=None)

ths = [r["th"] for r in sweep_rows]
P   = [r["precision"] for r in sweep_rows]
R   = [r["recall"]    for r in sweep_rows]
F1  = [r["f1"]        for r in sweep_rows]

cunei_model.score_thresh = best["th"]

print(f"Best F1 at th={best['th']:.2f}: P={best['precision']:.3f} R={best['recall']:.3f} F1={best['f1']:.3f}")


In [None]:
# rebuild loaders so targets carry "image_path"
test_loader, _ = get_dataloader(data=test_df, images_root=config.paths.images_root,
                                   class_names=class_names, batch_size=config.train.batch_size,
                                   shuffle=False, num_workers=config.train.num_workers, collate_fn=collate)


In [None]:
ds = test_loader.dataset
if hasattr(ds, "groups"):
    names = [name for name, _ in ds.groups]
pd.Series(names, name="test_image_names").to_csv("outputs/test_image_names.csv", index=False)

In [None]:
preds = test.run_inference(cunei_model, test_loader, class_names,
                      device=device, out_csv="test_predictions.csv",
                      score_thresh=cunei_model.score_thresh)



In [None]:
# gt and matches (IoU=0.5)
gt = build_gt_index(test_loader)

In [None]:
preds_eval, gt_counts = match_predictions(preds, gt, iou_thr=0.5)
print("TP:", int(preds_eval["tp"].sum()))


In [None]:
TP = int(preds_eval["tp"].sum())
GT_tot = int(sum(gt_counts.values()))
Pred_tot = len(preds_eval)

FP = max(Pred_tot - TP, 0)
FN = max(GT_tot - TP, 0)

prec = TP / max(TP + FP, 1)
rec  = TP / max(GT_tot, 1)

print(f"TP={TP} FP={FP} FN={FN}  |  Precision@0.5={prec:.3f} Recall@0.5={rec:.3f}")
assert TP <= GT_tot, "TP exceeds total GT — matching bug"


In [None]:
# per-class AP (you can also use evaluate_map which already returns per_class_AP)
res50 = evaluate_map(cunei_model, test_loader, class_names, device=device, iou_thr=0.5, max_batches=None)
per_class_AP = res50.get("per_class_AP", res50.get("per_class_ap"))
map50 = res50.get("mAP@0.5", res50.get("mAP", None))
print(f"Test mAP@0.5: {map50:.3f}" if map50 is not None else "mAP not available")


In [None]:
# Per-class P/R at your current score_thresh (from matched preds)
tp_by = preds_eval.loc[preds_eval["tp"]].groupby("label_id").size()
fp_by = preds_eval.loc[~preds_eval["tp"]].groupby("label_id").size()
det_by = preds_eval.groupby("label_id").size()

tp_by = tp_by.reindex(range(len(class_names)), fill_value=0)
fp_by = fp_by.reindex(range(len(class_names)), fill_value=0)
det_by = det_by.reindex(range(len(class_names)), fill_value=0)

summary_rows = []
for cid, cname in enumerate(class_names):
    gt_c = int(gt_counts.get(cid, 0))
    tp, fp = int(tp_by[cid]), int(fp_by[cid])
    prec = tp / max(tp + fp, 1)
    rec  = tp / max(gt_c, 1)
    ap   = per_class_AP.get(cname, np.nan)  # keys are names
    summary_rows.append({"class": cname, "AP@0.5": ap, "Precision@0.5": prec,
                 "Recall@0.5": rec, "GT": gt_c, "Detections": int(det_by[cid])})

# Overall row
TP, FP = int(tp_by.sum()), int(fp_by.sum())
GT_tot = int(sum(gt_counts.values()))
overall = {
    "class": "ALL",
    "AP@0.5": float(map50) if map50 is not None else np.nan,
    "Precision@0.5": TP / max(TP + FP, 1),
    "Recall@0.5": TP / max(GT_tot, 1),
    "GT": GT_tot,
    "Detections": int(det_by.sum()),
}
summary = pd.DataFrame(summary_rows + [overall])

summary["class"] = summary["class"].astype(str)
summary_class_view = summary.sort_values("class").reset_index(drop=True)

summary_class_view.assign(iou=0.5, score_thresh=cunei_model.score_thresh).to_csv("test_summary.csv", index=False)

summary



Visuals - train/val

In [None]:
plt.figure()
plt.plot(history["epoch"], history["train_loss"], label="train")
plt.plot(history["epoch"], [v for v in history["val_loss"]], label="val")
plt.legend()
plt.title("Loss")

plt.figure()
plt.plot(history["epoch"], history["train_cls"], label="train")
plt.plot(history["epoch"], [v for v in history["val_cls"]], label="val")
plt.legend()
plt.title("Classification")

plt.figure()
plt.plot(history["epoch"], history["train_reg"], label="train")
plt.plot(history["epoch"], [v for v in history["val_reg"]], label="val")
plt.legend()
plt.title("Regression")

plt.figure()
plt.plot(history["epoch"], [v for v in history["map50"] if v is not None])
plt.title("mAP@0.5")

plt.figure()
plt.plot(history["epoch"], [v for v in history["precision"] if v is not None])
plt.title("Precision@0.5")

plt.figure()
plt.plot(history["epoch"], [v for v in history["recall"] if v is not None])
plt.title("Recall@0.5")

plt.show()

Visuals - sweep score

In [None]:
ths = [r["th"] for r in sweep_rows]
P=[r["precision"] for r in sweep_rows]
R=[r["recall"] for r in sweep_rows]
F1=[r["f1"] for r in sweep_rows]

plt.figure()
plt.plot(ths,P,label="P")
plt.plot(ths,R,label="R")
plt.plot(ths,F1,label="F1")
plt.axvline(best["th"],ls="--")
plt.legend(); plt.title("Val sweep @ IoU=0.5")
plt.xlabel("score_thresh")
plt.savefig("val_sweep.png", dpi=150, bbox_inches="tight")
plt.show()

Visuals - predictions (sin gt)

In [None]:
# 1) Score histogram
plt.figure(); plt.hist(preds["score"], bins=50, range=(0,1)); plt.title("Scores"); plt.xlabel("score"); plt.ylabel("#detections")

# 2) Detections per image
counts = preds.groupby("image_path").size()
plt.figure(); plt.hist(counts, bins=range(1, counts.max()+2)); plt.title("Detections per image"); plt.xlabel("#detections"); plt.ylabel("#images")

# 3) Per-class detection counts
pc = preds["label_name"].value_counts().sort_values(ascending=False)
plt.figure(figsize=(8,3)); pc.plot(kind="bar"); plt.title("Per-class detections"); plt.ylabel("#detections"); plt.tight_layout()

# 4) Box area vs score + area histogram
w = preds["xmax"] - preds["xmin"]; h = preds["ymax"] - preds["ymin"]; area = (w*h).clip(lower=1)
plt.figure(); plt.scatter(np.log10(area), preds["score"], s=5, alpha=0.3); plt.title("Score vs log10(area)"); plt.xlabel("log10(area px^2)"); plt.ylabel("score")
plt.figure(); plt.hist(np.sqrt(area), bins=40); plt.title("Box size (sqrt area)"); plt.xlabel("pixels"); plt.ylabel("#detections")

# 5) Spatial heatmap of centers
cx = (preds["xmin"] + preds["xmax"]) / 2.0
cy = (preds["ymin"] + preds["ymax"]) / 2.0
H, xedges, yedges = np.histogram2d(cx, cy, bins=50)
plt.figure(); plt.imshow(H.T, origin="lower", aspect="auto"); plt.title("Detection centre heatmap"); plt.colorbar(label="#detections")
plt.show()


Visuals (avec gt)

In [None]:
# per-class AP bar chart (sorted)
names, aps = zip(*sorted(per_class_AP.items(), key=lambda kv: kv[1], reverse=True))
plt.figure(figsize=(9,3))
plt.bar(names, aps)
plt.xticks(rotation=45, ha="right"); plt.ylabel("AP@0.5"); plt.title("Test per-class AP")
plt.tight_layout(); plt.show()


In [None]:
# PR curves for up to 4 classes (IoU=0.5)
lid2name = preds_eval.groupby("label_id")["label_name"].first().to_dict()
name2lid = {str(v): int(k) for k, v in lid2name.items()}

print("pred ids→names:", lid2name)
print("unique GT ids :", sorted({int(c) for g in gt.values() for c in g["labels"].tolist()}))
print("TP count      :", int(preds_eval["tp"].sum()))


In [None]:
id2name = {i: n for i, n in enumerate(class_names)}
name2id = {str(n): i for i, n in id2name.items()}

top_names = [n for n,_ in sorted(per_class_AP.items(), key=lambda kv: kv[1], reverse=True)[:4]]

plt.figure()
for cname in top_names:
    cid = name2id.get(str(cname))
    if cid is None:
        print(f"skip {cname}: no id")
        continue

    rec, prec, ap = pr_curve_for_class(preds_eval, gt_counts, class_id=cid)
    plt.plot(rec, prec, label=f"{id2name[cid]} (AP={ap:.2f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("PR curves @ IoU=0.5")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# TP IoU histogram @0.5
tp_iou = preds_eval.loc[preds_eval["tp"], "match_iou"].dropna().to_numpy()
if tp_iou.size:
    plt.figure()
    plt.hist(tp_iou, bins=np.linspace(0.5, 1.0, 21))
    plt.xlabel("IoU"); plt.ylabel("#TPs"); plt.title("TP IoU histogram @0.5")
    plt.tight_layout(); plt.show()
else:
    print("No true positives to plot IoU histogram.")


In [None]:
# TP confusion (GT vs Pred) without recomputing IoU
if "gt_label_id" not in preds_eval.columns:
    preds_eval["gt_label_id"] = pd.NA

preds_eval["tp"] = preds_eval["tp"].astype(bool)

tp = preds_eval.loc[preds_eval["tp"]].copy()
cm = pd.crosstab(tp["gt_label_id"], tp["label_id"], rownames=["GT"], colnames=["Pred"], dropna=False)

cm.index = [lid2name.get(i, str(i)) for i in cm.index]
cm.columns = [lid2name.get(i, str(i)) for i in cm.columns]

plt.figure(figsize=(5,4))
plt.imshow(cm.values, aspect="auto")
plt.xticks(range(len(cm.columns)), cm.columns, rotation=45, ha="right")
plt.yticks(range(len(cm.index)), cm.index)
plt.title("TP confusion (class vs class)")
plt.colorbar()
plt.tight_layout()
plt.show()

In [None]:
# detection score distribution (all predictions)
plt.figure()
plt.hist(preds["score"], bins=50, range=(0,1))
plt.xlabel("score"); plt.ylabel("#detections"); plt.title("Detection scores")
plt.tight_layout(); plt.show()


In [None]:
# detections per image
dets_per_img = preds.groupby("image_path").size()
plt.figure()
plt.hist(dets_per_img, bins=range(1, dets_per_img.max()+2))
plt.xlabel("#detections per image"); plt.ylabel("#images"); plt.title("Detections per image")
plt.tight_layout(); plt.show()


In [None]:
# score vs box size + size histogram
w = preds["xmax"] - preds["xmin"]
h = preds["ymax"] - preds["ymin"]
area = (w*h).clip(lower=1)
plt.figure()
plt.scatter(np.log10(area), preds["score"], s=5, alpha=0.3)
plt.xlabel("log10(area px^2)"); plt.ylabel("score"); plt.title("Score vs box area")
plt.tight_layout(); plt.show()

plt.figure()
plt.hist(np.sqrt(area), bins=40)
plt.xlabel("box size (pixels, sqrt area)"); plt.ylabel("#detections"); plt.title("Box size distribution")
plt.tight_layout(); plt.show()


In [None]:
# spatial heatmap of detection centers (assumes 512×512)
cx = (preds["xmin"] + preds["xmax"]) / 2.0
cy = (preds["ymin"] + preds["ymax"]) / 2.0
H, xe, ye = np.histogram2d(cx, cy, bins=50, range=[[0,512],[0,512]])
plt.figure()
plt.imshow(H.T, origin="lower", extent=[0,512,0,512], aspect="equal")
plt.colorbar(label="#detections")
plt.title("Detection centre heatmap"); plt.tight_layout(); plt.show()


In [None]:
det = detection_only_counts(preds, gt, iou=0.5)
print(det)  # TPd/FPd/FNd
print(" ")

cm, acc, per_cls_rec = classification_on_matched(preds, gt, class_names, iou=0.5)
if cm is not None:
    print(f"classification-only accuracy on matched boxes: {acc:.3f}")
    print(cm)

