In [None]:
ROOT = "/content/drive/MyDrive/crop-disease-detection"

MODELS_DIR = f"{ROOT}/models"
DATA_DIR   = f"{ROOT}/data/processed"


In [None]:
import torch
import timm
from pathlib import Path

DEVICE = "cpu"

from torchvision import models
import torch.nn as nn

def _build_vit_base(num_classes: int):
    model = models.vit_b_16(weights=None)   # architecture only
    in_features = model.heads.head.in_features
    model.heads.head = nn.Linear(in_features, num_classes)
    return model


def load_vit_model(weights_path: str, num_classes: int):
    model = _build_vit_base(num_classes)
    state_dict = torch.load(weights_path, map_location=DEVICE)
    model.load_state_dict(state_dict)
    model.eval()
    return model

def load_disease_model(weights_path: str, num_classes: int):
    return load_vit_model(weights_path, num_classes=num_classes)


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report

def evaluate_imagefolder_model(model_path, test_dir):
    # Detect number of disease classes from folder names
    class_names = sorted(os.listdir(test_dir))
    num_classes = len(class_names)

    print("Detected disease classes:", class_names)

    # Load model
    model = load_disease_model(model_path, num_classes=num_classes)

    # Transforms
    tfms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # Dataset
    ds = datasets.ImageFolder(test_dir, transform=tfms)
    dl = DataLoader(ds, batch_size=1, shuffle=False)

    y_true = []
    y_pred = []

    for img, label in dl:
        with torch.no_grad():
            pred = model(img)
            pred_idx = pred.argmax(dim=1).item()

        y_true.append(label.item())
        y_pred.append(pred_idx)

    # Metrics
    print("\nCONFUSION MATRIX:")
    print(confusion_matrix(y_true, y_pred))

    print("\nCLASSIFICATION REPORT:")
    print(classification_report(y_true, y_pred, target_names=class_names))


##  Cassava

In [None]:
cassava_labels = [
    "Cassava Bacterial Blight (CBB)",
    "Cassava Brown Streak Disease (CBSD)",
    "Cassava Green Mottle (CGM)",
    "Cassava Mosaic Disease (CMD)",
    "Healthy"
]


model = _build_vit_base(num_classes=5)
model.load_state_dict(torch.load("/content/drive/MyDrive/crop-disease-detection/models/cassava_best.pth", map_location="cpu"))

cassava_model = f"/content/drive/MyDrive/crop-disease-detection/models/cassava_best.pth"
cassava_test = "/content/processed/Cassava/test"


evaluate_imagefolder_model(cassava_model, cassava_test)


## PlantVillage

In [None]:
pv_labels = [
    "Cassava Bacterial Blight (CBB)",
    "Cassava Brown Streak Disease (CBSD)",
    "Cassava Green Mottle (CGM)",
    "Cassava Mosaic Disease (CMD)",
    "Healthy"
]

model = _build_vit_base(num_classes=38)
model.load_state_dict(torch.load("/content/drive/MyDrive/crop-disease-detection/models/plant_village_best.pth", map_location="cpu"))

cassava_model = f"/content/drive/MyDrive/crop-disease-detection/models/plant_village_best.pth"
cassava_test = "/content/processed/plantVillage/test"


evaluate_imagefolder_model(cassava_model, cassava_test)

## Riceleafs

In [None]:
rice_labels = [
    "leaf_blast",
    "brown_spot",
    "bacterial_leaf_blight",
    "leaf_scald",
    "narrow_brown_spot",
    "healthy"
]

model = _build_vit_base(num_classes=6)
model.load_state_dict(torch.load("/content/drive/MyDrive/crop-disease-detection/models/rice_leaf_best.pth", map_location="cpu"))

cassava_model = f"/content/drive/MyDrive/crop-disease-detection/models/rice_leaf_best.pth"
cassava_test = "/content/processed/riceleaf/test"


evaluate_imagefolder_model(cassava_model, cassava_test)