# Transfer Learning Image Classifier Development

This notebook demonstrates the development of an image classifier using transfer learning with pre-trained models. We'll leverage established architectures trained on ImageNet and fine-tune them for our specific classification task.

## Setup and Imports

In [ ]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16, EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import os
import json
from datetime import datetime
import sys
sys.path.append('../..')
from ml_models_core.src.base_classifier import BaseImageClassifier
from ml_models_core.src.model_registry import ModelRegistry, ModelMetadata
from ml_models_core.src.data_loaders import get_unified_classification_data

## Configuration

In [ ]:
def load_and_preprocess_data():
    """
    Load and preprocess the unified classification dataset for transfer learning.
    """
    print("Loading unified classification dataset...")
    
    # Load the unified dataset (use sklearn format for raw data access)
    data_split = get_unified_classification_data(framework='sklearn')
    
    # Extract components from DataSplit object
    X = np.vstack([data_split.X_train, data_split.X_val, data_split.X_test])
    y = np.hstack([data_split.y_train, data_split.y_val, data_split.y_test])
    class_names = data_split.class_names
    
    print(f"Loaded {len(X)} images from {len(class_names)} classes")
    print(f"Classes (first 10): {class_names[:10]}")
    print(f"Total classes: {len(class_names)}")
    print(f"Image shape: {X[0].shape}")
    
    # Update CONFIG with actual number of classes
    CONFIG['num_classes'] = len(class_names)
    
    # Convert to numpy arrays if needed
    if not isinstance(X, np.ndarray):
        X = np.array(X)
    if not isinstance(y, np.ndarray):
        y = np.array(y)
    
    # Resize images to transfer learning input size (224x224)
    print("Resizing images to 224x224 for transfer learning...")
    X_resized = tf.image.resize(X, [224, 224]).numpy()
    
    # Normalize pixel values to [0, 1]
    X_resized = X_resized.astype(np.float32) / 255.0
    
    # Convert labels to categorical
    y_categorical = to_categorical(y, CONFIG['num_classes'])
    
    # Split data
    split_idx = int(len(X_resized) * (1 - CONFIG['validation_split']))
    
    X_train, X_val = X_resized[:split_idx], X_resized[split_idx:]
    y_train, y_val = y_categorical[:split_idx], y_categorical[split_idx:]
    
    print(f"Training samples: {len(X_train)}")
    print(f"Validation samples: {len(X_val)}")
    print(f"Class distribution in training: {np.bincount(np.argmax(y_train, axis=1))}")
    
    return X_train, X_val, y_train, y_val, class_names

# Load the data
X_train, X_val, y_train, y_val, class_names = load_and_preprocess_data()

## Data Loading and Preprocessing

In [ ]:
def load_and_preprocess_data():
    """
    Load and preprocess the unified classification dataset for transfer learning.
    """
    print("Loading unified classification dataset...")
    
    # Load the unified dataset
    data = get_unified_classification_data(framework='tensorflow')
    
    # Extract components
    X = data['X']
    y = data['y']
    class_names = data['class_names']
    
    print(f"Loaded {len(X)} images from {len(class_names)} classes")
    print(f"Classes (first 10): {class_names[:10]}")
    print(f"Total classes: {len(class_names)}")
    print(f"Image shape: {X[0].shape}")
    
    # Update CONFIG with actual number of classes
    CONFIG['num_classes'] = len(class_names)
    
    # Convert to numpy arrays if needed
    if not isinstance(X, np.ndarray):
        X = np.array(X)
    if not isinstance(y, np.ndarray):
        y = np.array(y)
    
    # Resize images to transfer learning input size (224x224)
    print("Resizing images to 224x224 for transfer learning...")
    X_resized = tf.image.resize(X, [224, 224]).numpy()
    
    # Normalize pixel values to [0, 1]
    X_resized = X_resized.astype(np.float32) / 255.0
    
    # Convert labels to categorical
    y_categorical = to_categorical(y, CONFIG['num_classes'])
    
    # Split data
    split_idx = int(len(X_resized) * (1 - CONFIG['validation_split']))
    
    X_train, X_val = X_resized[:split_idx], X_resized[split_idx:]
    y_train, y_val = y_categorical[:split_idx], y_categorical[split_idx:]
    
    print(f"Training samples: {len(X_train)}")
    print(f"Validation samples: {len(X_val)}")
    print(f"Class distribution in training: {np.bincount(np.argmax(y_train, axis=1))}")
    
    return X_train, X_val, y_train, y_val, class_names

# Load the data
X_train, X_val, y_train, y_val, class_names = load_and_preprocess_data()

## Pre-trained Model Selection and Architecture

In [None]:
def create_base_model(model_name='resnet50', input_shape=(224, 224, 3)):
    """
    Create and return a pre-trained base model.
    """
    base_models = {
        'resnet50': ResNet50,
        'vgg16': VGG16,
        'efficientnet': EfficientNetB0
    }
    
    if model_name not in base_models:
        raise ValueError(f"Model {model_name} not supported. Choose from: {list(base_models.keys())}")
    
    base_model = base_models[model_name](
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    
    # Freeze base model layers initially
    base_model.trainable = False
    
    print(f"Base model: {model_name}")
    print(f"Total parameters: {base_model.count_params():,}")
    print(f"Trainable parameters: {sum(p.numel() for p in base_model.trainable_weights if p.trainable):,}")
    
    return base_model

def create_transfer_model(base_model, num_classes=2, dropout_rate=0.5):
    """
    Create the complete transfer learning model with custom head.
    """
    # Add custom classification head
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(CONFIG['l2_regularization']))(x)
    x = Dropout(dropout_rate)(x)
    x = Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(CONFIG['l2_regularization']))(x)
    x = Dropout(dropout_rate)(x)
    predictions = Dense(num_classes, activation='softmax', name='predictions')(x)
    
    model = Model(inputs=base_model.input, outputs=predictions)
    
    return model

# Create the transfer learning model
base_model = create_base_model(CONFIG['base_model'], CONFIG['input_shape'])
model = create_transfer_model(base_model, CONFIG['num_classes'], CONFIG['dropout_rate'])

print(f"\nComplete model summary:")
model.summary()

## Model Compilation and Callbacks

In [None]:
# Compile model
model.compile(
    optimizer=Adam(learning_rate=CONFIG['learning_rate']),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'top_k_categorical_accuracy']
)

# Setup callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    ModelCheckpoint(
        '../models/transfer_model_best.h5',
        monitor='val_accuracy',
        save_best_only=True,
        save_weights_only=False,
        verbose=1
    )
]

print("Model compiled with callbacks ready.")

## Initial Training (Frozen Base Model)

In [None]:
print("Starting initial training with frozen base model...")

# Train with frozen base model
history_initial = model.fit(
    X_train, y_train,
    batch_size=CONFIG['batch_size'],
    epochs=CONFIG['epochs'],
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    verbose=1
)

print("Initial training completed.")

## Fine-tuning (Unfreezing Base Model)

In [None]:
print("Starting fine-tuning phase...")

# Unfreeze the base model for fine-tuning
base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = len(base_model.layers) // 2

# Freeze all layers before fine_tune_at
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

print(f"Fine-tuning from layer {fine_tune_at} onwards")
print(f"Trainable parameters: {sum(p.numel() for p in model.trainable_weights):,}")

# Recompile with lower learning rate for fine-tuning
model.compile(
    optimizer=Adam(learning_rate=CONFIG['fine_tune_learning_rate']),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'top_k_categorical_accuracy']
)

# Continue training with fine-tuning
history_finetune = model.fit(
    X_train, y_train,
    batch_size=CONFIG['batch_size'],
    epochs=CONFIG['fine_tune_epochs'],
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    verbose=1
)

print("Fine-tuning completed.")

## Training Visualization and Analysis

In [None]:
def plot_training_history(history_initial, history_finetune):
    """
    Plot training and validation metrics for both training phases.
    """
    # Combine histories
    initial_epochs = len(history_initial.history['loss'])
    total_epochs = initial_epochs + len(history_finetune.history['loss'])
    
    # Combine metrics
    train_loss = history_initial.history['loss'] + history_finetune.history['loss']
    val_loss = history_initial.history['val_loss'] + history_finetune.history['val_loss']
    train_acc = history_initial.history['accuracy'] + history_finetune.history['accuracy']
    val_acc = history_initial.history['val_accuracy'] + history_finetune.history['val_accuracy']
    
    epochs = range(1, total_epochs + 1)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    ax1.plot(epochs, train_loss, 'b-', label='Training Loss')
    ax1.plot(epochs, val_loss, 'r-', label='Validation Loss')
    ax1.axvline(x=initial_epochs, color='g', linestyle='--', alpha=0.7, label='Fine-tuning Start')
    ax1.set_title('Model Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot accuracy
    ax2.plot(epochs, train_acc, 'b-', label='Training Accuracy')
    ax2.plot(epochs, val_acc, 'r-', label='Validation Accuracy')
    ax2.axvline(x=initial_epochs, color='g', linestyle='--', alpha=0.7, label='Fine-tuning Start')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Best validation accuracy: {max(val_acc):.4f}")
    print(f"Final validation accuracy: {val_acc[-1]:.4f}")

plot_training_history(history_initial, history_finetune)

## Model Evaluation

In [ ]:
# Evaluate model on validation set
val_loss, val_accuracy, val_top_k = model.evaluate(X_val, y_val, verbose=0)
print(f"Validation Results:")
print(f"Loss: {val_loss:.4f}")
print(f"Accuracy: {val_accuracy:.4f}")
print(f"Top-k Accuracy: {val_top_k:.4f}")

# Generate predictions
y_pred_probs = model.predict(X_val)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = np.argmax(y_val, axis=1)

# Classification report (show first 10 classes for readability)
print("\nClassification Report (first 10 classes):")
unique_classes = sorted(list(set(y_true)))
display_classes = unique_classes[:10]

if len(display_classes) < len(unique_classes):
    print(f"Note: Showing first 10 of {len(unique_classes)} classes")

print(classification_report(y_true, y_pred, 
                          target_names=[class_names[i] for i in display_classes],
                          labels=display_classes))

# Confusion matrix (only for manageable number of classes)
if len(class_names) <= 15:
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
else:
    print(f"Confusion matrix skipped (too many classes: {len(class_names)})")

## Transfer Learning Classifier Implementation

In [ ]:
class TransferLearningClassifier(BaseImageClassifier):
    """
    Transfer learning image classifier using pre-trained models.
    """
    
    def __init__(self, config=None, class_names=None):
        self.config = config or CONFIG
        self.model = None
        self.class_names = class_names or ['Class_0', 'Class_1']
        self.training_history = None
        
    def load_model(self, model_path: str) -> None:
        """
        Load a trained transfer learning model.
        """
        try:
            self.model = tf.keras.models.load_model(model_path)
            print(f"Model loaded from {model_path}")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
    
    def preprocess(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess image for transfer learning model.
        """
        # Convert to float32 and normalize
        if image.dtype != np.float32:
            image = image.astype(np.float32)
        
        # Resize to model input size
        if image.shape[:2] != self.config['input_shape'][:2]:
            image = tf.image.resize(image, self.config['input_shape'][:2])
        
        # Normalize pixel values to [0, 1] if not already normalized
        if image.max() > 1.0:
            image = image / 255.0
        
        # Add batch dimension if needed
        if len(image.shape) == 3:
            image = np.expand_dims(image, axis=0)
        
        return image
    
    def predict(self, image: np.ndarray) -> dict:
        """
        Make prediction on preprocessed image.
        """
        if self.model is None:
            raise ValueError("Model not loaded. Call load_model() first.")
        
        preprocessed_image = self.preprocess(image)
        predictions = self.model.predict(preprocessed_image, verbose=0)
        
        # Convert to probabilities dict
        probs = predictions[0] if len(predictions.shape) > 1 else predictions
        
        return {
            self.class_names[i]: float(prob) 
            for i, prob in enumerate(probs)
        }
    
    def get_metadata(self) -> dict:
        """
        Get model metadata and configuration.
        """
        return {
            'model_type': 'transfer_learning',
            'base_model': self.config['base_model'],
            'input_shape': self.config['input_shape'],
            'num_classes': self.config['num_classes'],
            'class_names': self.class_names,
            'preprocessing': 'resize_and_normalize',
            'framework': 'tensorflow',
            'architecture': 'pretrained_with_custom_head'
        }
    
    def save_model(self, model_path: str) -> None:
        """
        Save the trained model.
        """
        if self.model is None:
            raise ValueError("No model to save. Train or load a model first.")
        
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        self.model.save(model_path)
        print(f"Model saved to {model_path}")

# Create classifier instance with actual class names
transfer_classifier = TransferLearningClassifier(CONFIG, class_names)
transfer_classifier.model = model

print("Transfer learning classifier created successfully.")
print("Metadata:", transfer_classifier.get_metadata())
print(f"Training on {len(class_names)} classes: {class_names[:10]}{'...' if len(class_names) > 10 else ''}")

## Model Performance Analysis

In [None]:
def analyze_model_performance(classifier, X_test, y_test):
    """
    Comprehensive performance analysis of the transfer learning model.
    """
    print("=== Transfer Learning Model Performance Analysis ===")
    
    # Test predictions
    predictions = []
    for i, image in enumerate(X_test[:100]):  # Test on subset for demonstration
        pred = classifier.predict(image)
        predictions.append(pred)
        if i % 20 == 0:
            print(f"Processed {i+1}/100 test images")
    
    # Convert predictions to arrays
    pred_probs = np.array([[pred[class_name] for class_name in classifier.class_names] 
                          for pred in predictions])
    pred_classes = np.argmax(pred_probs, axis=1)
    true_classes = np.argmax(y_test[:100], axis=1)
    
    # Calculate metrics
    accuracy = np.mean(pred_classes == true_classes)
    
    print(f"\nTest Accuracy: {accuracy:.4f}")
    print(f"Average Confidence: {np.mean(np.max(pred_probs, axis=1)):.4f}")
    
    # Plot prediction confidence distribution
    confidences = np.max(pred_probs, axis=1)
    plt.figure(figsize=(10, 4))
    
    plt.subplot(1, 2, 1)
    plt.hist(confidences, bins=20, alpha=0.7, edgecolor='black')
    plt.title('Prediction Confidence Distribution')
    plt.xlabel('Confidence')
    plt.ylabel('Frequency')
    plt.axvline(np.mean(confidences), color='red', linestyle='--', 
               label=f'Mean: {np.mean(confidences):.3f}')
    plt.legend()
    
    # Plot accuracy vs confidence
    plt.subplot(1, 2, 2)
    correct = (pred_classes == true_classes)
    plt.scatter(confidences[correct], [1]*sum(correct), alpha=0.6, 
               label='Correct', color='green')
    plt.scatter(confidences[~correct], [0]*sum(~correct), alpha=0.6, 
               label='Incorrect', color='red')
    plt.title('Prediction Accuracy vs Confidence')
    plt.xlabel('Confidence')
    plt.ylabel('Correct (1) / Incorrect (0)')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    return {
        'accuracy': accuracy,
        'mean_confidence': np.mean(confidences),
        'std_confidence': np.std(confidences)
    }

# Analyze performance
performance_metrics = analyze_model_performance(transfer_classifier, X_val, y_val)
print(f"\nFinal Performance Summary:")
for metric, value in performance_metrics.items():
    print(f"{metric}: {value:.4f}")

## Model Registry Integration

In [None]:
# Register model in the model registry
registry = ModelRegistry()

# Save the model
model_save_path = '../models/transfer_learning_classifier.h5'
transfer_classifier.save_model(model_save_path)

# Create metadata
metadata = ModelMetadata(
    name="transfer_learning_classifier",
    version="1.0.0",
    model_type="transfer_learning",
    accuracy=performance_metrics['accuracy'],
    training_date=datetime.now().isoformat(),
    model_path=model_save_path,
    config=CONFIG,
    performance_metrics={
        'validation_accuracy': val_accuracy,
        'validation_loss': val_loss,
        'mean_confidence': performance_metrics['mean_confidence'],
        'std_confidence': performance_metrics['std_confidence']
    }
)

# Register the model
registry.register_model(metadata)
print("Model registered successfully in the model registry.")

# Save configuration
config_path = '../models/transfer_learning_config.json'
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)
print(f"Configuration saved to {config_path}")

## Summary and Next Steps

In [ ]:
print("=== Transfer Learning Development Summary ===")
print(f"Base Model: {CONFIG['base_model']}")
print(f"Training Strategy: Two-phase (frozen + fine-tuning)")
print(f"Final Validation Accuracy: {val_accuracy:.4f}")
print(f"Model Parameters: {model.count_params():,}")
print(f"Total Classes: {len(class_names)}")
print(f"Training Images: {len(X_train)}")
print(f"Validation Images: {len(X_val)}")

print(f"\nDataset Information:")
print(f"- Unified classification dataset with {len(class_names)} classes")
print(f"- Classes include: {', '.join(class_names[:5])}{'...' if len(class_names) > 5 else ''}")
print(f"- Images resized to {CONFIG['input_shape'][:2]} for transfer learning")

print(f"\nKey Features:")
print("- Pre-trained ImageNet weights (ResNet50)")
print("- Custom classification head")
print("- Two-phase training strategy")
print("- Comprehensive evaluation metrics")
print("- Compatible with unified dataset")

print(f"\nModel Integration:")
print("- Implements BaseImageClassifier interface")
print("- Registered in ModelRegistry")
print("- Ready for ensemble integration")
print("- Compatible with API deployment")

print(f"\nNext Steps:")
print("1. Experiment with different pre-trained models")
print("2. Optimize hyperparameters")
print("3. Implement model ensembling")
print("4. Deploy to production API")
print("5. Monitor model performance")