# AI-Based Tomato & Potato Disease Classification - Google Colab Training

**Author:** Peter Maina (136532)  
**Institution:** Strathmore University  
**Project:** Final Year AI/ML Project  

This notebook implements the complete training pipeline for plant disease classification on Google Colab.

---

## Table of Contents
1. [Setup Environment](#setup)
2. [Download Dataset](#download)
3. [Data Exploration](#exploration)
4. [Data Preprocessing](#preprocessing)
5. [Model Training](#training)
6. [Model Evaluation](#evaluation)
7. [Model Export](#export)

---

## 1. Setup Environment

First, let's check GPU availability and clone the project repository.

In [None]:
# Check GPU availability
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
print("\nGPU Available:", tf.config.list_physical_devices('GPU'))

# Enable GPU memory growth
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("✓ GPU memory growth enabled")
    except RuntimeError as e:
        print(e)

In [None]:
# Mount Google Drive (optional - for saving models)
from google.colab import drive
drive.mount('/content/drive')
print("✓ Google Drive mounted")

In [None]:
# Clone project repository
!git clone https://github.com/YOUR_USERNAME/AI-Based-Tomato-and-Potato-Disease-Classification-App.git
%cd AI-Based-Tomato-and-Potato-Disease-Classification-App
!ls -la

In [None]:
# Install required packages
!pip install -q pyyaml
!pip install -q kaggle
!pip install -q scikit-learn
!pip install -q seaborn
print("✓ Dependencies installed")

### Setup Kaggle API

Upload your `kaggle.json` file to access the PlantVillage dataset.

**How to get kaggle.json:**
1. Go to https://www.kaggle.com/settings
2. Scroll to API section
3. Click "Create New API Token"
4. Upload the downloaded file below

In [None]:
# Upload kaggle.json
from google.colab import files
import os

print("Please upload your kaggle.json file:")
uploaded = files.upload()

# Setup Kaggle credentials
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
print("\n✓ Kaggle credentials configured")

## 2. Download Dataset

Download the PlantVillage dataset from Kaggle (4.37 GB, ~54,000 images).

In [None]:
# Download PlantVillage dataset
!python data/scripts/download_dataset.py --colab-mode

## 3. Data Exploration

Explore the dataset structure and visualize sample images.

In [None]:
# Import libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
import yaml

# Load configuration
with open('data/configs/data_config.yaml', 'r') as f:
    data_config = yaml.safe_load(f)

print("Dataset Configuration:")
print(f"  Total Images: {data_config['dataset']['total_images']:,}")
print(f"  Size: {data_config['dataset']['size_gb']} GB")
print(f"  Number of Classes: {data_config['num_classes']}")

In [None]:
# Visualize sample images from each crop
def visualize_samples(data_dir, crop_type, n_samples=6):
    """Visualize sample images from a crop type."""
    crop_dir = Path(data_dir) / crop_type
    
    if not crop_dir.exists():
        print(f"Directory not found: {crop_dir}")
        return
    
    # Get first disease class
    disease_classes = [d for d in crop_dir.iterdir() if d.is_dir()]
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle(f'{crop_type.upper()} Sample Images', fontsize=16)
    
    for idx, disease_dir in enumerate(disease_classes[:n_samples]):
        # Get first image
        image_files = list(disease_dir.glob('*.jpg')) + list(disease_dir.glob('*.png'))
        if image_files:
            img = Image.open(image_files[0])
            
            row = idx // 3
            col = idx % 3
            axes[row, col].imshow(img)
            axes[row, col].set_title(disease_dir.name, fontsize=10)
            axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize tomato samples
visualize_samples('data/raw', 'tomato')

# Visualize potato samples
visualize_samples('data/raw', 'potato', n_samples=3)

In [None]:
# Analyze class distribution
def analyze_distribution(data_dir):
    """Analyze and visualize class distribution."""
    data_dir = Path(data_dir)
    
    class_counts = {}
    
    for crop in ['tomato', 'potato']:
        crop_dir = data_dir / crop
        if crop_dir.exists():
            for disease_dir in crop_dir.iterdir():
                if disease_dir.is_dir():
                    image_files = list(disease_dir.glob('*.jpg')) + list(disease_dir.glob('*.png'))
                    class_counts[disease_dir.name] = len(image_files)
    
    # Plot distribution
    plt.figure(figsize=(15, 6))
    plt.bar(range(len(class_counts)), list(class_counts.values()))
    plt.xticks(range(len(class_counts)), list(class_counts.keys()), rotation=45, ha='right')
    plt.xlabel('Disease Class')
    plt.ylabel('Number of Images')
    plt.title('Class Distribution')
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    total = sum(class_counts.values())
    print(f"\nTotal Images: {total:,}")
    print(f"Min Images per Class: {min(class_counts.values()):,}")
    print(f"Max Images per Class: {max(class_counts.values()):,}")
    print(f"Average Images per Class: {total/len(class_counts):.0f}")

analyze_distribution('data/raw')

## 4. Data Preprocessing

Preprocess images and split into train/validation/test sets.

In [None]:
# Preprocess dataset
!python data/scripts/preprocess_data.py

In [None]:
# Split dataset into train/val/test
!python data/scripts/split_dataset.py

## 5. Model Training

Train different model architectures:
- Baseline CNN
- MobileNetV2 (recommended for mobile deployment)
- EfficientNetB0 (best accuracy)

In [None]:
# Train MobileNetV2 model (recommended)
!python ml/training.py --architecture MobileNetV2 --epochs 50 --batch-size 32 --use-gpu

In [None]:
# Optional: Train Baseline CNN
# !python ml/training.py --architecture baseline --epochs 50 --batch-size 32 --use-gpu

In [None]:
# Optional: Train EfficientNetB0 (best accuracy, slower training)
# !python ml/training.py --architecture EfficientNetB0 --epochs 50 --batch-size 32 --use-gpu

### Visualize Training History

In [None]:
# Load and visualize training history
import json
import sys
sys.path.insert(0, '/content/AI-Based-Tomato-and-Potato-Disease-Classification-App')
from ml.utils import plot_training_history

# Find latest training history file
history_dir = Path('ml/logs/training')
history_files = list(history_dir.glob('*_history.json'))

if history_files:
    # Load most recent history
    latest_history = sorted(history_files)[-1]
    print(f"Loading history from: {latest_history}")
    
    with open(latest_history, 'r') as f:
        history = json.load(f)
    
    # Plot training history
    plot_training_history(history)
else:
    print("No training history found")

## 6. Model Evaluation

Evaluate the trained model on the test set.

In [None]:
# Find the best trained model
model_dir = Path('ml/trained_models/final')
model_files = list(model_dir.glob('*.h5'))

if model_files:
    best_model = sorted(model_files)[-1]
    print(f"Found model: {best_model}")
else:
    print("No trained models found")
    best_model = None

In [None]:
# Evaluate model
if best_model:
    !python ml/evaluation.py --model {best_model}

### Test Inference on Sample Images

In [None]:
# Test inference on a sample image
from ml.inference import predict_single_image

# Get a sample test image
test_dir = Path('data/processed/test')
test_classes = list(test_dir.iterdir())

if test_classes and best_model:
    # Get first image from first class
    sample_class = test_classes[0]
    image_files = list(sample_class.glob('*.jpg')) + list(sample_class.glob('*.png'))
    
    if image_files:
        sample_image = image_files[0]
        print(f"Testing on: {sample_image}")
        print(f"True class: {sample_class.name}\n")
        
        # Make prediction
        predicted_class, confidence = predict_single_image(
            str(best_model),
            str(sample_image),
            visualize=True
        )

## 7. Model Export

Convert model to TensorFlow Lite format for mobile deployment.

In [None]:
# Convert to TensorFlow Lite
from ml.utils import convert_to_tflite

if best_model:
    tflite_model = convert_to_tflite(
        str(best_model),
        quantization='float16'
    )
    print(f"\n✓ TFLite model saved: {tflite_model}")

### Save Models to Google Drive

In [None]:
# Copy models to Google Drive for persistence
import shutil

drive_model_dir = Path('/content/drive/MyDrive/PlantDiseaseModels')
drive_model_dir.mkdir(parents=True, exist_ok=True)

# Copy trained models
if best_model:
    shutil.copy(best_model, drive_model_dir)
    print(f"✓ Model saved to: {drive_model_dir / best_model.name}")

# Copy TFLite model
if best_model:
    tflite_path = best_model.parent / f"{best_model.stem}.tflite"
    if tflite_path.exists():
        shutil.copy(tflite_path, drive_model_dir)
        print(f"✓ TFLite model saved to: {drive_model_dir / tflite_path.name}")

# Copy evaluation reports
eval_files = list(Path('ml/logs').glob('*_evaluation.*'))
for eval_file in eval_files:
    shutil.copy(eval_file, drive_model_dir)
    print(f"✓ Evaluation report saved: {drive_model_dir / eval_file.name}")

print("\n✓ All models and reports saved to Google Drive!")

## Summary

**Training Complete!**

You have successfully:
1. ✓ Downloaded and explored the PlantVillage dataset
2. ✓ Preprocessed and split the data
3. ✓ Trained a deep learning model
4. ✓ Evaluated model performance
5. ✓ Converted model to TFLite format
6. ✓ Saved models to Google Drive

**Next Steps:**
- Download the TFLite model for mobile app integration
- Review evaluation metrics and confusion matrix
- Fine-tune hyperparameters if needed
- Deploy the model in a mobile application

**Expected Performance:**
- Baseline CNN: 85-90% accuracy
- MobileNetV2: 92-95% accuracy
- EfficientNetB0: 95-97% accuracy