# 🔬 Plant Disease Detection using CNN with Transfer Learning

This notebook implements MobileNetV2 and ResNet50 for plant disease classification using transfer learning.

## Features:
- Transfer Learning with pre-trained models
- Multi-class disease classification
- Data augmentation pipeline
- Two-phase training (frozen + fine-tuning)
- Confidence scoring for predictions

In [1]:
# Import required libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Deep Learning
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2, ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

# Sklearn
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import LabelEncoder

# Utils
import json
from datetime import datetime
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("📚 Libraries imported successfully!")
print(f"TensorFlow version: {tf.__version__}")

📚 Libraries imported successfully!
TensorFlow version: 2.20.0-rc0


In [2]:
# Configuration
MODEL_TYPE = 'mobilenetv2'  # Change to 'resnet50' if preferred
INPUT_SHAPE = (224, 224, 3)
BATCH_SIZE = 32
EPOCHS = 20
FINE_TUNE_EPOCHS = 10

# Paths
BASE_PATH = Path("/Users/debabratapattnayak/web-dev/greencast")
DATASET_PATH = BASE_PATH / "processed_data" / "plantvillage_color_symlinks"
MODELS_PATH = BASE_PATH / "ml_models" / "trained_models"
RESULTS_PATH = BASE_PATH / "ml_models" / "results"

# Create directories
MODELS_PATH.mkdir(parents=True, exist_ok=True)
RESULTS_PATH.mkdir(parents=True, exist_ok=True)

print(f"📁 Dataset path: {DATASET_PATH}")
print(f"📁 Models will be saved to: {MODELS_PATH}")
print(f"📁 Results will be saved to: {RESULTS_PATH}")

📁 Dataset path: /Users/debabratapattnayak/web-dev/greencast/processed_data/plantvillage_color_symlinks
📁 Models will be saved to: /Users/debabratapattnayak/web-dev/greencast/ml_models/trained_models
📁 Results will be saved to: /Users/debabratapattnayak/web-dev/greencast/ml_models/results


In [3]:
# Check dataset and count classes
if not DATASET_PATH.exists():
    print(f"❌ Dataset not found at: {DATASET_PATH}")
    print("Please ensure the processed dataset exists.")
else:
    train_path = DATASET_PATH / 'train'
    if train_path.exists():
        classes = [d.name for d in train_path.iterdir() if d.is_dir()]
        num_classes = len(classes)
        
        print(f"📊 Found {num_classes} disease classes")
        print(f"🏷️ Sample classes: {classes[:5]}..." if len(classes) > 5 else f"🏷️ Classes: {classes}")
        
        # Count images per class (sample)
        class_counts = {}
        for class_name in classes[:5]:  # Check first 5 classes
            class_path = train_path / class_name
            image_count = 0
            for ext in ['*.jpg', '*.jpeg', '*.JPG', '*.JPEG', '*.png', '*.PNG']:
                image_count += len(list(class_path.glob(ext)))
            class_counts[class_name] = image_count
        
        print(f"\n📈 Sample class distribution:")
        for class_name, count in class_counts.items():
            print(f"  {class_name}: {count} images")
    else:
        print(f"❌ Training data not found at: {train_path}")
        num_classes = 0

📊 Found 38 disease classes
🏷️ Sample classes: ['Strawberry___healthy', 'Grape___Black_rot', 'Potato___Early_blight', 'Blueberry___healthy', 'Corn_(maize)___healthy']...

📈 Sample class distribution:
  Strawberry___healthy: 320 images
  Grape___Black_rot: 826 images
  Potato___Early_blight: 700 images
  Blueberry___healthy: 1052 images
  Corn_(maize)___healthy: 814 images


## 📊 Data Preparation

Create data generators with augmentation for training.

In [5]:
def create_data_generators(dataset_path, batch_size=32, validation_split=0.2):
    """Create data generators for training"""
    
    print(f"📁 Creating data generators from: {dataset_path}")
    
    # Data augmentation for training
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest',
        validation_split=validation_split
    )
    
    # Only rescaling for validation
    val_datagen = ImageDataGenerator(
        rescale=1./255,
        validation_split=validation_split
    )
    
    # Create generators
    train_generator = train_datagen.flow_from_directory(
        dataset_path / 'train',
        target_size=INPUT_SHAPE[:2],
        batch_size=batch_size,
        class_mode='categorical',
        subset='training',
        shuffle=True
    )
    
    validation_generator = val_datagen.flow_from_directory(
        dataset_path / 'train',
        target_size=INPUT_SHAPE[:2],
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation',
        shuffle=False
    )
    
    # Test generator (if test directory exists)
    test_generator = None
    if (dataset_path / 'test').exists():
        test_datagen = ImageDataGenerator(rescale=1./255)
        test_generator = test_datagen.flow_from_directory(
            dataset_path / 'test',
            target_size=INPUT_SHAPE[:2],
            batch_size=batch_size,
            class_mode='categorical',
            shuffle=False
        )
    
    print(f"✅ Data generators created!")
    print(f"📊 Training samples: {train_generator.samples}")
    print(f"📊 Validation samples: {validation_generator.samples}")
    if test_generator:
        print(f"📊 Test samples: {test_generator.samples}")
    
    return train_generator, validation_generator, test_generator

# Create data generators
if DATASET_PATH.exists() and num_classes > 0:
    train_gen, val_gen, test_gen = create_data_generators(DATASET_PATH, BATCH_SIZE)
    
    # Store class information
    class_indices = train_gen.class_indices
    class_names = list(class_indices.keys())
    
    print(f"🏷️ Classes ({len(class_names)}): {class_names[:5]}...")
else:
    print("❌ Cannot create data generators without valid dataset")

📁 Creating data generators from: /Users/debabratapattnayak/web-dev/greencast/processed_data/plantvillage_color_symlinks
Found 30453 images belonging to 38 classes.
Found 7594 images belonging to 38 classes.
Found 8129 images belonging to 38 classes.
✅ Data generators created!
📊 Training samples: 30453
📊 Validation samples: 7594
📊 Test samples: 8129
🏷️ Classes (38): ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy']...
