# Plant Disease Detection using Transfer Learning (EfficientNet)

This notebook demonstrates how to use transfer learning with EfficientNet to classify plant diseases based on images.

In [23]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt

# Check TensorFlow version
print(f'TensorFlow Version: {tf.__version__}')

TensorFlow Version: 2.13.0


## 1. Load and Preprocess Dataset
Assume we have a dataset structured with separate folders for each class of plant disease.

In [32]:
# Define paths
train_dir = '../input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train'
val_dir = '../input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train'

# Image dimensions
img_height, img_width = 256, 256
batch_size = 32

# Data generators with augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1.0/255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)
val_datagen = ImageDataGenerator(rescale=1.0/255)

# Load data
train_data = train_datagen.flow_from_directory(
    train_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)
val_data = val_datagen.flow_from_directory(
    val_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)


for data_batch, labels_batch in train_data:
    print(data_batch.shape, labels_batch.shape)
    break

Found 70295 images belonging to 38 classes.
Found 70295 images belonging to 38 classes.
(32, 256, 256, 3) (32, 38)


## 2. Build the Model using EfficientNet

In [46]:
# Load the EfficientNetB0 model pre-trained on ImageNet
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(img_height, img_width, 3))

# Freeze the base model
base_model.trainable = True

# Add custom layers on top
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
output = Dense(train_data.num_classes, activation='softmax')(x)

# Define the model
model = Model(inputs=base_model.input, outputs=output)


# Compile the model
model.compile(optimizer="Adam", loss='categorical_crossentropy', metrics=['accuracy'])

# Summary of the model
model.summary()

print(f"Train data classes: {train_data.num_classes}")

Model: "model_9"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_10 (InputLayer)       [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 rescaling_18 (Rescaling)    (None, 256, 256, 3)          0         ['input_10[0][0]']            
                                                                                                  
 normalization_9 (Normaliza  (None, 256, 256, 3)          7         ['rescaling_18[0][0]']        
 tion)                                                                                            
                                                                                                  
 rescaling_19 (Rescaling)    (None, 256, 256, 3)          0         ['normalization_9[0][0]'

## 3. Train the Model

In [50]:
# Callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True)

class BatchValidationCallback(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print(f" Step {batch + 1}: Train Accuracy = {logs['accuracy']:.4f}")
    
    def on_test_batch_end(self, batch, logs=None):
        print(f" Validation Step {batch + 1}: Validation Accuracy = {logs['accuracy']:.4f}")


log_batch_metrics = BatchValidationCallback()

# Train the model
try:
    history = model.fit_generator(
        train_data,
        validation_data=val_data,
        epochs=5,
        callbacks=[early_stopping, checkpoint, log_batch_metrics],
        verbose=1
    )
except Exception as e:
    print(f"Error during training: {e}")

  history = model.fit_generator(


Epoch 1/5
 Step 1: Train Accuracy = 0.7500
   1/2197 [..............................] - ETA: 35:34 - loss: 0.6456 - accuracy: 0.7500 Step 2: Train Accuracy = 0.8125
   2/2197 [..............................] - ETA: 18:55 - loss: 0.5657 - accuracy: 0.8125 Step 3: Train Accuracy = 0.8125
   3/2197 [..............................] - ETA: 20:28 - loss: 0.5549 - accuracy: 0.8125 Step 4: Train Accuracy = 0.8281
   4/2197 [..............................] - ETA: 21:13 - loss: 0.5893 - accuracy: 0.8281 Step 5: Train Accuracy = 0.8500
   5/2197 [..............................] - ETA: 20:21 - loss: 0.5229 - accuracy: 0.8500 Step 6: Train Accuracy = 0.8646
   6/2197 [..............................] - ETA: 19:32 - loss: 0.4721 - accuracy: 0.8646 Step 7: Train Accuracy = 0.8750
   7/2197 [..............................] - ETA: 20:24 - loss: 0.4565 - accuracy: 0.8750 Step 8: Train Accuracy = 0.8633
   8/2197 [..............................] - ETA: 20:40 - loss: 0.4858 - accuracy: 0.8633 Step 9: Train

KeyboardInterrupt: 

## 4. Evaluate the Model

In [None]:
# Evaluate on validation data
val_loss, val_acc = model.evaluate(val_data)
print(f'Validation Loss: {val_loss}')
print(f'Validation Accuracy: {val_acc}')

# Plot training history
plt.figure(figsize=(12, 6))
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

plt.figure(figsize=(12, 6))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()