In [7]:
cassava_class_dirs = sorted(os.listdir(os.path.join(
    ROOT_DIR, "data", "processed", "cassava", "train"
)))
cassava_class_dirs


['Cassava Bacterial Blight (CBB)',
 'Cassava Brown Streak Disease (CBSD)',
 'Cassava Green Mottle (CGM)',
 'Cassava Mosaic Disease (CMD)',
 'Healthy']

In [8]:
rice_class_dirs = sorted(os.listdir(os.path.join(
    ROOT_DIR, "data", "processed", "riceleaf", "train"
)))
rice_class_dirs


['bacterial_leaf_blight',
 'brown_spot',
 'healthy',
 'leaf_blast',
 'leaf_scald',
 'narrow_brown_spot']

In [9]:
pv_class_dirs = sorted(os.listdir(os.path.join(
    ROOT_DIR, "data", "processed", "plantVillage", "train"
)))
pv_class_dirs


['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_

In [3]:
for name, path in {
    "species": PATH_SPECIES,
    "cassava": PATH_CASSAVA,
    "rice": PATH_RICE,
    "plantVillage": PATH_PV
}.items():
    state = torch.load(path, map_location="cpu")
    num_classes = state["heads.head.weight"].shape[0]
    print(f"{name} model classes = {num_classes}")


species model classes = 16
cassava model classes = 5
rice model classes = 6
plantVillage model classes = 38


In [10]:
# ============================================================
# test_classifiers.ipynb  (single-cell version)
#
# Full inference pipeline test for:
#  - YOLO leaf detection
#  - Species ViT classifier (16 classes)
#  - Cassava / Rice / PlantVillage classifiers
#
# Works fully with your local folder layout.
# ============================================================

import os
import torch
import cv2
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from ultralytics import YOLO

print("Torch:", torch.__version__)
import torchvision
print("Torchvision:", torchvision.__version__)

# ------------------------------------------------------------
# 1. RESOLVE PROJECT ROOT (notebook-safe)
# ------------------------------------------------------------
CURRENT_DIR = os.getcwd()
ROOT_DIR = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
MODEL_DIR = os.path.join(ROOT_DIR, "models")
TEST_DIR  = os.path.join(ROOT_DIR, "tests", "test_images")

print("ROOT_DIR:", ROOT_DIR)
print("MODEL_DIR:", MODEL_DIR)
print("TEST_DIR:", TEST_DIR)

# ------------------------------------------------------------
# 2. MODEL PATHS
# ------------------------------------------------------------
PATH_SPECIES = os.path.join(MODEL_DIR, "species_classifier_vit.pth")
PATH_CASSAVA = os.path.join(MODEL_DIR, "cassava_best.pth")
PATH_RICE    = os.path.join(MODEL_DIR, "rice_leaf_best.pth")
PATH_PV      = os.path.join(MODEL_DIR, "plant_village_best.pth")
PATH_YOLO    = os.path.join(MODEL_DIR, "yolo_plantdoc_detect.pt")

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ------------------------------------------------------------
# 3. LOAD YOLO MODEL
# ------------------------------------------------------------
yolo = YOLO(PATH_YOLO)

# ------------------------------------------------------------
# 4. IMAGE TRANSFORMS FOR ViT
# ------------------------------------------------------------
IMG_SIZE = 224
vit_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

# ------------------------------------------------------------
# 5. LOAD ViT MODELS WITH AUTOMATIC CLASS SIZE DETECTION
# ------------------------------------------------------------
def load_vit(path):
    state = torch.load(path, map_location=device)
    num_classes = state["heads.head.weight"].shape[0]
    
    model = models.vit_b_16(weights=None)
    in_features = model.heads[-1].in_features
    model.heads[-1] = nn.Linear(in_features, num_classes)
    
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model

species_model = load_vit(PATH_SPECIES)
cassava_model = load_vit(PATH_CASSAVA)
rice_model    = load_vit(PATH_RICE)
pv_model      = load_vit(PATH_PV)

print("Models loaded.")

# ------------------------------------------------------------
# 6. LOAD LABELS FROM FOLDER NAMES
# ------------------------------------------------------------
def load_class_names(path):
    return sorted([
        d for d in os.listdir(path)
        if os.path.isdir(os.path.join(path, d))
    ])

# species (manual list from your Drive screenshot)
species_labels = [
    "Apple",
    "Blueberry",
    "Cassava",
    "Cherry_(including_sour)",
    "Corn_(maize)",
    "Grape",
    "Orange",
    "Peach",
    "Pepper,_bell",
    "Potato",
    "Raspberry",
    "Rice",
    "Soybean",
    "Squash",
    "Strawberry",
    "Tomato"
]

# cassava = 5 classes
cassava_labels = load_class_names(os.path.join(ROOT_DIR, "data", "processed", "cassava", "train"))

# rice = 6 classes
rice_labels = load_class_names(os.path.join(ROOT_DIR, "data", "processed", "riceleaf", "train"))

# PlantVillage = 38 classes
pv_labels = load_class_names(os.path.join(ROOT_DIR, "data", "processed", "plantVillage", "train"))

print("Species labels:", len(species_labels))
print("Cassava labels:", len(cassava_labels))
print("Rice labels:", len(rice_labels))
print("PlantVillage labels:", len(pv_labels))

# ------------------------------------------------------------
# 7. HELPERS
# ------------------------------------------------------------
def predict_vit(pil_img, model, labels):
    tensor = vit_tfms(pil_img).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(tensor).argmax(1).item()
    return labels[pred]

def crop_leaf_from_yolo(image_path):
    results = yolo(image_path)[0]
    if len(results.boxes) == 0:
        print("‚ö†Ô∏è No leaf detected:", os.path.basename(image_path))
        return None
    x1, y1, x2, y2 = results.boxes[0].xyxy[0].cpu().numpy().astype(int)
    img = cv2.imread(image_path)
    crop = img[y1:y2, x1:x2]
    crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
    return Image.fromarray(crop)

# ------------------------------------------------------------
# 8. RUN TESTS
# ------------------------------------------------------------
def run_tests():
    files = [x for x in os.listdir(TEST_DIR) if x.lower().endswith((".jpg", ".jpeg", ".png"))]
    if not files:
        print("‚ùå No test images found.")
        return
    
    print("\n========== RUNNING TESTS ==========\n")
    results = []
    
    for f in files:
        path = os.path.join(TEST_DIR, f)
        print(f"\nüü¶ Testing: {f}")

        leaf = crop_leaf_from_yolo(path)
        if leaf is None:
            continue

        species = predict_vit(leaf, species_model, species_labels)
        print("  üîç Species:", species)

        if species == "Cassava":
            disease = predict_vit(leaf, cassava_model, cassava_labels)
        elif species == "Rice":
            disease = predict_vit(leaf, rice_model, rice_labels)
        else:
            disease = predict_vit(leaf, pv_model, pv_labels)

        print("  ü¶† Disease:", disease)
        results.append((f, species, disease))
    
    print("\n========== RESULTS ==========\n")
    for f, s, d in results:
        print(f"{f:25s} | {s:12s} | {d}")

    return results

# ------------------------------------------------------------
# RUN
# ------------------------------------------------------------
run_tests()


Torch: 2.3.0+cpu
Torchvision: 0.18.0+cpu
ROOT_DIR: c:\Users\User\Desktop\Data Science\Projects\crop-disease-detection
MODEL_DIR: c:\Users\User\Desktop\Data Science\Projects\crop-disease-detection\models
TEST_DIR: c:\Users\User\Desktop\Data Science\Projects\crop-disease-detection\tests\test_images
Using device: cpu
Models loaded.
Species labels: 16
Cassava labels: 5
Rice labels: 6
PlantVillage labels: 38



üü¶ Testing: cassava_Cassava Bacterial Blight (CBB)_3944841972.jpg

image 1/1 c:\Users\User\Desktop\Data Science\Projects\crop-disease-detection\tests\test_images\cassava_Cassava Bacterial Blight (CBB)_3944841972.jpg: 480x640 9 leafs, 185.7ms
Speed: 3.7ms preprocess, 185.7ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640)
  üîç Species: Cassava
  ü¶† Disease: Cassava Bacterial Blight (CBB)

üü¶ Testing: cassava_Cassava Bacterial Blight (CBB)_586054705.jpg

image 1/1 c:\Users\User\Desktop\Data Science\Projects\crop-disease-detection\tests\test_images\cassava_Cassa

[('cassava_Cassava Bacterial Blight (CBB)_3944841972.jpg',
  'Cassava',
  'Cassava Bacterial Blight (CBB)'),
 ('cassava_Cassava Bacterial Blight (CBB)_586054705.jpg',
  'Cassava',
  'Healthy'),
 ('cassava_Cassava Brown Streak Disease (CBSD)_4219389723.jpg',
  'Corn_(maize)',
  'Tomato___Late_blight'),
 ('cassava_Cassava Brown Streak Disease (CBSD)_486370102.jpg',
  'Cassava',
  'Cassava Brown Streak Disease (CBSD)'),
 ('cassava_Cassava Green Mottle (CGM)_3311389928.jpg',
  'Cassava',
  'Cassava Brown Streak Disease (CBSD)'),
 ('cassava_Cassava Green Mottle (CGM)_4183847559.jpg',
  'Cassava',
  'Cassava Green Mottle (CGM)'),
 ('cassava_Cassava Mosaic Disease (CMD)_4084470563.jpg', 'Cassava', 'Healthy'),
 ('cassava_Cassava Mosaic Disease (CMD)_719222576.jpg',
  'Cassava',
  'Cassava Mosaic Disease (CMD)'),
 ('cassava_Healthy_1272381477.jpg',
  'Cassava',
  'Cassava Bacterial Blight (CBB)'),
 ('cassava_Healthy_1763396057.jpg', 'Cassava', 'Healthy'),
 ('PlantDoc_07c.jpg', 'Cassava', 'Cassa