In [3]:
import os
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers
from tensorflow.keras.applications import EfficientNetB0, MobileNet
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import seaborn as sns
import keras_cv

# CBAM implementation
def cbam_block(input_tensor, reduction_ratio=16):
    input_shape = input_tensor.shape[-1]

    # Channel attention
    avg_pool = tf.reduce_mean(input_tensor, axis=[1, 2], keepdims=True)
    max_pool = tf.reduce_max(input_tensor, axis=[1, 2], keepdims=True)

    mlp = layers.Dense(input_shape // reduction_ratio, activation='relu')
    mlp_out = layers.Dense(input_shape)

    avg_out = mlp_out(mlp(avg_pool))
    max_out = mlp_out(mlp(max_pool))

    channel_attention = tf.nn.sigmoid(avg_out + max_out)
    channel_refined = input_tensor * channel_attention

    # Spatial attention
    avg_pool_spatial = tf.reduce_mean(channel_refined, axis=-1, keepdims=True)
    max_pool_spatial = tf.reduce_max(channel_refined, axis=-1, keepdims=True)
    concat = tf.concat([avg_pool_spatial, max_pool_spatial], axis=-1)
    spatial_attention = layers.Conv2D(1, kernel_size=7, padding='same', activation='sigmoid')(concat)

    refined_output = channel_refined * spatial_attention
    return refined_output

# Model builder
def build_model(base_model_name='EfficientNetB0', input_shape=(224, 224, 3), num_classes=10):
    if base_model_name == 'EfficientNetB0':
        base_model = EfficientNetB0(include_top=False, weights='imagenet', input_shape=input_shape)
    elif base_model_name == 'MobileNet':
        base_model = MobileNet(include_top=False, weights='imagenet', input_shape=input_shape)
    else:
        raise ValueError("Unsupported base model name")

    base_model.trainable = False
    inputs = layers.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = cbam_block(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)

    return model

# Image preprocessing
def preprocess_images(train_dir, test_dir, input_size=(224, 224), batch_size=32, validation_split=0.2):
    datagen = ImageDataGenerator(
        rescale=1.0/255.0,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest',
        validation_split=validation_split
    )

    train_generator = datagen.flow_from_directory(
        train_dir,
        target_size=input_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='training'
    )

    validation_generator = datagen.flow_from_directory(
        train_dir,
        target_size=input_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation'
    )

    test_datagen = ImageDataGenerator(rescale=1.0/255.0)

    test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=input_size,
        batch_size=batch_size,
        class_mode='categorical',
        shuffle=False
    )

    return train_generator, validation_generator, test_generator

# Plot loss curves
def plot_loss_curves(history):
    plt.figure(figsize=(12, 4))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curve')
    plt.legend()

    plt.show()

# Plot confusion matrix
def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()

# Training function
def train_model(base_model_name='EfficientNetB0', train_dir='train', test_dir='test',
                input_size=(224, 224), batch_size=32, epochs=10, learning_rate=1e-4, validation_split=0.2,
                use_focal_loss=False):
    # Preprocess data
    train_gen, val_gen, test_gen = preprocess_images(train_dir, test_dir, input_size, batch_size, validation_split)

    # Get number of classes
    num_classes = len(train_gen.class_indices)
    class_names = list(train_gen.class_indices.keys())

    # Compute class weights
    class_counts = [0] * num_classes
    for _, labels in train_gen:
        class_counts += labels.sum(axis=0)
        if len(class_counts) > len(train_gen):
            break
    class_weights = compute_class_weight(
        'balanced', 
        classes=np.arange(num_classes), 
        y=np.repeat(np.arange(num_classes), class_counts)
    )
    class_weights = {i: weight for i, weight in enumerate(class_weights)}

    # Build model
    model = build_model(base_model_name=base_model_name, input_shape=(*input_size, 3), num_classes=num_classes)

    # Compile model
    loss_fn = (
        keras_cv.losses.FocalLoss(alpha=0.25, gamma=2.0) if use_focal_loss 
        else 'categorical_crossentropy'
    )
    model.compile(
        optimizer=optimizers.Adam(learning_rate=learning_rate),
        loss=loss_fn,
        metrics=['accuracy']
    )

    # Callbacks
    checkpoint_cb = callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy', mode='max')
    early_stopping_cb = callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)

    # Train
    print("Training....")
    history = model.fit(
        train_gen,
        epochs=epochs,
        validation_data=val_gen,
        class_weight=class_weights,
        callbacks=[checkpoint_cb, early_stopping_cb]
    )

    # Evaluate
    test_loss, test_acc = model.evaluate(test_gen)
    print(f"Test Accuracy: {test_acc:.2f}")

    # Plot loss curves
    plot_loss_curves(history)

    # Predict and evaluate
    test_gen.reset()
    y_pred = model.predict(test_gen)
    y_pred_classes = np.argmax(y_pred, axis=1)
    y_true = test_gen.classes

    # Classification report
    print("\nClassification Report:\n")
    print(classification_report(y_true, y_pred_classes, target_names=class_names))

    # Confusion matrix
    plot_confusion_matrix(y_true, y_pred_classes, class_names)

    return model, history

trained_model, training_history = train_model(
    base_model_name='EfficientNetB0',
    train_dir='./PlantVillage/train',
    test_dir='./PlantVillage/test',
    input_size=(224, 224),
    batch_size=32,
    epochs=20,
    learning_rate=1e-4,
    validation_split=0.2,
    use_focal_loss=True
)


Found 16328 images belonging to 17 classes.
Found 4073 images belonging to 17 classes.
Found 576 images belonging to 17 classes.


KeyboardInterrupt: 