# 03 — Evaluate & Report

In [None]:
# Cell 1: imports & basic setup

from pathlib import Path
import os
import json
from typing import Dict, List, Tuple

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# If you want to use your src/ modules:
import sys
PROJECT_ROOT = Path("..").resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from src.config import MODELS_DIR, DATA_PROCESSED_DIR, DEVICE
from src.models.loaders import (
    load_yolo_seg_model,
    load_vit_model,
)
from src.pipeline.inference import (
    apply_mask_and_crop,
    run_pipeline,
)


### Define transform and label mappings
Update the SPECIES_LABELS and disease label mappings to match how you trained your models (folder names / class order).

In [None]:
# Cell 2: transforms & label mappings

# Standard ViT / ImageNet transforms (adjust if you used something different)
IMG_SIZE = 224

eval_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# ---- LABEL MAPPINGS ----
# These must match the class order you used in training

# Example mapping for species classifier:
SPECIES_LABELS = [
    "cassava",
    "rice",
    "plantvillage"  # <-- e.g. generic bucket for PlantVillage dataset
    # add/remove depending on what you actually trained
]

SPECIES_IDX2NAME = {i: name for i, name in enumerate(SPECIES_LABELS)}

# Example disease label mappings per model
# These should match the folders in each dataset's 'train'/'val' set
CASSAVA_LABELS = [
    "cassava_bacterial_blight",
    "cassava_brown_streak",
    "cassava_green_mottle",
    "cassava_healthy",
]

RICE_LABELS = [
    "rice_bacterial_leaf_blight",
    "rice_brown_spot",
    "rice_leaf_smut",
]

PLANTVILLAGE_LABELS = [
    # TODO: fill this with your actual PlantVillage class names in order
    "tomato_bacterial_spot",
    "tomato_early_blight",
    "tomato_healthy",
    # ...
]

DISEASE_LABEL_MAP = {
    "cassava": CASSAVA_LABELS,
    "rice": RICE_LABELS,
    "plantvillage": PLANTVILLAGE_LABELS,
}


# Load all models

In [None]:
# Cell 3: load models

yolo = load_yolo_seg_model(MODELS_DIR / "yolo_plantdoc_seg.pt").to(DEVICE)

species_model = load_vit_model(MODELS_DIR / "species_classifier_vit.pth").to(DEVICE)
cassava_model = load_vit_model(MODELS_DIR / "cassava_best.pth").to(DEVICE)
rice_model    = load_vit_model(MODELS_DIR / "rice_leaf_best.pth").to(DEVICE)
plantv_model  = load_vit_model(MODELS_DIR / "plant_village_best.pth").to(DEVICE)

species_model.eval()
cassava_model.eval()
rice_model.eval()
plantv_model.eval()

disease_models = {
    "cassava": cassava_model,
    "rice": rice_model,
    "plantvillage": plantv_model,
}


### Generic evaluation dataset
This dataset will just walk through an image folder and use the folder name as the ground-truth label.

In [None]:
# Cell 4: generic dataset for evaluation

class FolderDataset(Dataset):
    """
    Expects folder structure:
        root/
            class_1/
                img1.jpg
                ...
            class_2/
                ...
    """
    def __init__(self, root: Path, transform=None):
        self.root = Path(root)
        self.transform = transform

        self.samples: List[Tuple[Path, str]] = []
        for class_dir in sorted(self.root.iterdir()):
            if not class_dir.is_dir():
                continue
            label = class_dir.name
            for img_path in class_dir.rglob("*"):
                if img_path.suffix.lower() in {".jpg", ".jpeg", ".png"}:
                    self.samples.append((img_path, label))

        self.class_names = sorted({label for _, label in self.samples})

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img_tensor = self.transform(img)
        else:
            img_tensor = transforms.ToTensor()(img)

        return img_tensor, label, str(img_path)


### Wrapper to run full pipeline on a PIL image

In [None]:
# Cell 5: pipeline wrapper for notebook use

def predict_species(img_tensor: torch.Tensor, model: torch.nn.Module) -> str:
    model.eval()
    with torch.no_grad():
        img_tensor = img_tensor.unsqueeze(0).to(DEVICE)  # [1, C, H, W]
        logits = model(img_tensor)
        pred_idx = torch.argmax(logits, dim=1).item()
    return SPECIES_IDX2NAME[pred_idx]

def predict_disease(
    img_tensor: torch.Tensor,
    model: torch.nn.Module,
    species_name: str,
) -> str:
    model.eval()
    with torch.no_grad():
        img_tensor = img_tensor.unsqueeze(0).to(DEVICE)
        logits = model(img_tensor)
        pred_idx = torch.argmax(logits, dim=1).item()

    labels = DISEASE_LABEL_MAP[species_name]
    return labels[pred_idx]

def run_full_pipeline_on_path(
    img_path: Path,
) -> Tuple[str, str]:
    """
    Returns (species_pred, disease_pred)
    """
    # 1) read image
    img = Image.open(img_path).convert("RGB")
    img_np = np.array(img)

    # 2) YOLO segmentation -> mask + crop
    #    We use the helper from src.pipeline.inference
    leaf = apply_mask_and_crop(yolo, img_np)

    # 3) convert cropped leaf to tensor
    leaf_pil = Image.fromarray(leaf)
    leaf_tensor = eval_tfms(leaf_pil)

    # 4) Species prediction
    species_pred = predict_species(leaf_tensor, species_model)

    # 5) Disease model according to species
    if species_pred not in disease_models:
        raise ValueError(f"No disease model configured for species '{species_pred}'")

    disease_model = disease_models[species_pred]
    disease_pred = predict_disease(leaf_tensor, disease_model, species_pred)

    return species_pred, disease_pred


### Evaluation loop for one dataset

In [None]:
# Cell 6: evaluation loop for one dataset

def evaluate_dataset(
    dataset_root: Path,
    assumed_species: str,
    batch_size: int = 1,
) -> Dict:
    """
    dataset_root: e.g. data/processed/cassava/val
    assumed_species: the ground-truth species of this dataset,
                     so we know how to compare labels.

    Returns a dict with accuracy and metrics.
    """
    ds = FolderDataset(dataset_root, transform=None)  # we handle tfms inside pipeline

    y_true: List[str] = []
    y_pred: List[str] = []
    img_paths: List[str] = []

    for _, label, path_str in tqdm(ds, desc=f"Evaluating {dataset_root.name}"):
        img_path = Path(path_str)

        try:
            species_pred, disease_pred = run_full_pipeline_on_path(img_path)
        except Exception as e:
            print(f"Error on {img_path}: {e}")
            continue

        # optional: check whether pipeline species matches assumed dataset species
        if species_pred != assumed_species:
            # you can log this separately if you want species accuracy too
            pass

        y_true.append(label)
        y_pred.append(disease_pred)
        img_paths.append(str(img_path))

    acc = accuracy_score(y_true, y_pred)
    cls_report = classification_report(y_true, y_pred, zero_division=0, output_dict=True)
    cm = confusion_matrix(y_true, y_pred, labels=sorted(set(y_true + y_pred)))

    return {
        "dataset_root": str(dataset_root),
        "assumed_species": assumed_species,
        "accuracy": acc,
        "classification_report": cls_report,
        "confusion_matrix": cm.tolist(),
        "labels": sorted(set(y_true + y_pred)),
    }


### Run evaluations for all datasets and print results

In [None]:
# Cell 7: run evaluation for all processed datasets you want

results = {}

# Example paths – change to match your actual processed layout
cassava_val_dir = DATA_PROCESSED_DIR / "cassava" / "val"
rice_val_dir    = DATA_PROCESSED_DIR / "riceleaf" / "val"
plantv_val_dir  = DATA_PROCESSED_DIR / "plantVillage" / "val"

if cassava_val_dir.exists():
    results["cassava"] = evaluate_dataset(cassava_val_dir, assumed_species="cassava")

if rice_val_dir.exists():
    results["rice"] = evaluate_dataset(rice_val_dir, assumed_species="rice")

if plantv_val_dir.exists():
    results["plantvillage"] = evaluate_dataset(plantv_val_dir, assumed_species="plantvillage")

print("Summary accuracy:")
for name, r in results.items():
    print(f"{name}: {r['accuracy']:.4f}")


### Inspect one classification report

In [None]:
# Cell 8: inspect report for one dataset

import pandas as pd

ds_name = "cassava"  # change to "rice" / "plantvillage" etc
report = results[ds_name]["classification_report"]

df_report = pd.DataFrame(report).T
df_report


### Save metrics to JSON for powerpoint

In [None]:
# Cell 9: save results to experiments/ folder

EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"
EXPERIMENTS_DIR.mkdir(parents=True, exist_ok=True)

out_path = EXPERIMENTS_DIR / "full_pipeline_eval.json"

with open(out_path, "w") as f:
    json.dump(results, f, indent=2)

print("Saved evaluation results to:", out_path)
