# 🌱 PlantVillage Crop Disease Classifier for Android

This notebook creates a TensorFlow Lite model from the PlantVillage dataset for your Android app.

## 🚀 Quick Start:
1. Upload your PlantVillage dataset zip file
2. Run all cells
3. Download the generated .tflite file
4. Replace model.tflite in your Android project

In [None]:
# Install and import required packages
!pip install tensorflow matplotlib pillow

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import zipfile
import shutil
from pathlib import Path
from google.colab import files

print("✅ TensorFlow version:", tf.__version__)
print("✅ GPU available:", len(tf.config.list_physical_devices('GPU')) > 0)

In [None]:
# Configuration
CONFIG = {
    'IMAGE_SIZE': (128, 128),  # Match your Android app
    'BATCH_SIZE': 32,
    'EPOCHS': 15,
    'LEARNING_RATE': 0.001,
    'NUM_CLASSES': 2,  # healthy, diseased
    'MODEL_NAME': 'plantvillage_crop_classifier'
}

print("📋 Configuration:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

In [None]:
# Upload and extract PlantVillage dataset
print("📁 Please upload your PlantVillage dataset zip file:")
uploaded = files.upload()

# Extract the dataset
zip_filename = list(uploaded.keys())[0]
print(f"\n📦 Extracting {zip_filename}...")

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall('.')

print("✅ Extraction completed!")

In [None]:
# Find and explore the dataset structure
def find_dataset_folder():
    """Find the PlantVillage dataset folder"""
    current_dir = Path('.')
    
    # Look for folders that might contain the dataset
    candidates = []
    
    for item in current_dir.iterdir():
        if item.is_dir():
            # Count subdirectories
            subdirs = [d for d in item.iterdir() if d.is_dir()]
            
            # Check if it looks like PlantVillage (many class directories)
            if len(subdirs) > 10:
                # Check for typical PlantVillage class names
                subdir_names = [d.name.lower() for d in subdirs]
                plant_keywords = ['apple', 'corn', 'tomato', 'potato', 'grape', 'pepper']
                health_keywords = ['healthy', 'disease', 'scab', 'rust', 'blight', 'spot']
                
                has_plants = any(any(plant in name for plant in plant_keywords) for name in subdir_names)
                has_health = any(any(health in name for health in health_keywords) for name in subdir_names)
                
                if has_plants and has_health:
                    candidates.append((item, len(subdirs)))
    
    if candidates:
        # Return the folder with the most subdirectories
        best_candidate = max(candidates, key=lambda x: x[1])
        return best_candidate[0]
    
    return None

# Find the dataset
dataset_folder = find_dataset_folder()

if dataset_folder:
    print(f"✅ Found dataset folder: {dataset_folder.name}")
    
    # List some classes
    class_dirs = [d for d in dataset_folder.iterdir() if d.is_dir()]
    print(f"📊 Found {len(class_dirs)} classes")
    print("📋 Sample classes:")
    for i, class_dir in enumerate(sorted(class_dirs)[:10]):
        image_count = len([f for f in class_dir.iterdir() if f.suffix.lower() in ['.jpg', '.jpeg', '.png']])
        print(f"   {class_dir.name}: {image_count} images")
    
    if len(class_dirs) > 10:
        print(f"   ... and {len(class_dirs) - 10} more classes")
        
else:
    print("❌ Could not find PlantVillage dataset!")
    print("Available folders:")
    for item in Path('.').iterdir():
        if item.is_dir():
            subdirs = len([d for d in item.iterdir() if d.is_dir()])
            print(f"   📁 {item.name} ({subdirs} subdirectories)")

In [None]:
# Reorganize PlantVillage for binary classification (healthy vs diseased)
def reorganize_plantvillage(source_folder):
    """Reorganize PlantVillage dataset into healthy/diseased binary classification"""
    
    print("🔄 Reorganizing PlantVillage for binary classification...")
    
    # Create output directories
    output_dir = Path("crop_disease_dataset")
    train_dir = output_dir / "train"
    val_dir = output_dir / "val"
    
    # Remove existing output directory if it exists
    if output_dir.exists():
        shutil.rmtree(output_dir)
    
    # Create directory structure
    for split in ["train", "val"]:
        for class_name in ["healthy", "diseased"]:
            (output_dir / split / class_name).mkdir(parents=True, exist_ok=True)
    
    # Process each class directory
    source_path = Path(source_folder)
    class_dirs = [d for d in source_path.iterdir() if d.is_dir()]
    
    healthy_total = 0
    diseased_total = 0
    processed_classes = 0
    
    for class_dir in class_dirs:
        class_name = class_dir.name.lower()
        
        # Determine if healthy or diseased
        if "healthy" in class_name:
            target_class = "healthy"
        else:
            target_class = "diseased"
        
        # Find all image files (multiple extensions)
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPG', '.JPEG', '.PNG']
        image_files = []
        
        for ext in image_extensions:
            image_files.extend(list(class_dir.glob(f"*{ext}")))
        
        if len(image_files) > 0:
            print(f"   📁 {class_dir.name}: {len(image_files)} images → {target_class}")
            
            # Shuffle images for random split
            np.random.shuffle(image_files)
            
            # Split 80/20 train/validation
            split_idx = int(len(image_files) * 0.8)
            train_files = image_files[:split_idx]
            val_files = image_files[split_idx:]
            
            # Copy training files
            for i, img_file in enumerate(train_files):
                dest = train_dir / target_class / f"{class_dir.name}_{i:04d}.jpg"
                try:
                    shutil.copy2(img_file, dest)
                except Exception as e:
                    print(f"     ⚠️ Error copying {img_file.name}: {e}")
            
            # Copy validation files
            for i, img_file in enumerate(val_files):
                dest = val_dir / target_class / f"{class_dir.name}_{i:04d}.jpg"
                try:
                    shutil.copy2(img_file, dest)
                except Exception as e:
                    print(f"     ⚠️ Error copying {img_file.name}: {e}")
            
            # Update counters
            if target_class == "healthy":
                healthy_total += len(image_files)
            else:
                diseased_total += len(image_files)
            
            processed_classes += 1
    
    # Verify the reorganization
    train_healthy = len(list((train_dir / "healthy").glob("*")))
    train_diseased = len(list((train_dir / "diseased").glob("*")))
    val_healthy = len(list((val_dir / "healthy").glob("*")))
    val_diseased = len(list((val_dir / "diseased").glob("*")))
    
    print(f"\n✅ Dataset reorganization completed!")
    print(f"📊 Processed {processed_classes} classes")
    print(f"📊 Training set: {train_healthy} healthy, {train_diseased} diseased")
    print(f"📊 Validation set: {val_healthy} healthy, {val_diseased} diseased")
    print(f"📊 Total images: {train_healthy + train_diseased + val_healthy + val_diseased}")
    
    if train_healthy + train_diseased + val_healthy + val_diseased == 0:
        print("❌ No images were processed! Check the dataset structure.")
        return None, None
    
    return str(train_dir), str(val_dir)

# Reorganize the dataset
if dataset_folder:
    TRAIN_DIR, VAL_DIR = reorganize_plantvillage(dataset_folder)
else:
    print("❌ Cannot proceed without dataset folder")
    TRAIN_DIR, VAL_DIR = None, None

In [None]:
# Create data generators (only if we have valid directories)
if TRAIN_DIR and VAL_DIR:
    # Data augmentation for training
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    # Validation data (no augmentation)
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    # Create generators
    train_generator = train_datagen.flow_from_directory(
        TRAIN_DIR,
        target_size=CONFIG['IMAGE_SIZE'],
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical',
        shuffle=True
    )
    
    val_generator = val_datagen.flow_from_directory(
        VAL_DIR,
        target_size=CONFIG['IMAGE_SIZE'],
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical',
        shuffle=False
    )
    
    print(f"✅ Data generators created!")
    print(f"📊 Training samples: {train_generator.samples}")
    print(f"📊 Validation samples: {val_generator.samples}")
    print(f"📊 Classes: {train_generator.class_indices}")
    
    # Show sample images
    plt.figure(figsize=(12, 8))
    sample_batch, sample_labels = next(train_generator)
    
    for i in range(min(8, len(sample_batch))):
        plt.subplot(2, 4, i + 1)
        plt.imshow(sample_batch[i])
        class_name = 'Healthy' if sample_labels[i][0] > 0.5 else 'Diseased'
        plt.title(f'{class_name}')
        plt.axis('off')
    
    plt.suptitle('Sample Training Images')
    plt.tight_layout()
    plt.show()
    
else:
    print("❌ Cannot create data generators without valid dataset directories")

In [None]:
# Create MobileNetV2 model (optimized for mobile)
def create_mobile_model(num_classes=2, input_shape=(128, 128, 3)):
    # Use MobileNetV2 as base (mobile-optimized)
    base_model = MobileNetV2(
        input_shape=input_shape,
        alpha=0.75,  # Width multiplier for smaller model
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model initially
    base_model.trainable = False
    
    # Add custom classification head
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    
    return model

# Create the model
model = create_mobile_model(
    num_classes=CONFIG['NUM_CLASSES'],
    input_shape=(*CONFIG['IMAGE_SIZE'], 3)
)

print("🏗️ Model created!")
print(f"📊 Total parameters: {model.count_params():,}")
model.summary()

In [None]:
# Compile and train the model
if TRAIN_DIR and VAL_DIR:
    model.compile(
        optimizer=Adam(learning_rate=CONFIG['LEARNING_RATE']),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Callbacks
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=5,
            restore_best_weights=True
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=3,
            min_lr=1e-7
        )
    ]
    
    print("🚀 Starting training...")
    
    # Train the model
    history = model.fit(
        train_generator,
        epochs=CONFIG['EPOCHS'],
        validation_data=val_generator,
        callbacks=callbacks,
        verbose=1
    )
    
    print("✅ Training completed!")
    
else:
    print("❌ Cannot train without valid dataset")
    # Create a dummy trained model for demonstration
    print("🔄 Creating demo model with initialized weights...")
    
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Dummy training data
    dummy_x = np.random.random((10, *CONFIG['IMAGE_SIZE'], 3)).astype(np.float32)
    dummy_y = tf.keras.utils.to_categorical(
        np.random.randint(0, CONFIG['NUM_CLASSES'], 10), 
        CONFIG['NUM_CLASSES']
    )
    
    model.fit(dummy_x, dummy_y, epochs=1, verbose=0)
    print("✅ Demo model ready!")

In [None]:
# Plot training history (if we have real training)
if 'history' in locals():
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    final_acc = history.history['val_accuracy'][-1]
    print(f"🎯 Final validation accuracy: {final_acc:.3f}")
else:
    print("📊 No training history to display (demo mode)")

In [None]:
# Convert to TensorFlow Lite with INT8 quantization
def convert_to_tflite_quantized(model, model_name):
    # Representative dataset for quantization
    def representative_dataset():
        for _ in range(100):
            # Generate representative data
            data = np.random.random((1, *CONFIG['IMAGE_SIZE'], 3)).astype(np.float32)
            yield [data]
    
    # Try INT8 quantization first (best for mobile)
    print("🔄 Converting to TensorFlow Lite with INT8 quantization...")
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    
    try:
        tflite_model = converter.convert()
        filename = f'{model_name}_int8.tflite'
        quantized = True
        print("✅ INT8 quantized model created!")
    except Exception as e:
        print(f"⚠️ INT8 quantization failed: {e}")
        print("🔄 Falling back to float32...")
        
        # Fallback to float32
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        tflite_model = converter.convert()
        filename = f'{model_name}_float32.tflite'
        quantized = False
        print("✅ Float32 model created!")
    
    # Save the model
    with open(filename, 'wb') as f:
        f.write(tflite_model)
    
    print(f"💾 Model saved as: {filename}")
    print(f"📊 Model size: {len(tflite_model)} bytes ({len(tflite_model)/1024:.1f} KB)")
    
    return tflite_model, filename

# Convert the model
tflite_model, tflite_filename = convert_to_tflite_quantized(model, CONFIG['MODEL_NAME'])

In [None]:
# Test the TensorFlow Lite model
def test_tflite_model(filename):
    print(f"🧪 Testing TensorFlow Lite model: {filename}")
    
    # Load the model
    interpreter = tf.lite.Interpreter(model_path=filename)
    interpreter.allocate_tensors()
    
    # Get input and output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print("\n🔍 Model Details:")
    print(f"   Input shape: {input_details[0]['shape']}")
    print(f"   Input type: {input_details[0]['dtype']}")
    print(f"   Output shape: {output_details[0]['shape']}")
    print(f"   Output type: {output_details[0]['dtype']}")
    
    # Test with sample data
    input_shape = input_details[0]['shape']
    input_dtype = input_details[0]['dtype']
    
    if input_dtype == np.int8:
        test_input = np.random.randint(-128, 127, input_shape, dtype=np.int8)
        print("   Using INT8 input range: [-128, 127]")
    elif input_dtype == np.uint8:
        test_input = np.random.randint(0, 255, input_shape, dtype=np.uint8)
        print("   Using UINT8 input range: [0, 255]")
    else:
        test_input = np.random.random(input_shape).astype(np.float32)
        print("   Using FLOAT32 input range: [0, 1]")
    
    # Run inference
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    
    print(f"\n🧪 Test Results:")
    print(f"   Raw output: {output}")
    
    # Convert output based on type
    if output_details[0]['dtype'] == np.int8:
        # Convert int8 to probabilities
        probs = (output.astype(np.float32) + 128) / 255.0
        print(f"   Converted probabilities: {probs}")
        predicted_class = np.argmax(probs)
        confidence = np.max(probs)
    else:
        predicted_class = np.argmax(output)
        confidence = np.max(output)
    
    class_names = ['healthy', 'diseased']
    print(f"   Predicted: {class_names[predicted_class]} ({confidence:.3f} confidence)")
    
    return True

# Test the model
test_success = test_tflite_model(tflite_filename)
print("\n✅ Model testing completed!")

In [None]:
# Create labels.txt file for Android
labels = ['healthy', 'diseased']
with open('labels.txt', 'w') as f:
    for label in labels:
        f.write(f'{label}\n')

print("📝 labels.txt created!")
print(f"📋 Labels: {labels}")

# Display final file information
model_size = os.path.getsize(tflite_filename)
labels_size = os.path.getsize('labels.txt')

print(f"\n📊 Generated Files:")
print(f"   📱 {tflite_filename}: {model_size:,} bytes ({model_size/1024:.1f} KB)")
print(f"   📝 labels.txt: {labels_size} bytes")

print(f"\n🎯 Model Specifications:")
print(f"   📐 Input: {CONFIG['IMAGE_SIZE'][0]}x{CONFIG['IMAGE_SIZE'][1]}x3")
print(f"   🏷️ Classes: {CONFIG['NUM_CLASSES']} (healthy, diseased)")
print(f"   🔢 Type: {'INT8 Quantized' if 'int8' in tflite_filename else 'Float32'}")
print(f"   💾 Size: {model_size/1024:.1f} KB")

In [None]:
# Download files for Android integration
print("📱 Ready for Android Integration!")
print("\n🔽 Downloading files...")

# Download the model and labels
files.download(tflite_filename)
files.download('labels.txt')

print("\n✅ Files downloaded successfully!")
print("\n📋 Next Steps:")
print("1. 📁 Replace 'app/src/main/assets/model.tflite' with the downloaded model")
print("2. 📝 Replace 'app/src/main/assets/labels.txt' with the downloaded labels")
print("3. 🔨 Build your Android project: ./gradlew assembleDebug")
print("4. 📱 Install and test your app")
print("5. 🎉 Your app should now detect crop diseases in real-time!")

print("\n💡 Tips:")
print("- The model expects 128x128 pixel images")
print("- Input should be INT8 format (range -128 to 127)")
print("- Output is INT8 format (convert to probabilities)")
print("- Your Android code is already configured for this model format!")