In [39]:
import argparse, os, random
from pathlib import Path
from typing import Dict, List, Tuple
import json
import csv
import numpy as np
from PIL import Image
import torch
import torch.nn as nn

from torchvision import models, transforms
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

In [40]:
# total 16 unified classes, 7 species

In [50]:
def load_model_from_ckpt(pt_path: str, device: torch.device):
    ckpt = torch.load(pt_path, map_location="cpu")
    classes = ckpt.get("classes", None)

    # infer head size
    if isinstance(classes, list) and len(classes) > 0:
        num_classes = len(classes)
    else:
        w = ckpt["model"].get("classifier.3.weight", None) or ckpt["model"].get("classifier.1.weight", None)
        num_classes = int(w.shape[0]) if w is not None else 38

    m = models.mobilenet_v3_small(weights=None)
    in_features = m.classifier[3].in_features
    m.classifier[3] = nn.Linear(in_features, num_classes)
    m.load_state_dict(ckpt["model"])
    m.to(device).eval()

    if not isinstance(classes, list) or len(classes) != num_classes:
        classes = [str(i) for i in range(num_classes)]
    return m, classes

def build_transform(img_size=224, normalize=False):
    tf = [transforms.Resize((img_size, img_size)), transforms.ToTensor()]
    if normalize:
        tf += [transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])]
    return transforms.Compose(tf)


def collect_random_images(root: str, cm: Dict, domain: str, n: int) -> List[str]:
    exts = {".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"}
    rootp = Path(root)
    mapping = cm["plantvillage_to_unified"] if domain == "pv" else cm["plantdoc_to_unified"]
    allowed_folders = set(mapping.keys())

    paths = []
    for img in rootp.rglob("*"):
        if img.suffix.lower() in exts:
            folder = img.parent.name
            if folder in allowed_folders:
                paths.append(str(img))
    if not paths:
        print(f"No valid images found under {root} for domain={domain}. "
                         f"check the classmap and dataset structure.")
        return []
    
    random.shuffle(paths)
    return paths[:n]

In [51]:
def load_classmap(path: str) -> Dict:
    cm = json.load(open(path, "r"))
    return cm

def detect_domain(folder_name: str, cm: Dict) -> str | None:
    if folder_name in cm["plantvillage_to_unified"]:
        return "pv"
    if folder_name in cm["plantdoc_to_unified"]:
        return "pd"
    return None

def pv_to_unified_label(pv_name: str, cm: Dict) -> str | None:
    return cm["plantvillage_to_unified"].get(pv_name, None)

def gt_unified_from_path(p: str, cm: Dict) -> Tuple[str | None, str]:
    # gets the (gt_unified, domain) from the image path using folder names
    folder = Path(p).parent.name
    dom = detect_domain(folder, cm)
    if dom == "pv":
        return cm["plantvillage_to_unified"].get(folder, None), "pv"
    if dom == "pd":
        return cm["plantdoc_to_unified"].get(folder, None), "pd"
    return None, "?"

def validate_image_paths(paths: List[str], cm: Dict) -> List[str]:
    # keep only those images whose parent folder maps to any of the unified class in classmap
    kept = []
    for p in paths:
        folder = Path(p).parent.name
        if detect_domain(folder, cm) is not None:
            kept.append(p)
    return kept


In [52]:
@torch.no_grad()
def predict_topk(model, tensor, device, k):
    logits = model(tensor.to(device)).softmax(1).cpu().numpy()[0]
    top_idx = np.argsort(-logits)[:k]
    top_prob = logits[top_idx]
    return top_idx.tolist(), top_prob.tolist()

def draw_overlay(img_path: str,
                 gt_unified: str | None,
                 top_idx: List[int], top_prob: List[float],
                 pv_classes: List[str], cm: Dict,
                 save_dir: str, correct: bool):
    img = Image.open(img_path).convert("RGB")
    plt.figure(figsize=(5,5)); plt.imshow(img); plt.axis("off")

    # predicted names (PV + unified)
    pv_names = [pv_classes[i] for i in top_idx]
    uni_names = [pv_to_unified_label(n, cm) or "—" for n in pv_names]

    # header shows GT vs Pred, colored by correctness
    color = "green" if correct else "red"
    pred_line = f"PRED: {uni_names[0] or pv_names[0]} (p={top_prob[0]:.3f})"
    gt_line   = f"GT:   {gt_unified or '—'}"
    lines = [pred_line, gt_line, "", "top-k:"]
    for i, (pv, uni, p) in enumerate(zip(pv_names, uni_names, top_prob), 1):
        label = f"{uni} [{pv}]" if uni and uni != "—" else pv
        lines.append(f"{i}. {label} ({p:.3f})")

    plt.title("\n".join(lines), fontsize=9, color=color)
    outp = Path(save_dir) / (Path(img_path).stem + "_pred.png")
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    plt.tight_layout(); plt.savefig(outp, dpi=150); plt.close()
    return outp

In [53]:
CLASSMAP_PATH = "classmap.json"

IMAGES = []
# IMAGES = ["/home/devs/code/aai/project/PlantDoc-Dataset/test/Bell_pepper leaf/10148582-green-leaf-of-pepper.jpg", ]

DATA_DIR = "PlantDoc-Dataset/test/"
# DATA_DIR = "PlantVillage-Dataset/raw/color/"
DATA_DOMAIN = "pd"
SAMPLE_NUM = 3


# Default
CKPT = "runs/frontiers2023/run2/mobilenetv3small_best.pt"
IMG_SIZE = 224
NORMALIZE = False


SAVE_DIR = f"artifacts/OOD/{DATA_DOMAIN}"
SAVE_LOG = f"{SAVE_DIR}/logs/exp1.csv"


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cm = load_classmap(CLASSMAP_PATH)

# collect images
if IMAGES:
    imgs = validate_image_paths(IMAGES, cm)
    if not imgs:
        print("none of the provided images map to any unified classes from the classmap.json.")
elif DATA_DIR:
    imgs = collect_random_images(DATA_DIR, cm, DATA_DOMAIN, SAMPLE_NUM)

model, pv_classes = load_model_from_ckpt(CKPT, device)
tf = build_transform(IMG_SIZE, normalize=NORMALIZE)

writer = None
if SAVE_LOG:
    Path(SAVE_LOG).parent.mkdir(parents=True, exist_ok=True)
    f = open(SAVE_LOG, "a", newline="")
    writer = csv.writer(f)
    writer.writerow(["image_path","domain","gt_unified","pred_unified","pred_pv","prob","topk_hit","correct"])


num_correct = 0
for p in imgs:
    gt_u, dom = gt_unified_from_path(p, cm)
    print("GT:", gt_u, dom)

    x = tf(Image.open(p).convert("RGB")).unsqueeze(0)
    top_idx, top_prob = predict_topk(model, x, device, k=3)

    pv_names = [pv_classes[i] for i in top_idx]
    uni_names = [pv_to_unified_label(n, cm) for n in pv_names]
    print(pv_names, uni_names)

    ## matching the class name to verify the results
    pred_u = uni_names[0] or pv_names[0]
    
#     correct = (gt_u is not None) and (pred_u == gt_u)
#     topk_hit = (gt_u is not None) and (gt_u in [u for u in uni_names if u is not None])
#     if correct: num_correct += 1

#     print(f"\nImage: {p}  (domain={dom})")
#     print(f"  GT unified: {gt_u}")
#     print(f"  Pred unified: {pred_u}  (p={top_prob[0]:.4f})  "
#             f"{'✅' if correct else ('(in top-k) ✅' if topk_hit else '❌')}")
#     for i, (pv, uu, pr) in enumerate(zip(pv_names, uni_names, top_prob), 1):
#         label = f"{uu} [{pv}]" if uu else pv
#         print(f"   {i}. {label} ({pr:.4f})")

#     out = draw_overlay(p, gt_u, top_idx, top_prob, pv_classes, cm, SAVE_DIR, correct)
#     print(f"  saved: {out}")

#     if writer:
#         writer.writerow([p, dom, gt_u, pred_u, pv_names[0], f"{top_prob[0]:.6f}",
#                             int(bool(topk_hit)), int(bool(correct))])

# print(f"\nSummary: {num_correct}/{len(imgs)} correct top-1 "
#         f"({100.0*num_correct/len(imgs):.1f}%).")

# if writer:
#         f.close()

GT: corn__common_rust pd
['Corn_(maize)___healthy', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot'] ['corn__healthy', 'corn__common_rust', 'corn__gray_leaf_spot']
GT: corn__northern_leaf_blight pd
['Corn_(maize)___healthy', 'Squash___Powdery_mildew', 'Peach___Bacterial_spot'] ['corn__healthy', None, 'peach__bacterial_spot']
GT: tomato__healthy pd
['Tomato___healthy', 'Tomato___Late_blight', 'Blueberry___healthy'] ['tomato__healthy', 'tomato__late_blight', None]
