# Denoising autoencoder

## 1. Notebook setup

### 1.1. Imports

In [None]:
# Standard library
import sys
from pathlib import Path

# Third-party
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Add src directory to path
sys.path.append(str(Path.cwd().parent))

# Local imports
from src.data_utils import load_coco_cached
from src.visualization import plot_image_grid
from src.noise import add_gaussian_noise, add_salt_pepper_noise, add_speckle_noise

### 1.2. Configuration

In [None]:
# Training configuration
TRAIN_MODEL = True  # Set to False to download pre-trained model from HuggingFace

# Training hyperparameters
LATENT_DIM = 128
EPOCHS = 60
BATCH_SIZE = 32
LEARNING_RATE = 0.001

# Noise parameters
NOISE_FACTOR = 0.25  # Train/evaluate strictly on Gaussian noise

# Callback settings
EARLY_STOPPING_PATIENCE = 8
CHECKPOINT_PATH = Path('../models/denoising_ae_latent128_best.keras')

# GPU configuration
GPU_ID = 1  # Which GPU to use (0-indexed). Set to None to use all available GPUs.

# Create models directory
models_dir = Path('../models')
models_dir.mkdir(exist_ok=True)

# Configure GPU
GPU_ID = 0  # Which GPU to use (0-indexed). Set to None to use all available GPUs.
gpus = tf.config.list_physical_devices('GPU')

if gpus:
    try:
        if GPU_ID is not None:

            # Use specific GPU
            tf.config.set_visible_devices(gpus[GPU_ID], 'GPU')
            tf.config.experimental.set_memory_growth(gpus[GPU_ID], True)
            print(f'Using GPU {GPU_ID}: {gpus[GPU_ID].name}')

        else:

            # Use all GPUs with memory growth
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)

            print(f'Using {len(gpus)} GPU(s): {[gpu.name for gpu in gpus]}')

    except RuntimeError as e:
        print(e)
else:
    print('No GPU available, using CPU')

## 2. Data preparation

In [None]:
# Load COCO dataset (cached subset)
# First run: Downloads full COCO (~95GB) and saves 10% subset to ../data/
# Subsequent runs: Loads quickly from cached subset file
(x_train, y_train), (x_test, y_test) = load_coco_cached(subset_percent=10, normalize=True)


print(f'Training set: {x_train.shape}')
print(f'Test set: {x_test.shape}')
print(f'Value range: [{x_train.min():.2f}, {x_train.max():.2f}]')

## Visualize Different Noise Types

In [None]:
# Select a sample image
sample_idx = 15
sample_image = x_train[sample_idx:sample_idx+1]

# Create noisy versions
gaussian_noisy = add_gaussian_noise(sample_image, noise_factor=0.1)
salt_pepper_noisy = add_salt_pepper_noise(sample_image, amount=0.1)
speckle_noisy = add_speckle_noise(sample_image, noise_factor=0.1)

# Plot comparison
fig, axes = plt.subplots(1, 4, figsize=(8, 4))

axes[0].imshow(sample_image[0])
axes[0].set_title('Original')
axes[0].axis('off')

axes[1].imshow(gaussian_noisy[0])
axes[1].set_title('Gaussian Noise')
axes[1].axis('off')

axes[2].imshow(salt_pepper_noisy[0])
axes[2].set_title('Salt & Pepper Noise')
axes[2].axis('off')

axes[3].imshow(speckle_noisy[0])
axes[3].set_title('Speckle Noise')
axes[3].axis('off')

plt.tight_layout()
plt.show()

## Create Training Data with Noise


For this example, we'll focus on Gaussian noise (both for training **and** quantitative evaluation). We'll still visualize other noise types later to highlight that the model specializes in the corruption it was trained on.

In [None]:
# Create noisy training and test sets using the same Gaussian factor used during training
noise_factor = NOISE_FACTOR


x_train_noisy = add_gaussian_noise(x_train, noise_factor=noise_factor)
x_test_noisy = add_gaussian_noise(x_test, noise_factor=noise_factor)


print(f'Created noisy datasets with Gaussian noise factor: {noise_factor}')
print(f'Noisy training set: {x_train_noisy.shape}')
print(f'Noisy test set: {x_test_noisy.shape}')

In [None]:
# Visualize clean vs noisy samples
n_samples = 8
indices = np.random.choice(len(x_train), n_samples, replace=False)

fig, axes = plt.subplots(2, n_samples, figsize=(20, 5))
fig.suptitle('Clean Images (top) vs Noisy Images (bottom)', fontsize=16, y=1.02)

for i, idx in enumerate(indices):
    axes[0, i].imshow(x_train[idx])
    axes[0, i].axis('off')
    
    axes[1, i].imshow(x_train_noisy[idx])
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

## Build Denoising Autoencoder

The architecture is similar to a compression autoencoder, but:
- **Input**: Noisy images
- **Output**: Clean images
- The bottleneck forces the network to learn noise-resistant features

In [None]:
def build_denoising_autoencoder(input_shape=(64, 64, 3), latent_dim=128):
    '''
    Build a denoising autoencoder.
    
    Args:
        input_shape: Shape of input images
        latent_dim: Dimension of latent bottleneck
    
    Returns:
        Compiled Keras model
    '''
    # Encoder
    encoder_input = keras.Input(shape=input_shape)
    
    x = layers.Conv2D(64, 3, activation='relu', padding='same')(encoder_input)
    x = layers.MaxPooling2D(2, padding='same')(x)  # 32x32
    
    x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D(2, padding='same')(x)  # 16x16
    
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D(2, padding='same')(x)  # 8x8
    
    x = layers.Conv2D(512, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D(2, padding='same')(x)  # 4x4
    
    # Bottleneck
    x = layers.Flatten()(x)
    latent = layers.Dense(latent_dim, activation='relu', name='latent')(x)
    
    # Decoder
    x = layers.Dense(4 * 4 * 512, activation='relu')(latent)
    x = layers.Reshape((4, 4, 512))(x)
    
    x = layers.Conv2DTranspose(512, 3, activation='relu', strides=2, padding='same')(x)  # 8x8
    x = layers.Conv2DTranspose(256, 3, activation='relu', strides=2, padding='same')(x)  # 16x16
    x = layers.Conv2DTranspose(128, 3, activation='relu', strides=2, padding='same')(x)  # 32x32
    x = layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)   # 64x64
    
    decoder_output = layers.Conv2D(3, 3, activation='sigmoid', padding='same')(x)
    
    # Build model
    autoencoder = keras.Model(encoder_input, decoder_output, name='denoising_autoencoder')
    
    # Compile
    autoencoder.compile(
        optimizer='adam',
        loss='mse',
        metrics=['mae']
    )
    
    return autoencoder

In [None]:
# Build model
model = build_denoising_autoencoder(latent_dim=128)
model.summary()

## Train the Denoising Autoencoder

Key insight: We train on **noisy inputs** but optimize to match **clean outputs**.

In [None]:
# Train the model
# Input: noisy images, Target: clean images
history = model.fit(
    x_train_noisy,  # Noisy input
    x_train,        # Clean target
    validation_data=(x_test_noisy, x_test),
    epochs=20,
    batch_size=32,
    shuffle=True,
    verbose=1
)

## Training History

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(8, 3))

# Loss
axes[0].set_title('Training Loss')
axes[0].plot(history.history['loss'], label='Train Loss')
axes[0].plot(history.history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].legend()

# MAE
axes[1].set_title('Mean Absolute Error')
axes[1].plot(history.history['mae'], label='Train MAE')
axes[1].plot(history.history['val_mae'], label='Val MAE')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('MAE')
axes[1].legend()

plt.tight_layout()
plt.show()

## Evaluate Denoising Performance

In [None]:
# Denoise test images
x_test_denoised = model.predict(x_test_noisy, verbose=0)

# Calculate metrics
mse_noisy = np.mean((x_test - x_test_noisy) ** 2)
mse_denoised = np.mean((x_test - x_test_denoised) ** 2)

print('Performance Metrics:')
print('=' * 50)
print(f'\nNoisy Images vs Clean:')
print(f'  MSE:  {mse_noisy:.6f}')

print(f'\nDenoised Images vs Clean:')
print(f'  MSE:  {mse_denoised:.6f}')

print(f'\nImprovement:')
print(f'  MSE:  {(mse_noisy - mse_denoised) / mse_noisy * 100:.1f}% reduction')

## Visual Comparison

In [None]:
# Visualize denoising results
n_samples = 8
indices = np.random.choice(len(x_test), n_samples, replace=False)

fig, axes = plt.subplots(3, n_samples, figsize=(20, 7.5))
fig.suptitle('Denoising Results: Original (top) vs Noisy (middle) vs Denoised (bottom)', 
             fontsize=16, y=0.995)

for i, idx in enumerate(indices):
    # Original clean image
    axes[0, i].imshow(x_test[idx])
    axes[0, i].axis('off')

    if i == 0:
        axes[0, i].set_ylabel('Original', fontsize=12, rotation=0, labelpad=40, va='center')
    
    # Noisy image
    axes[1, i].imshow(x_test_noisy[idx])
    axes[1, i].axis('off')

    if i == 0:
        axes[1, i].set_ylabel('Noisy', fontsize=12, rotation=0, labelpad=40, va='center')
    
    # Denoised image
    axes[2, i].imshow(x_test_denoised[idx])
    axes[2, i].axis('off')

    if i == 0:
        axes[2, i].set_ylabel('Denoised', fontsize=12, rotation=0, labelpad=40, va='center')
    
    # Add MSE metric for each image
    mse_val = np.mean((x_test[idx:idx+1] - x_test_denoised[idx:idx+1]) ** 2)
    axes[2, i].set_title(f'MSE: {mse_val:.4f}', fontsize=9)

plt.tight_layout()
plt.show()

## How the model handles unseen noise

In [None]:
# Compare quantitative performance on trained vs unseen noise types
comparison_configs = [
    ('Gaussian (trained)', lambda imgs: add_gaussian_noise(imgs, noise_factor=NOISE_FACTOR)),
    ('Salt & Pepper', lambda imgs: add_salt_pepper_noise(imgs, amount=min(0.5, NOISE_FACTOR / 2))),
    ('Speckle', lambda imgs: add_speckle_noise(imgs, noise_factor=NOISE_FACTOR))
]


generalization_results = []


for label, noise_fn in comparison_configs:
    noisy_batch = noise_fn(x_test)
    denoised_batch = model.predict(noisy_batch, verbose=0)
    mse_noisy_val = np.mean((x_test - noisy_batch) ** 2)
    mse_denoised_val = np.mean((x_test - denoised_batch) ** 2)
    improvement_pct = (mse_noisy_val - mse_denoised_val) / mse_noisy_val * 100
    generalization_results.append({
        'noise_type': label,
        'mse_noisy': mse_noisy_val,
        'mse_denoised': mse_denoised_val,
        'improvement_pct': improvement_pct
    })


print('\nModel performance across noise types (trained on Gaussian only):')
print('=' * 90)
print(f'{"Noise Type":<20} | {"MSE Noisy":>12} | {"MSE Denoised":>14} | {"Improvement":>12}')
print('-' * 90)
for r in generalization_results:
    print(
        f"{r['noise_type']:<20} | "
        f"{r['mse_noisy']:>12.6f} | "
        f"{r['mse_denoised']:>14.6f} | "
        f"{r['improvement_pct']:>10.1f}%"
    )

## Experiment: Different Noise Levels

Let's test how the model performs on different noise levels.

In [None]:
# Test on different noise levels
noise_levels = [0.1, 0.2, 0.3, 0.4, 0.5]
results = []

for noise_factor in noise_levels:

    # Create noisy images
    x_test_noisy_temp = add_gaussian_noise(x_test, noise_factor=noise_factor)
    
    # Denoise
    x_test_denoised_temp = model.predict(x_test_noisy_temp, verbose=0)
    
    # Calculate metrics
    mse_before = np.mean((x_test - x_test_noisy_temp) ** 2)
    mse_after = np.mean((x_test - x_test_denoised_temp) ** 2)
    improvement = (mse_before - mse_after) / mse_before * 100
    
    results.append({
        'noise_factor': noise_factor,
        'mse_before': mse_before,
        'mse_after': mse_after,
        'improvement_pct': improvement
    })

# Print results
print('\nPerformance across different noise levels:')
print('=' * 70)
print(f'{"Noise":>8} | {"MSE Before":>12} | {"MSE After":>12} | {"Improvement":>12}')
print('-' * 70)
for r in results:
    print(
        f"{r['noise_factor']:>8.2f} | "
        f"{r['mse_before']:>12.6f} | "
        f"{r['mse_after']:>12.6f} | "
        f"{r['improvement_pct']:>11.1f}%"
    )

In [None]:
# Plot performance across noise levels
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

noise_factors = [r['noise_factor'] for r in results]
mse_before_vals = [r['mse_before'] for r in results]
mse_after_vals = [r['mse_after'] for r in results]
improvement_vals = [r['improvement_pct'] for r in results]

# MSE comparison
axes[0].plot(
    noise_factors, mse_before_vals,
    'o-', label='Before Denoising', linewidth=2, markersize=8
)
axes[0].plot(
    noise_factors, mse_after_vals,
    's-', label='After Denoising', linewidth=2, markersize=8
)
axes[0].set_xlabel('Noise Factor', fontsize=12)
axes[0].set_ylabel('MSE', fontsize=12)
axes[0].set_title('MSE vs Noise Level', fontsize=14)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Relative improvement
axes[1].plot(
    noise_factors, improvement_vals,
    'd-', color='tab:green', linewidth=2, markersize=8
)
axes[1].set_xlabel('Noise Factor', fontsize=12)
axes[1].set_ylabel('MSE Improvement (%)', fontsize=12)
axes[1].set_title('Relative Improvement vs Noise', fontsize=14)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Save the Model

In [None]:
# Save the denoising model
model_path = '../models/denoising_ae_latent128.keras'
model.save(model_path)
print(f'Model saved to: {model_path}')