In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
import matplotlib.pyplot as plt

In [3]:
# Configuration
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
DATA_DIR = 'PlantVillage'  

In [5]:
def create_model(num_classes):
    """Create a CNN model with transfer learning"""
    base_model = keras.applications.MobileNetV2(
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model
    base_model.trainable = False
    
    model = keras.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dropout(0.3),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model

In [7]:
def prepare_data(data_dir):
    """Prepare training and validation data"""
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        validation_split=0.2,
        fill_mode='nearest'
    )
    
    train_generator = train_datagen.flow_from_directory(
        data_dir,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='training'
    )
    
    val_generator = train_datagen.flow_from_directory(
        data_dir,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='validation'
    )
    
    return train_generator, val_generator


In [9]:
def plot_history(history):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(history.history['accuracy'], label='Train Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Val Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(history.history['loss'], label='Train Loss')
    ax2.plot(history.history['val_loss'], label='Val Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    print("Training history plot saved as 'training_history.png'")


In [11]:
def train_model():
    """Main training function"""
    print("Preparing data...")
    train_gen, val_gen = prepare_data(DATA_DIR)
    
    num_classes = len(train_gen.class_indices)
    print(f"Number of classes: {num_classes}")
    print(f"Classes: {list(train_gen.class_indices.keys())}")
    
    print("\nBuilding model...")
    model = create_model(num_classes)
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    print("\nModel Summary:")
    model.summary()
    
    # Callbacks
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=5,
            restore_best_weights=True
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=3,
            min_lr=1e-7
        ),
        keras.callbacks.ModelCheckpoint(
            'crop_disease_model_best.h5',
            monitor='val_accuracy',
            save_best_only=True
        )
    ]
    
    print("\nTraining model...")
    history = model.fit(
        train_gen,
        epochs=EPOCHS,
        validation_data=val_gen,
        callbacks=callbacks
    )
    
    # Save final model
    model.save('crop_disease_model.h5')
    print("\nModel saved as 'crop_disease_model.h5'")
    
    # Save class labels
    import json
    class_labels = {v: k for k, v in train_gen.class_indices.items()}
    with open('class_labels.json', 'w') as f:
        json.dump(class_labels, f)
    print("Class labels saved as 'class_labels.json'")
    
    # Plot training history
    plot_history(history)
    
    return model, history


In [None]:
if __name__ == "__main__":
    # Check if data directory exists
    if not os.path.exists(DATA_DIR):
        print(f"Error: Data directory '{DATA_DIR}' not found!")
        print("\nPlease download the PlantVillage dataset from Kaggle:")
        print("https://www.kaggle.com/datasets/emmarex/plantdisease")
        print("Extract it and update the DATA_DIR variable in this script.")
    else:
        model, history = train_model()
        print("\nTraining complete!")
        print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")

Preparing data...
Found 16516 images belonging to 15 classes.
Found 4122 images belonging to 15 classes.
Number of classes: 15
Classes: ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']

Building model...

Model Summary:



Training model...
Epoch 1/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.5918 - loss: 1.3310



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m796s[0m 2s/step - accuracy: 0.6922 - loss: 0.9651 - val_accuracy: 0.8190 - val_loss: 0.5237 - learning_rate: 0.0010
Epoch 2/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.7687 - loss: 0.6744



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1068s[0m 2s/step - accuracy: 0.7838 - loss: 0.6361 - val_accuracy: 0.8542 - val_loss: 0.4275 - learning_rate: 0.0010
Epoch 3/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7s/step - accuracy: 0.8160 - loss: 0.5512



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4044s[0m 8s/step - accuracy: 0.8158 - loss: 0.5504 - val_accuracy: 0.8574 - val_loss: 0.4063 - learning_rate: 0.0010
Epoch 4/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8113 - loss: 0.5405



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m909s[0m 2s/step - accuracy: 0.8156 - loss: 0.5348 - val_accuracy: 0.8816 - val_loss: 0.3434 - learning_rate: 0.0010
Epoch 5/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m950s[0m 2s/step - accuracy: 0.8285 - loss: 0.5026 - val_accuracy: 0.8695 - val_loss: 0.3840 - learning_rate: 0.0010
Epoch 6/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8343 - loss: 0.4831



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m848s[0m 2s/step - accuracy: 0.8349 - loss: 0.4847 - val_accuracy: 0.8855 - val_loss: 0.3469 - learning_rate: 0.0010
Epoch 7/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m964s[0m 2s/step - accuracy: 0.8430 - loss: 0.4639 - val_accuracy: 0.8848 - val_loss: 0.3318 - learning_rate: 0.0010
Epoch 8/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.8463 - loss: 0.4475



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1134s[0m 2s/step - accuracy: 0.8420 - loss: 0.4638 - val_accuracy: 0.8967 - val_loss: 0.3121 - learning_rate: 0.0010
Epoch 9/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1237s[0m 2s/step - accuracy: 0.8483 - loss: 0.4452 - val_accuracy: 0.8901 - val_loss: 0.3165 - learning_rate: 0.0010
Epoch 10/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m889s[0m 2s/step - accuracy: 0.8501 - loss: 0.4312 - val_accuracy: 0.8908 - val_loss: 0.3193 - learning_rate: 0.0010
Epoch 11/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m841s[0m 2s/step - accuracy: 0.8526 - loss: 0.4302 - val_accuracy: 0.8906 - val_loss: 0.3283 - learning_rate: 0.0010
Epoch 12/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8624 - loss: 0.3997



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m880s[0m 2s/step - accuracy: 0.8622 - loss: 0.3972 - val_accuracy: 0.9003 - val_loss: 0.2907 - learning_rate: 5.0000e-04
Epoch 13/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8732 - loss: 0.3837



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m875s[0m 2s/step - accuracy: 0.8725 - loss: 0.3780 - val_accuracy: 0.9115 - val_loss: 0.2737 - learning_rate: 5.0000e-04
Epoch 14/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8726 - loss: 0.3875



[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m879s[0m 2s/step - accuracy: 0.8709 - loss: 0.3810 - val_accuracy: 0.9151 - val_loss: 0.2542 - learning_rate: 5.0000e-04
Epoch 15/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m878s[0m 2s/step - accuracy: 0.8781 - loss: 0.3634 - val_accuracy: 0.9073 - val_loss: 0.2724 - learning_rate: 5.0000e-04
Epoch 16/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m909s[0m 2s/step - accuracy: 0.8724 - loss: 0.3721 - val_accuracy: 0.9078 - val_loss: 0.2705 - learning_rate: 5.0000e-04
Epoch 17/20
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m848s[0m 2s/step - accuracy: 0.8743 - loss: 0.3654 - val_accuracy: 0.9064 - val_loss: 0.2636 - learning_rate: 5.0000e-04
Epoch 18/20
[1m266/517[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m7:01[0m 2s/step - accuracy: 0.8832 - loss: 0.3385