# Transfer Learning Image Classifier Development

This notebook demonstrates the development of an image classifier using transfer learning with pre-trained models. We'll leverage established architectures trained on ImageNet and fine-tune them for our specific classification task.

## Setup and Imports

In [ ]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import os
import json
from datetime import datetime
import sys

# Add path to access extracted modules
sys.path.append('../..')
sys.path.append('../')

# Import from extracted transfer learning modules
from src.classifier import TransferLearningClassifier
from src.trainer import TransferLearningTrainer
from src.models import TransferLearningModel
from src.config import TransferLearningClassifierConfig
from src.data_loader import create_memory_efficient_datasets, discover_classes

# Import from ml_models_core
from ml_models_core.src.base_classifier import BaseImageClassifier
from ml_models_core.src.model_registry import ModelRegistry, ModelMetadata

## Configuration

## Memory-Efficient Data Loading

**Note**: This notebook uses memory-efficient data loading similar to the deep learning notebooks. Instead of loading all images into memory at once, we:
1. Store only file paths in memory
2. Load images on-demand during training using TensorFlow's tf.data API
3. Use data streaming and prefetching for optimal performance

This approach allows training on large datasets without running out of memory.

In [ ]:
# Configuration using extracted config module
config = TransferLearningClassifierConfig(
    base_model_name='resnet50',
    image_size=(224, 224),
    batch_size=32,
    num_epochs=20,
    learning_rate=1e-3,
    fine_tune_layers=10,
    fine_tune_learning_rate=1e-5,
    dropout_rate=0.5,
    validation_split=0.2,
    dense_units=[512, 256],
    mixed_precision=True,
    use_xla=True,
    class_weights=True,
    cache_dataset=True
)

print("Configuration created using extracted config module:")
print(f"Base model: {config.base_model_name}")
print(f"Image size: {config.image_size}")
print(f"Batch size: {config.batch_size}")
print(f"Mixed precision: {config.mixed_precision}")
print(f"XLA compilation: {config.use_xla}")

# Create memory-efficient datasets using extracted data loader
print("\nCreating memory-efficient datasets...")

# Use existing dataset path
from pathlib import Path
dataset_path = Path("../../data/downloads/combined_unified_classification")

if not dataset_path.exists():
    # Fallback to other available datasets
    base_data_dir = Path("../../data/downloads")
    available_datasets = [
        base_data_dir / "combined_unified_classification",
        base_data_dir / "oxford_pets",
        base_data_dir / "vegetables"
    ]
    
    for candidate in available_datasets:
        if candidate.exists():
            dataset_path = candidate
            break
    else:
        raise FileNotFoundError("No datasets found. Please run data preparation first.")

print(f"Dataset path: {dataset_path}")

# Create datasets using extracted data loader
train_dataset, val_dataset, test_dataset, class_names, class_weights = create_memory_efficient_datasets(
    str(dataset_path), config
)

# Update config with discovered classes
config.num_classes = len(class_names)

print(f"\nMemory-efficient datasets created successfully!")
print(f"Training on {config.num_classes} classes")
print(f"Classes (first 10): {class_names[:10]}")
print(f"Class weights enabled: {config.class_weights}")
print(f"Dataset caching enabled: {config.cache_dataset}")

## Data Loading and Preprocessing

## Pre-trained Model Selection and Architecture

In [ ]:
# Create transfer learning model using extracted modules
print("Creating transfer learning model using extracted modules...")

# Create model instance
transfer_model = TransferLearningModel(config, config.num_classes)

# Build the model
model = transfer_model.build_model()

print(f"\nModel created successfully:")
print(f"Base model: {config.base_model_name}")
print(f"Total parameters: {model.count_params():,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.trainable_variables):,}")
print(f"Input shape: {model.input_shape}")
print(f"Output shape: {model.output_shape}")

# Display model architecture info
model_info = transfer_model.get_model_info()
print(f"\nModel architecture features:")
for feature in model_info.get('features', []):
    if feature:
        print(f"  - {feature}")

print(f"\nModel summary (first few layers):")
for i, layer in enumerate(model.layers[:10]):
    print(f"  {i}: {layer.name} - {layer.__class__.__name__}")
if len(model.layers) > 10:
    print(f"  ... and {len(model.layers) - 10} more layers")

## Model Compilation and Callbacks

In [ ]:
# Setup callbacks using extracted model
print("Setting up training callbacks...")

# Get callbacks from the model
callbacks = transfer_model.get_callbacks(log_dir="../logs/transfer_learning")

print("Callbacks configured:")
for callback in callbacks:
    print(f"  - {callback.__class__.__name__}")

print("\nModel compilation completed using extracted modules.")

## Initial Training (Frozen Base Model)

In [ ]:
# Train the model using extracted trainer
print("Starting training using extracted trainer...")

# Create classifier and trainer
classifier = TransferLearningClassifier(config=config, class_names=class_names)
trainer = TransferLearningTrainer(classifier, config)

# Train the model (this handles both phases: frozen base + fine-tuning)
results = trainer.train(str(dataset_path))

# Extract results
model = results['model']
training_metrics = results['metrics']
training_history = results['training_history']
fine_tune_history = results['fine_tune_history']

print(f"\nTraining completed successfully!")
print(f"Test accuracy: {training_metrics['test_accuracy']:.4f}")
print(f"Best validation accuracy: {training_metrics['best_val_accuracy']:.4f}")
print(f"Model parameters: {training_metrics['model_parameters']:,}")
print(f"Fine-tuned: {training_metrics['fine_tuned']}")

# Store results for later use
classifier.model = model

In [ ]:
# Fine-tuning is handled automatically by the trainer
print("Fine-tuning phase was handled automatically by the trainer.")

if fine_tune_history:
    print("Fine-tuning was performed successfully.")
    print(f"Phase 1 (frozen base): {len(training_history['loss'])} epochs")
    print(f"Phase 2 (fine-tuning): {len(fine_tune_history['loss'])} epochs")
else:
    print("Fine-tuning was skipped (fine_tune_layers = 0)")

# Plot training history
print("\nPlotting training history...")
try:
    trainer.plot_training_history(save_path="../logs/transfer_learning/training_history.png")
except Exception as e:
    print(f"Could not plot training history: {e}")

print("Training visualization completed.")

## Model Evaluation

In [ ]:
# Model evaluation using test dataset from trainer results
print("Evaluating model using test results from trainer...")

test_results = results['test_results']
test_accuracy = test_results['accuracy']
test_loss = test_results['loss']

print(f"Test Results:")
print(f"Loss: {test_loss:.4f}")
print(f"Accuracy: {test_accuracy:.4f}")

# Get predictions from trainer results
y_pred = np.array(test_results['predictions'])
y_true = np.array(test_results['true_labels'])

print(f"\nEvaluation based on {len(y_true)} test samples")

# Classification report (show first 10 classes for readability)
unique_classes = sorted(list(set(y_true)))
display_classes = unique_classes[:10]

if len(display_classes) < len(unique_classes):
    print(f"Note: Showing first 10 of {len(unique_classes)} classes")

print("\nClassification Report:")
print(classification_report(y_true, y_pred, 
                          target_names=[class_names[i] for i in display_classes],
                          labels=display_classes))

# Confusion matrix (only for manageable number of classes)
if len(class_names) <= 15:
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
else:
    print(f"Confusion matrix skipped (too many classes: {len(class_names)})")

In [ ]:
# Use the TransferLearningClassifier from extracted modules
print("Transfer learning classifier created using extracted modules.")

# The classifier is already created and trained
print("Classifier metadata:")
metadata = classifier.get_metadata()
for key, value in metadata.items():
    print(f"  {key}: {value}")

print(f"\nTraining completed on {len(class_names)} classes")
print(f"Classes: {', '.join(class_names[:10])}{'...' if len(class_names) > 10 else ''}")

## Model Performance Analysis

In [None]:
# Evaluate model on validation set
val_loss, val_accuracy, val_top_k = model.evaluate(val_dataset, verbose=0)
print(f"Validation Results:")
print(f"Loss: {val_loss:.4f}")
print(f"Accuracy: {val_accuracy:.4f}")
print(f"Top-k Accuracy: {val_top_k:.4f}")

# Generate predictions for confusion matrix (limited sample)
print("\nGenerating predictions for confusion matrix...")
y_pred_list = []
y_true_list = []

# Collect predictions from validation dataset (limit to first 1000 samples for efficiency)
samples_collected = 0
max_samples = 1000

for images, labels in val_dataset:
    if samples_collected >= max_samples:
        break
    
    predictions = model.predict(images, verbose=0)
    y_pred_list.extend(np.argmax(predictions, axis=1))
    y_true_list.extend(np.argmax(labels.numpy(), axis=1))
    samples_collected += len(images)
    
    if samples_collected % 200 == 0:
        print(f"Processed {samples_collected} samples...")

y_pred = np.array(y_pred_list)
y_true = np.array(y_true_list)

# Classification report (show first 10 classes for readability)
print(f"\nClassification Report (based on {len(y_true)} samples, first 10 classes):")
unique_classes = sorted(list(set(y_true)))
display_classes = unique_classes[:10]

if len(display_classes) < len(unique_classes):
    print(f"Note: Showing first 10 of {len(unique_classes)} classes")

print(classification_report(y_true, y_pred, 
                          target_names=[class_names[i] for i in display_classes],
                          labels=display_classes))

# Confusion matrix (only for manageable number of classes)
if len(class_names) <= 15:
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
else:
    print(f"Confusion matrix skipped (too many classes: {len(class_names)})")

In [ ]:
# Save model and register using extracted modules
print("Saving model using extracted modules...")

# Save the model
model_save_path = '../models/transfer_learning_classifier.h5'
classifier.save_model(
    model_save_path,
    model=model,
    class_names=class_names,
    accuracy=training_metrics['test_accuracy'],
    training_history=training_history
)

# Register model in the model registry
registry = ModelRegistry()

# Create metadata using training results
metadata = ModelMetadata(
    name="transfer_learning_classifier",
    version="1.0.0",
    model_type="transfer_learning",
    accuracy=training_metrics['test_accuracy'],
    training_date=datetime.now().isoformat(),
    model_path=model_save_path,
    config=config.to_dict(),
    performance_metrics=training_metrics
)

# Register the model
registry.register_model(metadata)
print("Model registered successfully in the model registry.")

# Save training history
history_path = '../logs/transfer_learning_training_history.json'
trainer.save_training_history(history_path)
print(f"Training history saved to {history_path}")

## Summary and Next Steps

In [ ]:
print("=== Transfer Learning Development Summary ===")
print("✅ Successfully extracted transfer learning code into modular src files")
print("✅ Updated notebook to use extracted modules")
print()
print(f"Base Model: {config.base_model_name}")
print(f"Framework: TensorFlow/Keras")
print(f"Training Strategy: Two-phase (frozen + fine-tuning)")
print(f"Final Test Accuracy: {training_metrics['test_accuracy']:.4f}")
print(f"Best Validation Accuracy: {training_metrics['best_val_accuracy']:.4f}")
print(f"Model Parameters: {training_metrics['model_parameters']:,}")
print(f"Total Classes: {len(class_names)}")
print(f"Training Samples: {training_metrics['train_samples']}")
print(f"Test Samples: {training_metrics['test_samples']}")

print(f"\nExtracted Modules:")
print("- src/config.py: TransferLearningClassifierConfig")
print("- src/models.py: TransferLearningModel with TensorFlow/Keras")
print("- src/data_loader.py: Memory-efficient TensorFlow datasets")
print("- src/trainer.py: TransferLearningTrainer with two-phase training")
print("- src/classifier.py: TransferLearningClassifier implementing BaseImageClassifier")
print("- scripts/train.py: CLI training script")

print(f"\nKey Features Implemented:")
print("✅ Pre-trained ImageNet weights")
print("✅ Two-phase training (frozen base + fine-tuning)")
print("✅ Memory-efficient tf.data loading")
print("✅ Mixed precision training")
print("✅ XLA compilation")
print("✅ Class weight balancing")
print("✅ Comprehensive callbacks")
print("✅ TensorBoard logging")
print("✅ Model checkpointing")

print(f"\nIntegration Features:")
print("✅ Implements BaseImageClassifier interface")
print("✅ Compatible with ModelRegistry")
print("✅ Configurable via dataclass")
print("✅ Memory-efficient data loading")
print("✅ Production-ready CLI script")

print(f"\nNext Steps:")
print("1. ✅ Code extraction completed")
print("2. ✅ Notebook integration completed")
print("3. Run CLI training script for validation")
print("4. Add unit tests for extracted modules")
print("5. Deploy to production API")
print("6. Monitor model performance")

print(f"\nCLI Usage:")
print("python scripts/train.py --data_path /path/to/data --base_model resnet50 --epochs 20")