# **Multi-Model Transfer Learning for Image Classification**

This notebook demonstrates how to use transfer learning with multiple pre-trained models for image classification. We'll compare different models to find the best one for our specific task.

## What is Transfer Learning?

Transfer learning is a machine learning technique where a model developed for one task is reused as the starting point for a model on a second task. It's particularly useful in deep learning where pre-trained models (such as those trained on ImageNet) can be fine-tuned for specific domains with less data.

## Key Benefits:
- Reduced training time
- Requires less data
- Often achieves better performance
- Lower computational requirements

---

## 1. Setup and Import Packages

First, let's import the necessary packages and modules:

In [13]:
# Import core packages
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import torchvision
import time

# Import from our CNN module
from src.cnn import CNN, load_data, load_model_weights

# Import from our utils module
from src.utils import (
    train_one_epoch, validate, train_model_with_early_stopping,
    plot_training_history, plot_confusion_matrix, predict_sample_images,
    get_pretrained_model
)

# Import Weights & Biases
import wandb

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)


## 2. Configuration Parameters

Define the experiment parameters and model configurations:

In [14]:
# Common training parameters
common_config = {
    "batch_size": 32,       # Number of images per batch
    "img_size": 224,        # Image size for model input
    "epochs": 20,           # Maximum number of training epochs
    "patience": 5,          # Early stopping patience
}

# Model-specific configurations
model_configs = {
    "ResNet50": {
        "base_model": "resnet50",
        "learning_rate": 1e-3,
        "weight_decay": 1e-5,
        "unfreeze_layers": 2
    },
    "ResNet18": {
        "base_model": "resnet18",
        "learning_rate": 5e-4,
        "weight_decay": 1e-5,
        "unfreeze_layers": 2
    },
    "EfficientNet-B0": {
        "base_model": "efficientnet_b0",
        "learning_rate": 1e-3,
        "weight_decay": 1e-5,
        "unfreeze_layers": 2
    }
}

print("Common configuration:")
for key, value in common_config.items():
    print(f"  {key}: {value}")

print("\nModel configurations:")
for model_name, config in model_configs.items():
    print(f"  {model_name}: {config}")



Common configuration:
  batch_size: 32
  img_size: 224
  epochs: 20
  patience: 5

Model configurations:
  ResNet50: {'base_model': 'resnet50', 'learning_rate': 0.001, 'weight_decay': 1e-05, 'unfreeze_layers': 2}
  ResNet18: {'base_model': 'resnet18', 'learning_rate': 0.0005, 'weight_decay': 1e-05, 'unfreeze_layers': 2}
  EfficientNet-B0: {'base_model': 'efficientnet_b0', 'learning_rate': 0.001, 'weight_decay': 1e-05, 'unfreeze_layers': 2}


## 3. Data Loading

Load the image data from our dataset using the load_data function from our CNN module:

In [15]:
# Path to training and validation data
# Update these paths to match your dataset location
train_dir = './dataset/training'
valid_dir = './dataset/validation'

# Load data using the load_data function
train_loader, valid_loader, num_classes = load_data(
    train_dir, 
    valid_dir, 
    batch_size=common_config["batch_size"], 
    img_size=common_config["img_size"]
)

# Display dataset information
print(f"Number of classes: {num_classes}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(valid_loader.dataset)}")

# Get class names from the dataset
class_names = train_loader.dataset.classes
print(f"\nClasses: {class_names}")

Number of classes: 15
Training samples: 2985
Validation samples: 1500

Classes: ['Bedroom', 'Coast', 'Forest', 'Highway', 'Industrial', 'Inside city', 'Kitchen', 'Living room', 'Mountain', 'Office', 'Open country', 'Store', 'Street', 'Suburb', 'Tall building']


## 4. Prepare for Multiple Models

Create a dictionary to store results for each model:

In [16]:
# Dictionary to store results for each model
models_results = {}

# Set the device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Print GPU information if available
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cpu


## 5. Training Functions

Define a function to train a single model:

In [17]:
def train_single_model(model_name, model_config, common_config):
    """Train a single model with the specified configuration and log to W&B."""
    print(f"\n{'-'*50}")
    print(f"Training {model_name}")
    print(f"{'-'*50}\n")
    
    # 1) Start run of W&B
    run = wandb.init(
        project="deep-lab-models",
        name=f"run-{model_name}",
        group="multi-model-transfer-learning",
        config={**common_config, **model_config}
    )

    # 2) Obtain and prepare the model
    base_model = get_pretrained_model(model_config["base_model"])
    print(f"Using {model_config['base_model']} as base model")
    model = CNN(base_model, num_classes, model_config["unfreeze_layers"]).to(device)

    # Show trainable parameters
    print("\nTrainable layers:")
    trainable_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"  {name}")
            trainable_params += param.numel()
    print(f"\nTotal trainable parameters: {trainable_params:,}")

    # Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=model_config["learning_rate"],
        weight_decay=model_config["weight_decay"]
    )

    # 3) Training with stopping and logging
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    best_val_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(common_config["epochs"]):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_preds, val_labels = validate(model, valid_loader, criterion, device)

        # Log metrics to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc
        })

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        # Manual early stopping 
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0

            torch.save(model.state_dict(), f"best_{model_name}.pt")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= common_config["patience"]:
                print("Early stopping triggered.")
                break

    # 4) Log 
    fig_cm = plot_confusion_matrix(val_labels, val_preds, class_names)
    wandb.log({"confusion_matrix": wandb.Image(fig_cm)})

    sample_fig = predict_sample_images(model, valid_loader, device, class_names, num_samples=5)
    wandb.log({"sample_predictions": wandb.Image(sample_fig)})

    # 5) Save final model
    model_filename = f"{model_name.lower().replace('-', '_')}_finetune.pt"
    torch.save(model.state_dict(), model_filename)
    artifact = wandb.Artifact(f"{model_name}-model", type="model")
    artifact.add_file(model_filename)
    run.log_artifact(artifact)

    run.finish()

    val_acc = history["val_acc"][-1]
    training_time = None  
    return {
        "model": model,
        "history": history,
        "val_preds": val_preds,
        "val_labels": val_labels,
        "val_acc": val_acc,
        "training_time": training_time,
        "config": {**model_config, **common_config}
    }


## 6. Model 1: ResNet50

Train our first model: ResNet50

In [18]:
# Train ResNet50
model_name = "ResNet50"
models_results[model_name] = train_single_model(model_name, model_configs[model_name], common_config)

# Plot training history
plot_training_history(models_results[model_name]["history"])

# Plot confusion matrix
plot_confusion_matrix(
    models_results[model_name]["val_labels"], 
    models_results[model_name]["val_preds"], 
    class_names
)

# Show sample predictions
predict_sample_images(
    models_results[model_name]["model"], 
    valid_loader, 
    device, 
    class_names, 
    num_samples=5
)


--------------------------------------------------
Training ResNet50
--------------------------------------------------

Using resnet50 as base model

Trainable layers:
  fc.0.weight
  fc.0.bias
  fc.3.weight
  fc.3.bias

Total trainable parameters: 2,113,551


KeyboardInterrupt: 

## 7. Model 2: ResNet18

Train our second model: ResNet18

In [None]:
# Train ResNet18
model_name = "ResNet18"
models_results[model_name] = train_single_model(model_name, model_configs[model_name], common_config)

# Plot training history
plot_training_history(models_results[model_name]["history"])

# Plot confusion matrix
plot_confusion_matrix(
    models_results[model_name]["val_labels"], 
    models_results[model_name]["val_preds"], 
    class_names
)

# Show sample predictions
predict_sample_images(
    models_results[model_name]["model"], 
    valid_loader, 
    device, 
    class_names, 
    num_samples=5
)


--------------------------------------------------
Training ResNet18
--------------------------------------------------



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/elenamartineztorrijos/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:04<00:00, 9.72MB/s]


Using resnet18 as base model

Trainable layers:
  fc.0.weight
  fc.0.bias
  fc.3.weight
  fc.3.bias

Total trainable parameters: 540,687


KeyboardInterrupt: 



## 8. Model 3: EfficientNet-B0

Train our third model: EfficientNet-B0

In [None]:
# Train EfficientNet-B0
model_name = "EfficientNet-B0"
models_results[model_name] = train_single_model(model_name, model_configs[model_name], common_config)

# Plot training history
plot_training_history(models_results[model_name]["history"])

# Plot confusion matrix
plot_confusion_matrix(
    models_results[model_name]["val_labels"], 
    models_results[model_name]["val_preds"], 
    class_names
)

# Show sample predictions
predict_sample_images(
    models_results[model_name]["model"], 
    valid_loader, 
    device, 
    class_names, 
    num_samples=5
)

## 9. Select the Best Model

Based on the comparison, select the best model:

In [None]:
# Find the model with the highest validation accuracy
best_model_name = max(models_results.items(), key=lambda x: x[1]['val_acc'])[0]
best_model = models_results[best_model_name]['model']
best_val_acc = models_results[best_model_name]['val_acc']

print(f"Best model: {best_model_name} with validation accuracy: {best_val_acc:.4f}")

# Save the best model with a special name
best_model_filename = f"best_model_{best_model_name.lower().replace('-', '_')}"
best_model.save_model(best_model_filename)
print(f"Best model saved as {best_model_filename}.pt")

## 10. Final Evaluation

Evaluate the best model on the validation set:

In [None]:
# Show predictions from the best model
print(f"Showing predictions from the best model: {best_model_name}")
predict_sample_images(best_model, valid_loader, device, class_names, num_samples=8)

# Plot confusion matrix for the best model
plot_confusion_matrix(
    models_results[best_model_name]["val_labels"],
    models_results[best_model_name]["val_preds"],
    class_names
)