# Improved Disease Detection Model Training

This notebook provides an enhanced approach to training the animal skin disease detection model with:
- EfficientNetB0 architecture (better than InceptionV3)
- Improved data augmentation
- Better evaluation metrics
- Proper callbacks and training strategies
- Cross-validation support

## 1. Import Required Libraries

In [None]:
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, CSVLogger
)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import plot_model

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
import os
from datetime import datetime

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

## 2. Configuration and Hyperparameters

In [None]:
# Configuration
CONFIG = {
    'IMG_SIZE': 224,  # EfficientNet works better with 224x224
    'BATCH_SIZE': 16,  # Larger batch size for better training
    'EPOCHS': 100,
    'LEARNING_RATE': 0.001,
    'NUM_CLASSES': 3,
    'CLASS_NAMES': ['Bacterial', 'Fungal', 'Healthy'],
    'TRAIN_DIR': './src/Train',
    'VAL_DIR': './src/Validation',
    'TEST_DIR': './src/Test',
    'MODEL_DIR': './model',
    'LOGS_DIR': './logs'
}

# Create directories if they don't exist
os.makedirs(CONFIG['MODEL_DIR'], exist_ok=True)
os.makedirs(CONFIG['LOGS_DIR'], exist_ok=True)

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 3. Enhanced Data Generators with Better Augmentation

In [None]:
def create_data_generators():
    """Create enhanced data generators with better augmentation"""
    
    # Training data generator with extensive augmentation
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=45,
        width_shift_range=0.3,
        height_shift_range=0.3,
        shear_range=0.3,
        zoom_range=0.3,
        brightness_range=[0.7, 1.3],
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='nearest',
        validation_split=0.2  # Use 20% for validation if no separate val folder
    )
    
    # Validation and test data generators (no augmentation)
    val_test_datagen = ImageDataGenerator(rescale=1./255)
    
    # Training generator
    train_generator = train_datagen.flow_from_directory(
        CONFIG['TRAIN_DIR'],
        target_size=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical',
        shuffle=True,
        seed=42
    )
    
    # Validation generator
    if os.path.exists(CONFIG['VAL_DIR']):
        val_generator = val_test_datagen.flow_from_directory(
            CONFIG['VAL_DIR'],
            target_size=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
            batch_size=CONFIG['BATCH_SIZE'],
            class_mode='categorical',
            shuffle=False
        )
    else:
        print("No separate validation directory found. Using training split.")
        val_generator = None
    
    # Test generator (if exists)
    test_generator = None
    if os.path.exists(CONFIG['TEST_DIR']):
        test_generator = val_test_datagen.flow_from_directory(
            CONFIG['TEST_DIR'],
            target_size=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
            batch_size=CONFIG['BATCH_SIZE'],
            class_mode='categorical',
            shuffle=False
        )
    
    return train_generator, val_generator, test_generator

# Create generators
train_gen, val_gen, test_gen = create_data_generators()

print(f"Training samples: {train_gen.samples}")
if val_gen:
    print(f"Validation samples: {val_gen.samples}")
if test_gen:
    print(f"Test samples: {test_gen.samples}")
print(f"Class indices: {train_gen.class_indices}")

## 4. Build Enhanced Model with EfficientNetB0

In [None]:
def create_efficientnet_model():
    """Create an improved model using EfficientNetB0"""
    
    # Load pre-trained EfficientNetB0 (better than InceptionV3)
    base_model = EfficientNetB0(
        input_shape=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'], 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model initially
    base_model.trainable = False
    
    # Add custom classification head
    inputs = tf.keras.Input(shape=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'], 3))
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.2)(x)
    outputs = Dense(CONFIG['NUM_CLASSES'], activation='softmax')(x)
    
    model = Model(inputs, outputs)
    
    return model, base_model

# Create the model
model, base_model = create_efficientnet_model()

# Compile the model
model.compile(
    optimizer=Adam(learning_rate=CONFIG['LEARNING_RATE']),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'precision', 'recall']
)

# Display model summary
model.summary()

# Save model architecture plot
plot_model(model, to_file=f"{CONFIG['MODEL_DIR']}/model_architecture.png", 
           show_shapes=True, show_layer_names=True)

## 5. Setup Advanced Callbacks

In [None]:
def create_callbacks():
    """Create advanced callbacks for better training"""
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    callbacks = [
        # Early stopping to prevent overfitting
        EarlyStopping(
            monitor='val_loss',
            patience=15,
            restore_best_weights=True,
            verbose=1
        ),
        
        # Reduce learning rate when validation loss plateaus
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=7,
            min_lr=1e-7,
            verbose=1
        ),
        
        # Save best model
        ModelCheckpoint(
            filepath=f"{CONFIG['MODEL_DIR']}/best_model_{timestamp}.h5",
            monitor='val_loss',
            save_best_only=True,
            save_weights_only=False,
            verbose=1
        ),
        
        # Log training metrics
        CSVLogger(
            filename=f"{CONFIG['LOGS_DIR']}/training_log_{timestamp}.csv",
            append=True
        )
    ]
    
    return callbacks

callbacks = create_callbacks()
print(f"Created {len(callbacks)} callbacks for enhanced training")

## 6. Train the Model (Phase 1: Frozen Base)

In [None]:
# Phase 1: Train with frozen base model
print("Phase 1: Training with frozen base model...")

# Calculate steps per epoch
steps_per_epoch = train_gen.samples // CONFIG['BATCH_SIZE']
validation_steps = val_gen.samples // CONFIG['BATCH_SIZE'] if val_gen else None

# Train for initial epochs
history_phase1 = model.fit(
    train_gen,
    epochs=20,  # Initial training phase
    steps_per_epoch=steps_per_epoch,
    validation_data=val_gen,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

print("Phase 1 training completed!")

## 7. Fine-tuning (Phase 2: Unfreeze Some Layers)

In [None]:
# Phase 2: Fine-tune by unfreezing some layers
print("Phase 2: Fine-tuning with unfrozen layers...")

# Unfreeze the top layers of the base model
base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = len(base_model.layers) // 2

# Freeze all layers before fine_tune_at
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

# Use a lower learning rate for fine-tuning
model.compile(
    optimizer=Adam(learning_rate=CONFIG['LEARNING_RATE'] / 10),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'precision', 'recall']
)

print(f"Unfrozen layers: {sum(1 for layer in model.layers if layer.trainable)}")

# Continue training
history_phase2 = model.fit(
    train_gen,
    epochs=CONFIG['EPOCHS'],
    initial_epoch=len(history_phase1.history['loss']),
    steps_per_epoch=steps_per_epoch,
    validation_data=val_gen,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

print("Fine-tuning completed!")

## 8. Enhanced Evaluation and Visualization

In [None]:
def plot_training_history(history1, history2=None):
    """Plot comprehensive training history"""
    
    # Combine histories if we have two phases
    if history2:
        acc = history1.history['accuracy'] + history2.history['accuracy']
        val_acc = history1.history['val_accuracy'] + history2.history['val_accuracy']
        loss = history1.history['loss'] + history2.history['loss']
        val_loss = history1.history['val_loss'] + history2.history['val_loss']
        
        # Add vertical line to show where fine-tuning started
        fine_tune_epoch = len(history1.history['loss'])
    else:
        acc = history1.history['accuracy']
        val_acc = history1.history['val_accuracy']
        loss = history1.history['loss']
        val_loss = history1.history['val_loss']
        fine_tune_epoch = None
    
    epochs = range(len(acc))
    
    plt.figure(figsize=(15, 5))
    
    # Plot accuracy
    plt.subplot(1, 3, 1)
    plt.plot(epochs, acc, 'r-', label='Training Accuracy')
    plt.plot(epochs, val_acc, 'b-', label='Validation Accuracy')
    if fine_tune_epoch:
        plt.axvline(x=fine_tune_epoch, color='g', linestyle='--', label='Fine-tuning Start')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Plot loss
    plt.subplot(1, 3, 2)
    plt.plot(epochs, loss, 'r-', label='Training Loss')
    plt.plot(epochs, val_loss, 'b-', label='Validation Loss')
    if fine_tune_epoch:
        plt.axvline(x=fine_tune_epoch, color='g', linestyle='--', label='Fine-tuning Start')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot learning rate (if available)
    plt.subplot(1, 3, 3)
    if 'lr' in history1.history:
        lr = history1.history['lr']
        if history2 and 'lr' in history2.history:
            lr += history2.history['lr']
        plt.plot(epochs, lr, 'g-', label='Learning Rate')
        plt.title('Learning Rate Schedule')
        plt.xlabel('Epochs')
        plt.ylabel('Learning Rate')
        plt.yscale('log')
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{CONFIG['MODEL_DIR']}/training_history.png", dpi=300, bbox_inches='tight')
    plt.show()

# Plot training history
plot_training_history(history_phase1, history_phase2)

In [None]:
def evaluate_model_comprehensive(model, test_generator=None, val_generator=None):
    """Comprehensive model evaluation with multiple metrics"""
    
    # Use test set if available, otherwise validation set
    eval_gen = test_generator if test_generator else val_generator
    eval_name = "Test" if test_generator else "Validation"
    
    if eval_gen is None:
        print("No evaluation data available!")
        return
    
    print(f"\n=== {eval_name} Set Evaluation ===")
    
    # Get predictions
    eval_gen.reset()
    predictions = model.predict(eval_gen, verbose=1)
    predicted_classes = np.argmax(predictions, axis=1)
    
    # Get true labels
    true_classes = eval_gen.classes[:len(predicted_classes)]
    
    # Classification report
    print("\nClassification Report:")
    report = classification_report(
        true_classes, predicted_classes, 
        target_names=CONFIG['CLASS_NAMES'],
        output_dict=True
    )
    print(classification_report(true_classes, predicted_classes, target_names=CONFIG['CLASS_NAMES']))
    
    # Confusion Matrix
    cm = confusion_matrix(true_classes, predicted_classes)
    
    plt.figure(figsize=(10, 4))
    
    # Plot confusion matrix
    plt.subplot(1, 2, 1)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=CONFIG['CLASS_NAMES'], 
                yticklabels=CONFIG['CLASS_NAMES'])
    plt.title(f'{eval_name} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    
    # Plot normalized confusion matrix
    plt.subplot(1, 2, 2)
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=CONFIG['CLASS_NAMES'], 
                yticklabels=CONFIG['CLASS_NAMES'])
    plt.title(f'{eval_name} Confusion Matrix (Normalized)')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    
    plt.tight_layout()
    plt.savefig(f"{CONFIG['MODEL_DIR']}/confusion_matrix.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    # Per-class accuracy
    class_accuracy = cm.diagonal() / cm.sum(axis=1)
    print("\nPer-class Accuracy:")
    for i, acc in enumerate(class_accuracy):
        print(f"  {CONFIG['CLASS_NAMES'][i]}: {acc:.3f}")
    
    # Overall metrics
    overall_accuracy = np.sum(cm.diagonal()) / np.sum(cm)
    print(f"\nOverall Accuracy: {overall_accuracy:.3f}")
    
    return report, cm

# Evaluate the model
evaluation_report, conf_matrix = evaluate_model_comprehensive(model, test_gen, val_gen)

## 9. Export Models in Multiple Formats

In [None]:
def export_model_formats(model, model_dir):
    """Export model in multiple formats for different use cases"""
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    print("Exporting model in multiple formats...")
    
    # 1. Save as H5 (Keras format)
    h5_path = f"{model_dir}/improved_model_{timestamp}.h5"
    model.save(h5_path)
    print(f"✓ H5 model saved: {h5_path}")
    
    # 2. Save as SavedModel (TensorFlow format)
    savedmodel_path = f"{model_dir}/saved_model_{timestamp}"
    model.save(savedmodel_path, save_format='tf')
    print(f"✓ SavedModel saved: {savedmodel_path}")
    
    # 3. Convert to TFLite (for mobile deployment)
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    
    # Standard TFLite
    tflite_model = converter.convert()
    tflite_path = f"{model_dir}/model_{timestamp}.tflite"
    with open(tflite_path, 'wb') as f:
        f.write(tflite_model)
    print(f"✓ TFLite model saved: {tflite_path}")
    
    # Optimized TFLite
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_optimized = converter.convert()
    tflite_opt_path = f"{model_dir}/optimized_model_{timestamp}.tflite"
    with open(tflite_opt_path, 'wb') as f:
        f.write(tflite_optimized)
    print(f"✓ Optimized TFLite model saved: {tflite_opt_path}")
    
    # 4. Convert to TensorFlow.js (for web deployment)
    try:
        import tensorflowjs as tfjs
        tfjs_path = f"{model_dir}/tfjs_{timestamp}"
        tfjs.converters.save_keras_model(model, tfjs_path)
        print(f"✓ TensorFlow.js model saved: {tfjs_path}")
    except ImportError:
        print("⚠ TensorFlow.js not installed. Run: pip install tensorflowjs")
    
    # Save model info
    model_info = {
        'timestamp': timestamp,
        'architecture': 'EfficientNetB0',
        'input_shape': (CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'], 3),
        'num_classes': CONFIG['NUM_CLASSES'],
        'class_names': CONFIG['CLASS_NAMES'],
        'h5_path': h5_path,
        'tflite_path': tflite_path,
        'optimized_tflite_path': tflite_opt_path
    }
    
    # Save model info as JSON
    import json
    info_path = f"{model_dir}/model_info_{timestamp}.json"
    with open(info_path, 'w') as f:
        json.dump(model_info, f, indent=2)
    print(f"✓ Model info saved: {info_path}")
    
    return model_info

# Export the trained model
model_info = export_model_formats(model, CONFIG['MODEL_DIR'])
print("\nModel export completed!")

## 10. Model Integration Instructions

In [None]:
print("""
=== MODEL INTEGRATION INSTRUCTIONS ===

Your trained model has been exported in multiple formats:

1. H5 Format (.h5): 
   - Use for Python/Flask backend
   - Copy to your backend/py/ directory
   - Load with: tf.keras.models.load_model('model.h5')

2. TFLite Format (.tflite):
   - Use for mobile apps (Android/iOS)
   - Smaller file size, optimized for mobile
   - Use TensorFlow Lite interpreter

3. TensorFlow.js Format:
   - Use for web frontend (React)
   - Copy to frontend/public/ directory
   - Load with: tf.loadLayersModel()

NEXT STEPS:
1. Copy the .h5 file to your Flask backend
2. Update your Flask server to use the new model
3. Test the integration with sample images
4. Deploy and monitor performance

REMEMBER:
- Image preprocessing: resize to 224x224, normalize to [0,1]
- Model expects 4D input: (batch_size, 224, 224, 3)
- Output: 3 classes [Bacterial, Fungal, Healthy]
""")

# Create a simple test script
test_script = f'''
# Simple test script for the trained model
import tensorflow as tf
import numpy as np
from PIL import Image

# Load the model
model = tf.keras.models.load_model('{model_info["h5_path"]}')

def predict_disease(image_path):
    """Predict disease from image path"""
    # Load and preprocess image
    img = Image.open(image_path).convert('RGB')
    img = img.resize((224, 224))
    img_array = np.array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    # Make prediction
    predictions = model.predict(img_array)
    predicted_class = np.argmax(predictions[0])
    confidence = predictions[0][predicted_class]
    
    class_names = {CONFIG['CLASS_NAMES']}
    
    return {{
        'predicted_class': class_names[predicted_class],
        'confidence': float(confidence),
        'all_predictions': {{
            class_names[i]: float(predictions[0][i]) 
            for i in range(len(class_names))
        }}
    }}

# Example usage:
# result = predict_disease('path/to/your/image.jpg')
# print(result)
'''

with open(f"{CONFIG['MODEL_DIR']}/test_model.py", 'w') as f:
    f.write(test_script)

print(f"Test script created: {CONFIG['MODEL_DIR']}/test_model.py")