In [1]:
from iragca.matplotlib import Styles
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import torch
from torchvision import transforms

from lib.architectures import SimpleCNN
from lib.config import Directories
from lib.data import PlantDocDiseaseDetection, PlantVillageDiseaseDetection

plt.style.use(Styles.ML.value)

model_name = "disease_detection_model"


IMAGE_SIZE = (32, 32)
transform_pipeline = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
    ]
)
KAGGLE_DATA_PATH = Directories.EXTERNAL_DATA_DIR.value / "kagglehub"

plantvillage = PlantVillageDiseaseDetection(data_path=KAGGLE_DATA_PATH / "plantvillage", transforms=transform_pipeline)
plantdoc = PlantDocDiseaseDetection(data_path=KAGGLE_DATA_PATH / "plantdoc", transforms=transform_pipeline)

In [2]:
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
plantdoc_loader = DataLoader(plantdoc, batch_size=32, shuffle=True)
plantvillage_loader = DataLoader(plantvillage, batch_size=32, shuffle=True)

In [4]:
model = SimpleCNN(channels=3, output_dim=1)
model.load_state_dict(torch.load(Directories.MODELS_DIR.value / f"{model_name}.pth"))


def test(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    all_preds = []
    all_labels = []
    all_outputs = []
    THRESHOLD = 0.5
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc=f"Evaluating Dataset {data_loader.dataset.__class__.__name__}"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            preds = (outputs >= THRESHOLD).long()

            all_outputs.extend(outputs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    f1 = f1_score(all_labels, all_preds)
    accuracy = accuracy_score(all_labels, all_preds)

    return f1, accuracy

f1, accuracy = test(model, plantdoc_loader)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Evaluating Dataset PlantDocDiseaseDetection: 100%|██████████| 92/92 [01:05<00:00,  1.41it/s]

F1 Score: 0.8382
Accuracy: 0.7644





In [5]:
f1, accuracy = test(model, plantvillage_loader)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Evaluating Dataset PlantVillageDiseaseDetection: 100%|██████████| 1698/1698 [00:58<00:00, 29.18it/s]


F1 Score: 0.8040
Accuracy: 0.7608
