In [26]:


import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing import image_dataset_from_directory
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt


DATASET_DIR = "/home/sk/Desktop/crop_disease_dataset/archive(1)/onion datasets"

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
SEED = 42
EPOCHS = 20


train_ds = image_dataset_from_directory(
    DATASET_DIR,
    validation_split=0.2,
    subset="training",
    seed=SEED,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

val_ds = image_dataset_from_directory(
    DATASET_DIR,
    validation_split=0.2,
    subset="validation",
    seed=SEED,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

class_names = train_ds.class_names
num_classes = len(class_names)
print("Classes:", class_names)




class_weight = {
    1: 2.4561,  # Virosis-D
    2: 0.3656,  # Healthy leaves
    3: 1.4847,  # Purple blotch
    5: 0.6622,  # Iris yellow virus_augment
    6: 1.5151,  # Alternaria_D
    7: 0.9855,  # Fusarium-D
    8: 0.8071,  # Caterpillar-P
    9: 0.7830   # stemphylium Leaf Blight
}

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)


data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
])


base_model = keras.applications.EfficientNetV2B0(
    include_top=False,
    input_shape=IMG_SIZE + (3,),
    weights="imagenet"
)
base_model.trainable = False

inputs = keras.Input(shape=IMG_SIZE + (3,))
x = data_augmentation(inputs)
x = keras.applications.efficientnet_v2.preprocess_input(x)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs, outputs)

Found 11968 files belonging to 8 classes.
Using 9575 files for training.
Found 11968 files belonging to 8 classes.
Using 2393 files for validation.
Classes: ['Alternaria_D', 'Caterpillar-P', 'Fusarium-D', 'Healthy leaves', 'Iris yellow virus_augment', 'Purple blotch', 'Virosis-D', 'stemphylium Leaf Blight']


In [30]:
# 8. Compile
# ----------------------
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

# ----------------------
# 9. Train with Class Weights
# ----------------------
callbacks = [
    keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
    keras.callbacks.ModelCheckpoint("efficientnetv2_onion1.h5", save_best_only=True)
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks,
    class_weight=class_weight
)

Epoch 1/10


E0000 00:00:1758619261.575857  352153 meta_optimizer.cc:967] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/functional_11_1/efficientnetv2-b0_1/block2b_drop_1/stateless_dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 199ms/step - accuracy: 0.8449 - loss: 0.3998 



[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 250ms/step - accuracy: 0.8436 - loss: 0.4106 - val_accuracy: 0.8512 - val_loss: 0.4144
Epoch 2/10
[1m299/300[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 202ms/step - accuracy: 0.8407 - loss: 0.4024 



[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 250ms/step - accuracy: 0.8425 - loss: 0.4063 - val_accuracy: 0.8508 - val_loss: 0.4122
Epoch 3/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 254ms/step - accuracy: 0.8480 - loss: 0.3882 - val_accuracy: 0.8462 - val_loss: 0.4136
Epoch 4/10
[1m299/300[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 210ms/step - accuracy: 0.8496 - loss: 0.3891  



[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 260ms/step - accuracy: 0.8476 - loss: 0.3989 - val_accuracy: 0.8533 - val_loss: 0.4014
Epoch 5/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 211ms/step - accuracy: 0.8492 - loss: 0.3867 



[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 261ms/step - accuracy: 0.8476 - loss: 0.3911 - val_accuracy: 0.8567 - val_loss: 0.3955
Epoch 6/10
[1m299/300[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 211ms/step - accuracy: 0.8502 - loss: 0.3751  



[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 259ms/step - accuracy: 0.8483 - loss: 0.3877 - val_accuracy: 0.8546 - val_loss: 0.3937
Epoch 7/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 256ms/step - accuracy: 0.8518 - loss: 0.3892 - val_accuracy: 0.8517 - val_loss: 0.4008
Epoch 8/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 257ms/step - accuracy: 0.8516 - loss: 0.3798 - val_accuracy: 0.8517 - val_loss: 0.3947
Epoch 9/10
[1m299/300[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 208ms/step - accuracy: 0.8523 - loss: 0.3751 



[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 257ms/step - accuracy: 0.8517 - loss: 0.3779 - val_accuracy: 0.8583 - val_loss: 0.3877
Epoch 10/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 254ms/step - accuracy: 0.8559 - loss: 0.3771 - val_accuracy: 0.8562 - val_loss: 0.3910


In [None]:
# 10. Evaluate & Confusion Matrix
# ----------------------
# Get true labels and predictions
import seaborn as sns
y_true = []
y_pred = []
for images, labels in val_ds:
    preds = model.predict(images, verbose=0)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

# Classification Report
print("\nClassification Report:\n")
print(classification_report(y_true, y_pred, target_names=class_names))
