# CNN for Saliency Center Detection

This notebook implements a Convolutional Neural Network that detects the "saliency center" pixel of an image. The model takes RGB + Depth images (256×256×4) as input and outputs UV coordinates (u, v) of the center of the most salient object in the image.

## Table of Contents
1. Import Required Libraries
2. Model Architecture
3. Data Generation and Preprocessing
4. Training Setup
5. Model Training
6. Evaluation and Visualization
7. Testing and Inference

## 1. Import Required Libraries

First, let's import all the necessary libraries for building our saliency center detection model.

In [None]:
# Install required packages if not already installed
import subprocess
import sys

def install_package(package):
    try:
        __import__(package)
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Install required packages
packages = ["tensorflow", "numpy", "matplotlib", "opencv-python", "scikit-learn", "pillow"]
for package in packages:
    install_package(package)

print("All packages installed successfully!")

In [None]:
# Import required libraries
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

import numpy as np
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
import os
import warnings
warnings.filterwarnings('ignore')

# Check TensorFlow version
print(f"TensorFlow version: {tf.__version__}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Configure matplotlib
plt.style.use('default')
plt.rcParams['figure.figsize'] = [12, 8]

## 2. Model Architecture

Let's build the CNN model for detecting the saliency center. The model takes 256×256×4 images (RGB + Depth) and outputs normalized UV coordinates.

In [None]:
# CNN for detecting the "saliency center" pixel of an image.
# It should return the uv coordinates of the center of the most salient object in the image.

def create_saliency_center_model(input_shape=(256, 256, 4)):
    """
    Create a CNN model for saliency center detection.

    Args:
        input_shape: Shape of input images (height, width, channels)

    Returns:
        Compiled Keras model
    """
    model = Sequential([
        # Input: 256×256×4 (RGB + Depth)
        Input(shape=input_shape),

        # Convolutional blocks - extract features
        Conv2D(32, (3, 3), activation='relu', padding='same'),
        Conv2D(32, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),  # 128×128

        Conv2D(64, (3, 3), activation='relu', padding='same'),
        Conv2D(64, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),  # 64×64

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),  # 32×32

        Conv2D(256, (3, 3), activation='relu', padding='same'),
        Conv2D(256, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),  # 16×16

        # Flatten and dense layers
        Flatten(),
        Dense(512, activation='relu'),
        Dropout(0.5),
        Dense(256, activation='relu'),
        Dropout(0.3),

        # Output: 2 values (u, v coordinates) constrained to range [0, 1] in u and v
        Dense(2, activation='sigmoid')
    ])

    return model

# Create the model
saliency_center_model = create_saliency_center_model()

# Compile for regression
saliency_center_model.compile(
    optimizer='adam',
    loss='mse',  # Mean Squared Error for coordinate prediction
    metrics=['mae']  # Mean Absolute Error for monitoring
)

# Display model architecture
print("Saliency Center Detection Model Architecture:")
print("=" * 50)
saliency_center_model.summary()

# Calculate total parameters
total_params = saliency_center_model.count_params()
print(f"\nTotal trainable parameters: {total_params:,}")

## 3. Data Generation and Preprocessing

Since we don't have a real dataset, let's create synthetic data for demonstration purposes. In a real scenario, you would load your actual RGB+D images and corresponding saliency center annotations.

In [None]:
def generate_synthetic_data(num_samples=1000, img_size=(256, 256)):
    """
    Generate synthetic RGB+D images with known saliency centers for training.

    Args:
        num_samples: Number of samples to generate
        img_size: Size of images (height, width)

    Returns:
        X: Array of shape (num_samples, height, width, 4) - RGB+D images
        y: Array of shape (num_samples, 2) - Normalized UV coordinates
    """
    height, width = img_size
    X = np.zeros((num_samples, height, width, 4), dtype=np.float32)
    y = np.zeros((num_samples, 2), dtype=np.float32)

    for i in range(num_samples):
        # Create base image with noise
        rgb_img = np.random.uniform(0, 0.3, (height, width, 3))
        depth_img = np.random.uniform(0, 0.3, (height, width, 1))

        # Create salient object (bright circle or rectangle)
        center_u = np.random.uniform(0.2, 0.8)  # Avoid edges
        center_v = np.random.uniform(0.2, 0.8)

        # Convert normalized coordinates to pixel coordinates
        center_x = int(center_u * width)
        center_y = int(center_v * height)

        # Create salient object
        object_type = np.random.choice(['circle', 'rectangle'])

        if object_type == 'circle':
            radius = np.random.randint(20, 60)
            y_grid, x_grid = np.ogrid[:height, :width]
            mask = (x_grid - center_x)**2 + (y_grid - center_y)**2 <= radius**2

            # Bright colored circle
            color = np.random.uniform(0.7, 1.0, 3)
            for c in range(3):
                rgb_img[mask, c] = color[c]
            depth_img[mask, 0] = np.random.uniform(0.8, 1.0)

        else:  # rectangle
            size = np.random.randint(30, 80)
            x1 = max(0, center_x - size//2)
            x2 = min(width, center_x + size//2)
            y1 = max(0, center_y - size//2)
            y2 = min(height, center_y + size//2)

            # Bright colored rectangle
            color = np.random.uniform(0.7, 1.0, 3)
            rgb_img[y1:y2, x1:x2] = color.reshape(1, 1, 3)
            depth_img[y1:y2, x1:x2, 0] = np.random.uniform(0.8, 1.0)

        # Combine RGB and Depth
        X[i] = np.concatenate([rgb_img, depth_img], axis=2)
        y[i] = [center_u, center_v]  # Normalized coordinates

    return X, y

# Generate synthetic dataset
print("Generating synthetic dataset...")
num_samples = 2000
X_data, y_data = generate_synthetic_data(num_samples)

print(f"Generated {num_samples} samples")
print(f"Image shape: {X_data.shape}")
print(f"Coordinates shape: {y_data.shape}")
print(f"Coordinate ranges - U: [{y_data[:, 0].min():.3f}, {y_data[:, 0].max():.3f}]")
print(f"Coordinate ranges - V: [{y_data[:, 1].min():.3f}, {y_data[:, 1].max():.3f}]")

In [None]:
# Visualize some sample data
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle('Sample Synthetic Data', fontsize=16)

for i in range(4):
    # RGB visualization
    axes[0, i].imshow(X_data[i, :, :, :3])  # RGB channels
    axes[0, i].set_title(f'RGB - Sample {i+1}')
    axes[0, i].set_xticks([])
    axes[0, i].set_yticks([])

    # Mark the true saliency center
    u, v = y_data[i]
    center_x = u * 256
    center_y = v * 256
    axes[0, i].plot(center_x, center_y, 'r+', markersize=15, markeredgewidth=3)
    axes[0, i].text(10, 30, f'UV: ({u:.2f}, {v:.2f})', color='red',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

    # Depth visualization
    axes[1, i].imshow(X_data[i, :, :, 3], cmap='gray')  # Depth channel
    axes[1, i].set_title(f'Depth - Sample {i+1}')
    axes[1, i].set_xticks([])
    axes[1, i].set_yticks([])
    axes[1, i].plot(center_x, center_y, 'r+', markersize=15, markeredgewidth=3)

plt.tight_layout()
plt.show()

## 4. Training Setup

Let's split our data into training and validation sets and set up training callbacks.

In [None]:
# Split data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    X_data, y_data, test_size=0.2, random_state=42
)

print(f"Training set size: {X_train.shape[0]} samples")
print(f"Validation set size: {X_val.shape[0]} samples")

# Define training callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    ModelCheckpoint(
        'saliency_center_model_best.h5',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
]

# Training parameters
batch_size = 16
epochs = 50

print(f"Training configuration:")
print(f"  Batch size: {batch_size}")
print(f"  Max epochs: {epochs}")
print(f"  Callbacks: Early stopping, Learning rate reduction, Model checkpointing")

## 5. Model Training

Now let's train our saliency center detection model.

In [None]:
# Train the model
print("Starting training...")
print("=" * 50)

history = saliency_center_model.fit(
    X_train, y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    verbose=1
)

print("\nTraining completed!")
print("=" * 50)

# Print final metrics
final_train_loss = history.history['loss'][-1]
final_val_loss = history.history['val_loss'][-1]
final_train_mae = history.history['mae'][-1]
final_val_mae = history.history['val_mae'][-1]

print(f"Final Training Loss (MSE): {final_train_loss:.6f}")
print(f"Final Validation Loss (MSE): {final_val_loss:.6f}")
print(f"Final Training MAE: {final_train_mae:.6f}")
print(f"Final Validation MAE: {final_val_mae:.6f}")

## 6. Evaluation and Visualization

Let's visualize the training progress and evaluate the model performance.

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Loss plot
ax1.plot(history.history['loss'], label='Training Loss', linewidth=2)
ax1.plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
ax1.set_title('Model Loss (MSE)', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# MAE plot
ax2.plot(history.history['mae'], label='Training MAE', linewidth=2)
ax2.plot(history.history['val_mae'], label='Validation MAE', linewidth=2)
ax2.set_title('Model Mean Absolute Error', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('MAE')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Calculate pixel-level error statistics
def calculate_pixel_errors(y_true, y_pred, img_size=(256, 256)):
    """Calculate pixel-level errors from normalized coordinates"""
    # Convert normalized coordinates to pixel coordinates
    pixel_true = y_true * np.array([img_size[1], img_size[0]])  # [width, height]
    pixel_pred = y_pred * np.array([img_size[1], img_size[0]])

    # Calculate Euclidean distance in pixels
    pixel_errors = np.sqrt(np.sum((pixel_true - pixel_pred)**2, axis=1))

    return pixel_errors, pixel_true, pixel_pred

# Evaluate on validation set
val_predictions = saliency_center_model.predict(X_val, verbose=0)
pixel_errors, pixel_true, pixel_pred = calculate_pixel_errors(y_val, val_predictions)

print("Validation Set Performance:")
print("=" * 40)
print(f"Mean Squared Error (normalized): {mean_squared_error(y_val, val_predictions):.6f}")
print(f"Mean Absolute Error (normalized): {mean_absolute_error(y_val, val_predictions):.6f}")
print(f"Mean pixel error: {np.mean(pixel_errors):.2f} pixels")
print(f"Median pixel error: {np.median(pixel_errors):.2f} pixels")
print(f"95th percentile pixel error: {np.percentile(pixel_errors, 95):.2f} pixels")
print(f"Max pixel error: {np.max(pixel_errors):.2f} pixels")

In [None]:
# Visualize prediction results
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
fig.suptitle('Saliency Center Detection Results', fontsize=16, fontweight='bold')

for i in range(4):
    idx = i * 10  # Sample different images

    # Original RGB image
    axes[0, i].imshow(X_val[idx, :, :, :3])
    axes[0, i].set_title(f'RGB Image {idx+1}')
    axes[0, i].set_xticks([])
    axes[0, i].set_yticks([])

    # Mark true and predicted centers
    true_u, true_v = y_val[idx]
    pred_u, pred_v = val_predictions[idx]

    true_x, true_y = true_u * 256, true_v * 256
    pred_x, pred_y = pred_u * 256, pred_v * 256

    axes[0, i].plot(true_x, true_y, 'g+', markersize=15, markeredgewidth=3, label='True')
    axes[0, i].plot(pred_x, pred_y, 'r+', markersize=15, markeredgewidth=3, label='Predicted')
    if i == 0:
        axes[0, i].legend()

    # Depth image
    axes[1, i].imshow(X_val[idx, :, :, 3], cmap='gray')
    axes[1, i].set_title(f'Depth Image {idx+1}')
    axes[1, i].set_xticks([])
    axes[1, i].set_yticks([])
    axes[1, i].plot(true_x, true_y, 'g+', markersize=15, markeredgewidth=3)
    axes[1, i].plot(pred_x, pred_y, 'r+', markersize=15, markeredgewidth=3)

    # Error visualization
    error_map = np.zeros((256, 256, 3))
    error_map[:, :, 0] = X_val[idx, :, :, 0]  # Use red channel as base

    # Draw error line
    cv2.line(error_map,
            (int(true_x), int(true_y)),
            (int(pred_x), int(pred_y)),
            (1, 1, 0), 2)  # Yellow line

    axes[2, i].imshow(error_map)
    axes[2, i].set_title(f'Error: {pixel_errors[idx]:.1f} pixels')
    axes[2, i].set_xticks([])
    axes[2, i].set_yticks([])

    # Add coordinate text
    axes[2, i].text(10, 30, f'True: ({true_u:.2f}, {true_v:.2f})',
                   color='green', fontweight='bold',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    axes[2, i].text(10, 60, f'Pred: ({pred_u:.2f}, {pred_v:.2f})',
                   color='red', fontweight='bold',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

plt.tight_layout()
plt.show()

## 7. Testing and Inference

Let's test the model with new synthetic data and demonstrate how to use it for inference.

In [None]:
# Generate test data
print("Generating test data...")
X_test, y_test = generate_synthetic_data(200)
print(f"Test set size: {X_test.shape[0]} samples")

# Test the model
test_predictions = saliency_center_model.predict(X_test, verbose=0)
test_pixel_errors, test_pixel_true, test_pixel_pred = calculate_pixel_errors(y_test, test_predictions)

print("\nTest Set Performance:")
print("=" * 40)
print(f"Mean Squared Error (normalized): {mean_squared_error(y_test, test_predictions):.6f}")
print(f"Mean Absolute Error (normalized): {mean_absolute_error(y_test, test_predictions):.6f}")
print(f"Mean pixel error: {np.mean(test_pixel_errors):.2f} pixels")
print(f"Median pixel error: {np.median(test_pixel_errors):.2f} pixels")
print(f"95th percentile pixel error: {np.percentile(test_pixel_errors, 95):.2f} pixels")

# Function for inference on a single image
def predict_saliency_center(model, image):
    """
    Predict saliency center for a single image.

    Args:
        model: Trained Keras model
        image: Input image of shape (height, width, 4) - RGB+D

    Returns:
        u, v: Normalized coordinates of saliency center
        pixel_x, pixel_y: Pixel coordinates of saliency center
    """
    # Ensure image is the right shape
    if len(image.shape) == 3:
        image = np.expand_dims(image, axis=0)  # Add batch dimension

    # Predict
    prediction = model.predict(image, verbose=0)[0]
    u, v = prediction

    # Convert to pixel coordinates
    height, width = image.shape[1:3]
    pixel_x = u * width
    pixel_y = v * height

    return u, v, pixel_x, pixel_y

# Demonstrate inference on a few test examples
print("\nInference Examples:")
print("=" * 40)

for i in range(3):
    # Select a test image
    test_img = X_test[i]
    true_u, true_v = y_test[i]

    # Predict saliency center
    pred_u, pred_v, pred_x, pred_y = predict_saliency_center(saliency_center_model, test_img)

    # Calculate error
    pixel_error = np.sqrt((true_u*256 - pred_x)**2 + (true_v*256 - pred_y)**2)

    print(f"Image {i+1}:")
    print(f"  True center (UV): ({true_u:.3f}, {true_v:.3f})")
    print(f"  Predicted center (UV): ({pred_u:.3f}, {pred_v:.3f})")
    print(f"  Predicted center (pixels): ({pred_x:.1f}, {pred_y:.1f})")
    print(f"  Pixel error: {pixel_error:.2f}")
    print()

In [None]:
# Error distribution analysis
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Histogram of pixel errors
ax1.hist(test_pixel_errors, bins=30, alpha=0.7, edgecolor='black')
ax1.axvline(np.mean(test_pixel_errors), color='red', linestyle='--',
           label=f'Mean: {np.mean(test_pixel_errors):.2f}')
ax1.axvline(np.median(test_pixel_errors), color='green', linestyle='--',
           label=f'Median: {np.median(test_pixel_errors):.2f}')
ax1.set_xlabel('Pixel Error')
ax1.set_ylabel('Frequency')
ax1.set_title('Distribution of Pixel Errors', fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Scatter plot of predicted vs true coordinates
ax2.scatter(y_test[:, 0], test_predictions[:, 0], alpha=0.6, label='U coordinate')
ax2.scatter(y_test[:, 1], test_predictions[:, 1], alpha=0.6, label='V coordinate')
ax2.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect prediction')
ax2.set_xlabel('True Coordinates')
ax2.set_ylabel('Predicted Coordinates')
ax2.set_title('Predicted vs True Coordinates', fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)

plt.tight_layout()
plt.show()

# Final model summary
print("\n" + "="*60)
print("SALIENCY CENTER DETECTION MODEL SUMMARY")
print("="*60)
print(f"Model Architecture: CNN with 4 conv blocks + 2 dense layers")
print(f"Input Shape: (256, 256, 4) - RGB + Depth")
print(f"Output Shape: (2,) - Normalized UV coordinates [0,1]")
print(f"Total Parameters: {saliency_center_model.count_params():,}")
print(f"Training Samples: {X_train.shape[0]}")
print(f"Validation Samples: {X_val.shape[0]}")
print(f"Test Samples: {X_test.shape[0]}")
print(f"\nFinal Performance:")
print(f"  Mean Pixel Error: {np.mean(test_pixel_errors):.2f} pixels")
print(f"  Median Pixel Error: {np.median(test_pixel_errors):.2f} pixels")
print(f"  Success Rate (< 20 pixels): {np.sum(test_pixel_errors < 20) / len(test_pixel_errors) * 100:.1f}%")
print("="*60)

## Usage Notes and Next Steps

### How to Use This Model:

1. **Input Format**: The model expects 256×256×4 images (RGB + Depth channels)
2. **Output Format**: Returns normalized UV coordinates [0,1] representing the saliency center
3. **Preprocessing**: Ensure input images are normalized to [0,1] range

### For Real Applications:

1. **Replace Synthetic Data**: Use your actual RGB+D dataset with manually annotated saliency centers
2. **Data Augmentation**: Add rotation, scaling, and color jittering to improve generalization
3. **Transfer Learning**: Pre-train on synthetic data, then fine-tune on real data
4. **Multi-scale Training**: Train on different image resolutions for robustness

### Potential Improvements:

- **Attention Mechanisms**: Add spatial attention layers to focus on salient regions
- **Multi-task Learning**: Predict both center and saliency maps simultaneously  
- **Uncertainty Estimation**: Add prediction confidence/uncertainty quantification
- **Real-time Optimization**: Model quantization and pruning for deployment