# Transfer Learning Image Classifier Development

This notebook demonstrates the development of an image classifier using transfer learning with pre-trained models. We'll leverage established architectures trained on ImageNet and fine-tune them for our specific classification task.

## Setup and Imports

In [None]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16, EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import os
import json
from datetime import datetime
import sys
sys.path.append('../..')
from ml_models_core.src.base_classifier import BaseImageClassifier
from ml_models_core.src.model_registry import ModelRegistry, ModelMetadata
from ml_models_core.src.data_loaders import get_unified_classification_data

## Configuration

## Memory-Efficient Data Loading

**Note**: This notebook uses memory-efficient data loading similar to the deep learning notebooks. Instead of loading all images into memory at once, we:
1. Store only file paths in memory
2. Load images on-demand during training using TensorFlow's tf.data API
3. Use data streaming and prefetching for optimal performance

This approach allows training on large datasets without running out of memory.

In [None]:
# Configuration
CONFIG = {
    'base_model': 'resnet50',
    'input_shape': (224, 224, 3),
    'num_classes': 2,  # Will be updated based on dataset
    'batch_size': 32,
    'epochs': 20,
    'fine_tune_epochs': 10,
    'learning_rate': 1e-3,
    'fine_tune_learning_rate': 1e-5,
    'dropout_rate': 0.5,
    'l2_regularization': 1e-4,
    'validation_split': 0.2
}

def create_memory_efficient_dataset():
    """
    Create memory-efficient TensorFlow datasets that load images on-demand.
    """
    print("Creating memory-efficient dataset...")
    
    # Use the existing combined dataset
    from pathlib import Path
    dataset_path = Path("../../data/downloads/combined_unified_classification")
    
    if not dataset_path.exists():
        # Fallback to other available datasets
        base_data_dir = Path("../../data/downloads")
        available_datasets = [
            base_data_dir / "combined_unified_classification",
            base_data_dir / "oxford_pets",
            base_data_dir / "vegetables"
        ]
        
        for candidate in available_datasets:
            if candidate.exists():
                dataset_path = candidate
                break
        else:
            raise FileNotFoundError("No datasets found. Please run data preparation first.")
    
    print(f"Dataset path: {dataset_path}")
    
    # Collect image paths and labels without loading images
    image_paths = []
    labels = []
    class_names = []
    class_to_idx = {}
    
    # Scan directory structure
    for class_idx, class_dir in enumerate(sorted(dataset_path.iterdir())):
        if not class_dir.is_dir() or class_dir.name.startswith('.'):
            continue
        
        class_name = class_dir.name
        class_names.append(class_name)
        class_to_idx[class_name] = class_idx
        
        # Collect paths for this class
        valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
        for ext in valid_extensions:
            for img_path in class_dir.glob(f'*{ext}'):
                image_paths.append(str(img_path))
                labels.append(class_idx)
    
    print(f"Found {len(image_paths)} images from {len(class_names)} classes")
    print(f"Classes (first 10): {class_names[:10]}")
    
    # Update CONFIG with actual number of classes
    CONFIG['num_classes'] = len(class_names)
    
    # Convert to numpy arrays for splitting
    image_paths = np.array(image_paths)
    labels = np.array(labels)
    
    # Shuffle data
    indices = np.random.permutation(len(image_paths))
    image_paths = image_paths[indices]
    labels = labels[indices]
    
    # Split into train/val
    split_idx = int(len(image_paths) * (1 - CONFIG['validation_split']))
    train_paths = image_paths[:split_idx]
    train_labels = labels[:split_idx]
    val_paths = image_paths[split_idx:]
    val_labels = labels[split_idx:]
    
    print(f"Training samples: {len(train_paths)}")
    print(f"Validation samples: {len(val_paths)}")
    
    # Create TensorFlow datasets with memory-efficient loading
    def load_and_preprocess_image(path, label):
        """Load and preprocess a single image."""
        # Load image from file
        image = tf.io.read_file(path)
        image = tf.image.decode_image(image, channels=3, expand_animations=False)
        image = tf.ensure_shape(image, [None, None, 3])
        # Resize to model input size
        image = tf.image.resize(image, [224, 224])
        # Normalize to [0, 1]
        image = tf.cast(image, tf.float32) / 255.0
        # Convert label to one-hot
        label = tf.one_hot(label, CONFIG['num_classes'])
        return image, label
    
    # Create datasets from paths
    train_dataset = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
    train_dataset = train_dataset.map(load_and_preprocess_image, 
                                      num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset = train_dataset.batch(CONFIG['batch_size'])
    train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
    
    val_dataset = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
    val_dataset = val_dataset.map(load_and_preprocess_image,
                                  num_parallel_calls=tf.data.AUTOTUNE)
    val_dataset = val_dataset.batch(CONFIG['batch_size'])
    val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
    
    # Create augmented training dataset
    def augment_image(image, label):
        """Apply data augmentation to training images."""
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, 0.1)
        image = tf.image.random_contrast(image, 0.9, 1.1)
        return image, label
    
    train_dataset_aug = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
    train_dataset_aug = train_dataset_aug.map(load_and_preprocess_image,
                                              num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset_aug = train_dataset_aug.map(augment_image,
                                              num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset_aug = train_dataset_aug.batch(CONFIG['batch_size'])
    train_dataset_aug = train_dataset_aug.prefetch(tf.data.AUTOTUNE)
    
    return train_dataset_aug, val_dataset, class_names, len(train_paths), len(val_paths)

# Create memory-efficient datasets
train_dataset, val_dataset, class_names, n_train, n_val = create_memory_efficient_dataset()

print(f"\nMemory-efficient datasets created successfully!")
print(f"Training on {CONFIG['num_classes']} classes")
print(f"Training batches: {n_train // CONFIG['batch_size']}")
print(f"Validation batches: {n_val // CONFIG['batch_size']}")

## Data Loading and Preprocessing

## Pre-trained Model Selection and Architecture

In [None]:
def create_base_model(model_name='resnet50', input_shape=(224, 224, 3)):
    """
    Create and return a pre-trained base model.
    """
    base_models = {
        'resnet50': ResNet50,
        'vgg16': VGG16,
        'efficientnet': EfficientNetB0
    }
    
    if model_name not in base_models:
        raise ValueError(f"Model {model_name} not supported. Choose from: {list(base_models.keys())}")
    
    base_model = base_models[model_name](
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    
    # Freeze base model layers initially
    base_model.trainable = False
    
    print(f"Base model: {model_name}")
    print(f"Total parameters: {base_model.count_params():,}")
    print(f"Trainable parameters: {sum(tf.size(w).numpy() for w in base_model.trainable_weights):,}")
    
    return base_model

def create_transfer_model(base_model, num_classes=2, dropout_rate=0.5):
    """
    Create the complete transfer learning model with custom head.
    """
    # Add custom classification head
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(CONFIG['l2_regularization']))(x)
    x = Dropout(dropout_rate)(x)
    x = Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(CONFIG['l2_regularization']))(x)
    x = Dropout(dropout_rate)(x)
    predictions = Dense(num_classes, activation='softmax', name='predictions')(x)
    
    model = Model(inputs=base_model.input, outputs=predictions)
    
    return model

# Create the transfer learning model
base_model = create_base_model(CONFIG['base_model'], CONFIG['input_shape'])
model = create_transfer_model(base_model, CONFIG['num_classes'], CONFIG['dropout_rate'])

print(f"\nComplete model summary:")
model.summary()

## Model Compilation and Callbacks

In [None]:
# Compile model
model.compile(
    optimizer=Adam(learning_rate=CONFIG['learning_rate']),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'top_k_categorical_accuracy']
)

# Setup callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    ModelCheckpoint(
        '../models/transfer_model_best.h5',
        monitor='val_accuracy',
        save_best_only=True,
        save_weights_only=False,
        verbose=1
    )
]

print("Model compiled with callbacks ready.")

## Initial Training (Frozen Base Model)

In [None]:
print("Starting initial training with frozen base model...")

# Calculate steps per epoch
steps_per_epoch = n_train // CONFIG['batch_size']
validation_steps = n_val // CONFIG['batch_size']

# Train with frozen base model
history_initial = model.fit(
    train_dataset,
    epochs=CONFIG['epochs'],
    validation_data=val_dataset,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

print("Initial training completed.")

In [None]:
print("Starting fine-tuning phase...")

# Unfreeze the base model for fine-tuning
base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = len(base_model.layers) // 2

# Freeze all layers before fine_tune_at
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

print(f"Fine-tuning from layer {fine_tune_at} onwards")
print(f"Trainable parameters: {sum(tf.size(w).numpy() for w in model.trainable_weights):,}")

# Recompile with lower learning rate for fine-tuning
model.compile(
    optimizer=Adam(learning_rate=CONFIG['fine_tune_learning_rate']),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'top_k_categorical_accuracy']
)

# Continue training with fine-tuning
history_finetune = model.fit(
    train_dataset,
    epochs=CONFIG['fine_tune_epochs'],
    validation_data=val_dataset,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

print("Fine-tuning completed.")

## Model Evaluation

In [None]:
# Evaluate model on validation set
val_loss, val_accuracy, val_top_k = model.evaluate(val_dataset, verbose=0)
print(f"Validation Results:")
print(f"Loss: {val_loss:.4f}")
print(f"Accuracy: {val_accuracy:.4f}")
print(f"Top-k Accuracy: {val_top_k:.4f}")

# Generate predictions for confusion matrix (limited sample)
print("\nGenerating predictions for confusion matrix...")
y_pred_list = []
y_true_list = []

# Collect predictions from validation dataset (limit to first 1000 samples for efficiency)
samples_collected = 0
max_samples = 1000

for images, labels in val_dataset:
    if samples_collected >= max_samples:
        break
    
    predictions = model.predict(images, verbose=0)
    y_pred_list.extend(np.argmax(predictions, axis=1))
    y_true_list.extend(np.argmax(labels.numpy(), axis=1))
    samples_collected += len(images)
    
    if samples_collected % 200 == 0:
        print(f"Processed {samples_collected} samples...")

y_pred = np.array(y_pred_list)
y_true = np.array(y_true_list)

# Classification report (show first 10 classes for readability)
print(f"\nClassification Report (based on {len(y_true)} samples, first 10 classes):")
unique_classes = sorted(list(set(y_true)))
display_classes = unique_classes[:10]

if len(display_classes) < len(unique_classes):
    print(f"Note: Showing first 10 of {len(unique_classes)} classes")

print(classification_report(y_true, y_pred, 
                          target_names=[class_names[i] for i in display_classes],
                          labels=display_classes))

# Confusion matrix (only for manageable number of classes)
if len(class_names) <= 15:
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
else:
    print(f"Confusion matrix skipped (too many classes: {len(class_names)})")

In [None]:
class TransferLearningClassifier(BaseImageClassifier):
    """
    Transfer learning image classifier using pre-trained models.
    """
    
    def __init__(self, config=None, class_names=None):
        self.config = config or CONFIG
        self.model = None
        self.class_names = class_names or ['Class_0', 'Class_1']
        self.training_history = None
        
    def load_model(self, model_path: str) -> None:
        """
        Load a trained transfer learning model.
        """
        try:
            self.model = tf.keras.models.load_model(model_path)
            print(f"Model loaded from {model_path}")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
    
    def preprocess(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess image for transfer learning model.
        """
        # Convert to float32 and normalize
        if image.dtype != np.float32:
            image = image.astype(np.float32)
        
        # Resize to model input size
        if image.shape[:2] != self.config['input_shape'][:2]:
            image = tf.image.resize(image, self.config['input_shape'][:2])
        
        # Normalize pixel values to [0, 1] if not already normalized
        if image.max() > 1.0:
            image = image / 255.0
        
        # Add batch dimension if needed
        if len(image.shape) == 3:
            image = np.expand_dims(image, axis=0)
        
        return image
    
    def predict(self, image: np.ndarray) -> dict:
        """
        Make prediction on preprocessed image.
        """
        if self.model is None:
            raise ValueError("Model not loaded. Call load_model() first.")
        
        preprocessed_image = self.preprocess(image)
        predictions = self.model.predict(preprocessed_image, verbose=0)
        
        # Convert to probabilities dict
        probs = predictions[0] if len(predictions.shape) > 1 else predictions
        
        return {
            self.class_names[i]: float(prob) 
            for i, prob in enumerate(probs)
        }
    
    def get_metadata(self) -> dict:
        """
        Get model metadata and configuration.
        """
        return {
            'model_type': 'transfer_learning',
            'base_model': self.config['base_model'],
            'input_shape': self.config['input_shape'],
            'num_classes': self.config['num_classes'],
            'class_names': self.class_names,
            'preprocessing': 'resize_and_normalize',
            'framework': 'tensorflow',
            'architecture': 'pretrained_with_custom_head'
        }
    
    def save_model(self, model_path: str) -> None:
        """
        Save the trained model.
        """
        if self.model is None:
            raise ValueError("No model to save. Train or load a model first.")
        
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        self.model.save(model_path)
        print(f"Model saved to {model_path}")

# Create classifier instance with actual class names
transfer_classifier = TransferLearningClassifier(CONFIG, class_names)
transfer_classifier.model = model

print("Transfer learning classifier created successfully.")
print("Metadata:", transfer_classifier.get_metadata())
print(f"Training on {len(class_names)} classes: {class_names[:10]}{'...' if len(class_names) > 10 else ''}")

## Model Performance Analysis

In [None]:
# Evaluate model on validation set
val_loss, val_accuracy, val_top_k = model.evaluate(val_dataset, verbose=0)
print(f"Validation Results:")
print(f"Loss: {val_loss:.4f}")
print(f"Accuracy: {val_accuracy:.4f}")
print(f"Top-k Accuracy: {val_top_k:.4f}")

# Generate predictions for confusion matrix (limited sample)
print("\nGenerating predictions for confusion matrix...")
y_pred_list = []
y_true_list = []

# Collect predictions from validation dataset (limit to first 1000 samples for efficiency)
samples_collected = 0
max_samples = 1000

for images, labels in val_dataset:
    if samples_collected >= max_samples:
        break
    
    predictions = model.predict(images, verbose=0)
    y_pred_list.extend(np.argmax(predictions, axis=1))
    y_true_list.extend(np.argmax(labels.numpy(), axis=1))
    samples_collected += len(images)
    
    if samples_collected % 200 == 0:
        print(f"Processed {samples_collected} samples...")

y_pred = np.array(y_pred_list)
y_true = np.array(y_true_list)

# Classification report (show first 10 classes for readability)
print(f"\nClassification Report (based on {len(y_true)} samples, first 10 classes):")
unique_classes = sorted(list(set(y_true)))
display_classes = unique_classes[:10]

if len(display_classes) < len(unique_classes):
    print(f"Note: Showing first 10 of {len(unique_classes)} classes")

print(classification_report(y_true, y_pred, 
                          target_names=[class_names[i] for i in display_classes],
                          labels=display_classes))

# Confusion matrix (only for manageable number of classes)
if len(class_names) <= 15:
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
else:
    print(f"Confusion matrix skipped (too many classes: {len(class_names)})")

In [None]:
# Simple performance metrics for model registry
performance_metrics = {
    'accuracy': val_accuracy,
    'mean_confidence': 0.75,  # Placeholder - would be calculated from detailed analysis
    'std_confidence': 0.15    # Placeholder - would be calculated from detailed analysis
}

# Register model in the model registry
registry = ModelRegistry()

# Save the model
model_save_path = '../models/transfer_learning_classifier.h5'
transfer_classifier.save_model(model_save_path)

# Create metadata
metadata = ModelMetadata(
    name="transfer_learning_classifier",
    version="1.0.0",
    model_type="transfer_learning",
    accuracy=performance_metrics['accuracy'],
    training_date=datetime.now().isoformat(),
    model_path=model_save_path,
    config=CONFIG,
    performance_metrics={
        'validation_accuracy': val_accuracy,
        'validation_loss': val_loss,
        'mean_confidence': performance_metrics['mean_confidence'],
        'std_confidence': performance_metrics['std_confidence']
    }
)

# Register the model
registry.register_model(metadata)
print("Model registered successfully in the model registry.")

# Save configuration
config_path = '../models/transfer_learning_config.json'
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)
print(f"Configuration saved to {config_path}")

## Summary and Next Steps

In [None]:
print("=== Transfer Learning Development Summary ===")
print(f"Base Model: {CONFIG['base_model']}")
print(f"Training Strategy: Two-phase (frozen + fine-tuning)")
# Get final validation accuracy from model evaluation (if available)
try:
    print(f"Final Validation Accuracy: {val_accuracy:.4f}")
except NameError:
    print("Final Validation Accuracy: Run evaluation cell to get accuracy")
print(f"Model Parameters: {model.count_params():,}")
print(f"Total Classes: {len(class_names)}")
print(f"Training Images: {n_train}")
print(f"Validation Images: {n_val}")

print(f"\nDataset Information:")
print(f"- Memory-efficient unified classification dataset with {len(class_names)} classes")
print(f"- Classes include: {', '.join(class_names[:5])}{'...' if len(class_names) > 5 else ''}")
print(f"- Images resized to {CONFIG['input_shape'][:2]} for transfer learning")
print(f"- Images loaded on-demand to save memory")

print(f"\nKey Features:")
print("- Pre-trained ImageNet weights (ResNet50)")
print("- Custom classification head")
print("- Two-phase training strategy")
print("- Memory-efficient data loading with tf.data")
print("- Comprehensive evaluation metrics")
print("- Compatible with unified dataset")

print(f"\nModel Integration:")
print("- Implements BaseImageClassifier interface")
print("- Registered in ModelRegistry")
print("- Ready for ensemble integration")
print("- Compatible with API deployment")

print(f"\nNext Steps:")
print("1. Experiment with different pre-trained models")
print("2. Optimize hyperparameters")
print("3. Implement model ensembling")
print("4. Deploy to production API")
print("5. Monitor model performance")