# 🌱 Crop Disease Classification Model for Android

This notebook creates a TensorFlow Lite model optimized for your Android crop disease detection app.

## 📋 What this notebook does:
1. Creates a MobileNetV2-based model (optimized for mobile)
2. Trains on crop disease dataset
3. Converts to TensorFlow Lite with quantization
4. Tests the model
5. Provides download link for Android integration

## 🚀 Quick Start:
1. Upload your dataset or use a public one
2. Run all cells
3. Download the generated .tflite file
4. Replace model.tflite in your Android project

In [None]:
# Install required packages
!pip install tensorflow matplotlib pillow

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
from google.colab import files

print("✅ TensorFlow version:", tf.__version__)
print("✅ GPU available:", tf.config.list_physical_devices('GPU'))

In [None]:
# Configuration - Modify these as needed
CONFIG = {
    'IMAGE_SIZE': (128, 128),  # Match your Android app (128x128 for current model)
    'BATCH_SIZE': 32,
    'EPOCHS': 15,
    'LEARNING_RATE': 0.001,
    'NUM_CLASSES': 2,  # healthy, diseased
    'MODEL_NAME': 'crop_disease_classifier_v2'
}

print("📋 Configuration:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## 📁 Dataset Setup

### Option 1: Use Kaggle Dataset (Recommended)
Popular crop disease datasets on Kaggle:
- [Plant Disease Dataset](https://www.kaggle.com/vipoooool/new-plant-diseases-dataset)
- [Crop Disease Dataset](https://www.kaggle.com/rashikrahmanpritom/plant-disease-recognition-dataset)
- [PlantVillage Dataset](https://www.kaggle.com/emmarex/plantdisease)

### Option 2: Upload Your Own Dataset
Structure your data like this:
```
dataset/
  train/
    healthy/
      img1.jpg
      img2.jpg
    diseased/
      img1.jpg
      img2.jpg
  val/
    healthy/
    diseased/
```

In [None]:
# Option 1: Upload PlantVillage dataset (RECOMMENDED)
print("📁 Upload your PlantVillage dataset:")
uploaded = files.upload()  # Upload your plantvillage.zip file

# Extract the dataset
import zipfile
zip_filename = list(uploaded.keys())[0]  # Get uploaded filename
print(f"Extracting {zip_filename}...")

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall('.')

# Find the extracted folder
import os
folders = [f for f in os.listdir('.') if os.path.isdir(f) and ('plant' in f.lower() or 'village' in f.lower())]
if not folders:
    # Look for any folder with many subdirectories (likely the dataset)
    folders = [f for f in os.listdir('.') if os.path.isdir(f) and len([d for d in os.listdir(f) if os.path.isdir(os.path.join(f, d))]) > 10]

print(f"Found dataset folders: {folders}")
dataset_folder = folders[0] if folders else None

if dataset_folder:
    print(f"✅ Using dataset folder: {dataset_folder}")
    
    # Reorganize PlantVillage for binary classification
    def setup_plantvillage_binary(dataset_path):
        from pathlib import Path
        import shutil
        
        print("🔄 Reorganizing PlantVillage for binary classification...")
        
        # Create output directories
        output_dir = Path("crop_disease_dataset")
        train_dir = output_dir / "train"
        val_dir = output_dir / "val"
        
        for split in ["train", "val"]:
            for class_name in ["healthy", "diseased"]:
                (output_dir / split / class_name).mkdir(parents=True, exist_ok=True)
        
        # Process each class directory
        dataset_path = Path(dataset_path)
        class_dirs = [d for d in dataset_path.iterdir() if d.is_dir()]
        
        healthy_count = 0
        diseased_count = 0
        
        for class_dir in class_dirs:
            class_name = class_dir.name.lower()
            
            # Determine if healthy or diseased
            if "healthy" in class_name:
                target_class = "healthy"
            else:
                target_class = "diseased"
            
            # Get all images
            image_files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.jpeg")) + list(class_dir.glob("*.png"))
            
            if len(image_files) > 0:
                print(f"   {class_dir.name}: {len(image_files)} images → {target_class}")
                
                # Split 80/20
                split_idx = int(len(image_files) * 0.8)
                train_files = image_files[:split_idx]
                val_files = image_files[split_idx:]
                
                # Copy files
                for i, img_file in enumerate(train_files):
                    dest = train_dir / target_class / f"{class_dir.name}_{i:04d}.jpg"
                    shutil.copy2(img_file, dest)
                
                for i, img_file in enumerate(val_files):
                    dest = val_dir / target_class / f"{class_dir.name}_{i:04d}.jpg"
                    shutil.copy2(img_file, dest)
                
                if target_class == "healthy":
                    healthy_count += len(image_files)
                else:
                    diseased_count += len(image_files)
        
        # Print summary
        train_healthy = len(list((train_dir / "healthy").glob("*")))
        train_diseased = len(list((train_dir / "diseased").glob("*")))
        val_healthy = len(list((val_dir / "healthy").glob("*")))
        val_diseased = len(list((val_dir / "diseased").glob("*")))
        
        print(f"\n✅ Dataset reorganized!")
        print(f"📊 Training: {train_healthy} healthy, {train_diseased} diseased")
        print(f"📊 Validation: {val_healthy} healthy, {val_diseased} diseased")
        
        return str(train_dir), str(val_dir)
    
    # Setup the dataset
    TRAIN_DIR, VAL_DIR = setup_plantvillage_binary(dataset_folder)
    
else:
    print("⚠️ Dataset folder not found. Creating demo dataset...")
    # Option 2: For demo, create synthetic data
    print("🔄 Creating demo dataset...")
os.makedirs('demo_dataset/train/healthy', exist_ok=True)
os.makedirs('demo_dataset/train/diseased', exist_ok=True)
os.makedirs('demo_dataset/val/healthy', exist_ok=True)
os.makedirs('demo_dataset/val/diseased', exist_ok=True)

# Create some dummy images for demo
for i in range(50):
    # Healthy images (greenish)
    img = np.random.randint(50, 200, (128, 128, 3), dtype=np.uint8)
    img[:, :, 1] += 50  # More green
    plt.imsave(f'demo_dataset/train/healthy/img_{i}.jpg', img)
    
    # Diseased images (brownish/yellowish)
    img = np.random.randint(100, 255, (128, 128, 3), dtype=np.uint8)
    img[:, :, 2] = img[:, :, 2] // 2  # Less blue
    plt.imsave(f'demo_dataset/train/diseased/img_{i}.jpg', img)

# Validation set
for i in range(10):
    img = np.random.randint(50, 200, (128, 128, 3), dtype=np.uint8)
    img[:, :, 1] += 50
    plt.imsave(f'demo_dataset/val/healthy/img_{i}.jpg', img)
    
    img = np.random.randint(100, 255, (128, 128, 3), dtype=np.uint8)
    img[:, :, 2] = img[:, :, 2] // 2
    plt.imsave(f'demo_dataset/val/diseased/img_{i}.jpg', img)

print("✅ Demo dataset created!")

# Set dataset paths (modify these for your actual dataset)
TRAIN_DIR = 'demo_dataset/train'
VAL_DIR = 'demo_dataset/val'

print(f"📁 Training data: {TRAIN_DIR}")
print(f"📁 Validation data: {VAL_DIR}")

In [None]:
# Create data generators
def create_data_generators(train_dir, val_dir):
    # Data augmentation for training
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    # Validation data (no augmentation)
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    # Create generators
    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=CONFIG['IMAGE_SIZE'],
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical'
    )
    
    val_generator = val_datagen.flow_from_directory(
        val_dir,
        target_size=CONFIG['IMAGE_SIZE'],
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical'
    )
    
    return train_generator, val_generator

train_gen, val_gen = create_data_generators(TRAIN_DIR, VAL_DIR)

print(f"✅ Training samples: {train_gen.samples}")
print(f"✅ Validation samples: {val_gen.samples}")
print(f"✅ Classes: {train_gen.class_indices}")

In [None]:
# Create model architecture
def create_mobile_model(num_classes=2, input_shape=(128, 128, 3)):
    # Use MobileNetV2 as base (optimized for mobile)
    base_model = MobileNetV2(
        input_shape=input_shape,
        alpha=0.75,  # Width multiplier for smaller model
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model initially
    base_model.trainable = False
    
    # Add custom classification head
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    
    return model

# Create the model
model = create_mobile_model(
    num_classes=CONFIG['NUM_CLASSES'],
    input_shape=(*CONFIG['IMAGE_SIZE'], 3)
)

print("🏗️ Model architecture created!")
model.summary()

In [None]:
# Compile and train the model
model.compile(
    optimizer=Adam(learning_rate=CONFIG['LEARNING_RATE']),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Callbacks for better training
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=3,
        min_lr=1e-7
    )
]

print("🚀 Starting training...")

# Train the model
history = model.fit(
    train_gen,
    epochs=CONFIG['EPOCHS'],
    validation_data=val_gen,
    callbacks=callbacks,
    verbose=1
)

print("✅ Training completed!")

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

# Print final metrics
final_acc = history.history['val_accuracy'][-1]
print(f"🎯 Final validation accuracy: {final_acc:.3f}")

In [None]:
# Convert to TensorFlow Lite with quantization
def convert_to_tflite_optimized(model, model_name):
    # Create representative dataset for quantization
    def representative_dataset():
        for _ in range(100):
            # Use actual training data for better quantization
            data = np.random.random((1, *CONFIG['IMAGE_SIZE'], 3)).astype(np.float32)
            yield [data]
    
    # Convert with INT8 quantization (best for mobile)
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    
    try:
        tflite_model = converter.convert()
        print("✅ INT8 quantized model created!")
        quantized = True
    except Exception as e:
        print(f"⚠️ INT8 quantization failed: {e}")
        print("Falling back to float32...")
        
        # Fallback to float32
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        tflite_model = converter.convert()
        quantized = False
        print("✅ Float32 model created!")
    
    # Save the model
    filename = f'{model_name}_quantized.tflite' if quantized else f'{model_name}_float32.tflite'
    with open(filename, 'wb') as f:
        f.write(tflite_model)
    
    print(f"💾 Model saved as: {filename}")
    print(f"📊 Model size: {len(tflite_model)} bytes ({len(tflite_model)/1024:.1f} KB)")
    
    return tflite_model, filename

# Convert the model
tflite_model, tflite_filename = convert_to_tflite_optimized(model, CONFIG['MODEL_NAME'])

In [None]:
# Test the TensorFlow Lite model
def test_tflite_model(filename):
    # Load the model
    interpreter = tf.lite.Interpreter(model_path=filename)
    interpreter.allocate_tensors()
    
    # Get input and output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print("🔍 TensorFlow Lite Model Details:")
    print(f"   Input shape: {input_details[0]['shape']}")
    print(f"   Input type: {input_details[0]['dtype']}")
    print(f"   Output shape: {output_details[0]['shape']}")
    print(f"   Output type: {output_details[0]['dtype']}")
    
    # Test with sample data
    input_shape = input_details[0]['shape']
    input_dtype = input_details[0]['dtype']
    
    if input_dtype == np.int8:
        test_input = np.random.randint(-128, 127, input_shape, dtype=np.int8)
    elif input_dtype == np.uint8:
        test_input = np.random.randint(0, 255, input_shape, dtype=np.uint8)
    else:
        test_input = np.random.random(input_shape).astype(np.float32)
    
    # Run inference
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    
    print(f"\n🧪 Test Results:")
    print(f"   Raw output: {output}")
    
    # Convert output based on type
    if output_details[0]['dtype'] == np.int8:
        # Convert int8 to probabilities
        probs = (output.astype(np.float32) + 128) / 255.0
        print(f"   Probabilities: {probs}")
        predicted_class = np.argmax(probs)
    else:
        predicted_class = np.argmax(output)
    
    class_names = ['healthy', 'diseased']
    print(f"   Predicted class: {predicted_class} ({class_names[predicted_class]})")
    
    return True

# Test the model
test_success = test_tflite_model(tflite_filename)
print("\n✅ Model testing completed!")

In [None]:
# Create labels.txt file for Android
labels = ['healthy', 'diseased']
with open('labels.txt', 'w') as f:
    for label in labels:
        f.write(f'{label}\n')

print("📝 labels.txt created!")
print("Labels:", labels)

# Display file sizes
import os
model_size = os.path.getsize(tflite_filename)
labels_size = os.path.getsize('labels.txt')

print(f"\n📊 Final Files:")
print(f"   {tflite_filename}: {model_size} bytes ({model_size/1024:.1f} KB)")
print(f"   labels.txt: {labels_size} bytes")

In [None]:
# Download files for Android integration
print("📱 Ready for Android Integration!")
print("\n🔽 Downloading files...")

# Download the model and labels
files.download(tflite_filename)
files.download('labels.txt')

print("\n✅ Files downloaded!")
print("\n📋 Next Steps:")
print("1. Replace 'app/src/main/assets/model.tflite' with the downloaded model")
print("2. Replace 'app/src/main/assets/labels.txt' with the downloaded labels")
print("3. Build and test your Android app")
print("4. The app should now work with real crop disease detection!")

print("\n🎯 Model Specifications:")
print(f"   Input: {CONFIG['IMAGE_SIZE'][0]}x{CONFIG['IMAGE_SIZE'][1]}x3")
print(f"   Output: {CONFIG['NUM_CLASSES']} classes (healthy, diseased)")
print(f"   Type: {'INT8 Quantized' if 'quantized' in tflite_filename else 'Float32'}")
print(f"   Size: {model_size/1024:.1f} KB")