In [2]:
from iragca.matplotlib import Styles
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import torch
from torchvision import transforms

from lib.architectures import SimpleCNN
from lib.config import Directories
from lib.data import (
    PlantDocDiseaseDetection, 
    PlantVillageDiseaseDetection,
    PlantDocSymptomIdentification, 
    PlantVillageSymptomIdentification,
    CombinedPlantDocDataset,
)
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score, classification_report
from tqdm import tqdm


plt.style.use(Styles.ML.value)

model_name = "disease_detection_model"


IMAGE_SIZE = (32, 32)
transform_pipeline = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
    ]
)
KAGGLE_DATA_PATH = Directories.EXTERNAL_DATA_DIR.value / "kagglehub"

plantvillage = PlantVillageDiseaseDetection(data_path=KAGGLE_DATA_PATH / "plantvillage", transforms=transform_pipeline)
plantdoc = PlantDocDiseaseDetection(data_path=KAGGLE_DATA_PATH / "plantdoc", transforms=transform_pipeline)
plantdoc_si = PlantDocSymptomIdentification(data_path=KAGGLE_DATA_PATH / "plantdoc", transforms=transform_pipeline)
plantvillage_si = PlantVillageSymptomIdentification(data_path=KAGGLE_DATA_PATH / "plantvillage", transforms=transform_pipeline)
combined_plantdoc = CombinedPlantDocDataset(data_path=KAGGLE_DATA_PATH / "plantdoc", transforms=transform_pipeline)

In [4]:
plantdoc_loader = DataLoader(plantdoc, batch_size=32, shuffle=True)
plantvillage_loader = DataLoader(plantvillage, batch_size=32, shuffle=True)
plantdoc_si_loader = DataLoader(plantdoc_si, batch_size=32, shuffle=True)
plantvillage_si_loader = DataLoader(plantvillage_si, batch_size=32, shuffle=True)

In [4]:
model = SimpleCNN(channels=3, output_dim=1)
model.load_state_dict(torch.load(Directories.MODELS_DIR.value / f"{model_name}.pth"))


def test(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    all_preds = []
    all_labels = []
    all_outputs = []
    THRESHOLD = 0.5
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc=f"Evaluating Dataset {data_loader.dataset.__class__.__name__}"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            preds = (outputs >= THRESHOLD).long()

            all_outputs.extend(outputs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    f1 = f1_score(all_labels, all_preds)
    accuracy = accuracy_score(all_labels, all_preds)

    return f1, accuracy


In [None]:

f1, accuracy = test(model, plantdoc_loader)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Evaluating Dataset PlantDocDiseaseDetection: 100%|██████████| 92/92 [00:27<00:00,  3.29it/s]

F1 Score: 0.8954
Accuracy: 0.8392





In [5]:
f1, accuracy = test(model, plantvillage_loader)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Evaluating Dataset PlantVillageDiseaseDetection: 100%|██████████| 1698/1698 [00:32<00:00, 52.47it/s]


F1 Score: 0.9901
Accuracy: 0.9856


## Symptom Identification

In [2]:
symptom_identifier = SimpleCNN(channels=3, output_dim=12)
symptom_identifier.load_state_dict(torch.load(Directories.MODELS_DIR.value / f"symptom_identification_model.pth"))

<All keys matched successfully>

In [3]:
def test_si(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    all_preds = []
    all_labels = []
    all_outputs = []
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc=f"Evaluating Dataset {data_loader.dataset.__class__.__name__}"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            preds = outputs.argmax(dim=1)

            all_outputs.extend(outputs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return all_labels, all_preds

In [None]:
all_labels, all_preds = test_si(symptom_identifier, plantdoc_si_loader)

Evaluating Dataset PlantDocSymptomIdentification: 100%|██████████| 66/66 [00:16<00:00,  4.01it/s]


In [10]:
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)

print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

F1 Score: 0.7692
Accuracy: 0.7569


In [11]:
print(classification_report(all_labels, all_preds))

              precision    recall  f1-score   support

           0       0.83      0.83      0.83       769
           1       0.84      0.80      0.82       238
           2       0.00      0.00      0.00         0
           3       0.68      0.68      0.68       130
           4       0.03      1.00      0.06         2
           5       0.74      0.67      0.71        91
           6       0.80      0.52      0.63        54
           7       0.75      0.56      0.64        79
           8       0.77      0.77      0.77       223
           9       0.83      0.61      0.70        93
          10       0.72      0.74      0.73       415
          11       0.00      0.00      0.00         0

    accuracy                           0.76      2094
   macro avg       0.58      0.60      0.55      2094
weighted avg       0.79      0.76      0.77      2094



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [5]:
all_labels, all_preds = test_si(symptom_identifier, plantvillage_si_loader)

Evaluating Dataset PlantVillageSymptomIdentification: 100%|██████████| 1226/1226 [00:22<00:00, 54.17it/s]


In [6]:
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)

print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

F1 Score: 0.9302
Accuracy: 0.9305


In [7]:
print(classification_report(all_labels, all_preds))

              precision    recall  f1-score   support

           0       0.96      0.78      0.86      6970
           1       0.99      0.99      0.99      5507
           2       0.99      0.98      0.98      5357
           3       0.99      0.96      0.97      2887
           4       0.97      0.85      0.91      1676
           5       0.97      0.87      0.92       952
           6       0.98      0.88      0.93       373
           7       0.96      0.95      0.96      1801
           8       0.97      0.97      0.97      1467
           9       0.92      0.92      0.92       630
          10       0.83      0.98      0.90     10492
          11       0.98      0.96      0.97      1109

    accuracy                           0.93     39221
   macro avg       0.96      0.92      0.94     39221
weighted avg       0.94      0.93      0.93     39221



In [8]:
total_params = sum(p.numel() for p in symptom_identifier.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 268908


## Combined Identification and Detection

In [6]:
test_comb = test_si

combined_classifier = SimpleCNN(channels=3, output_dim=13)
combined_classifier.load_state_dict(torch.load(Directories.MODELS_DIR.value / f"combined_identification_model.pth"))

all_labels, all_preds = test_comb(combined_classifier, DataLoader(combined_plantdoc, batch_size=32, shuffle=True))

Evaluating Dataset CombinedPlantDocDataset: 100%|██████████| 92/92 [00:29<00:00,  3.11it/s]


In [7]:
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)

print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

F1 Score: 0.7897
Accuracy: 0.7754


In [8]:
print(classification_report(all_labels, all_preds))

              precision    recall  f1-score   support

           0       0.82      0.84      0.83       769
           1       0.80      0.74      0.77       238
           2       0.00      0.00      0.00         0
           3       0.87      0.75      0.81       130
           4       0.02      1.00      0.04         2
           5       0.76      0.56      0.65        91
           6       0.80      0.44      0.57        54
           7       0.66      0.72      0.69        79
           8       0.79      0.75      0.77       223
           9       0.71      0.61      0.66        93
          10       0.69      0.70      0.69       415
          11       0.00      0.00      0.00         0
          12       0.88      0.84      0.86       822

    accuracy                           0.78      2916
   macro avg       0.60      0.61      0.56      2916
weighted avg       0.81      0.78      0.79      2916



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [18]:
labels_for_combined = []
preds_for_combined = []

for label, pred in zip(all_labels, all_preds):
    if label == combined_plantdoc.CLASS_MAP['healthy']:
        labels_for_combined.append(0)
        if pred != combined_plantdoc.CLASS_MAP['healthy']:
            preds_for_combined.append(1)
        else:
            preds_for_combined.append(0)

    else:
        labels_for_combined.append(1)
        if pred != combined_plantdoc.CLASS_MAP['healthy']:
            preds_for_combined.append(1)
        else:
            preds_for_combined.append(0)

accuracy_score(labels_for_combined, preds_for_combined)

0.9242112482853223

In [19]:
f1_score(labels_for_combined, preds_for_combined)

0.947717057014431

In [20]:
total_params = sum(p.numel() for p in combined_classifier.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 269037
