# üéµ MOSS: Multi-Objective Sound Synthesis

## Pareto Optimization for Audio-Visual Spectrogram Blending

This notebook is a **self-contained, comprehensive demonstration** of the MOSS system. It implements
the **exact same algorithms** as the web frontend, allowing you to understand and experiment with
the optimization process.

### The Problem

Given:
- A target **image** (e.g., Mona Lisa)
- A target **audio** (e.g., Tchaikovsky)

Find a spectrogram that:
1. **Looks like** the image when visualized
2. **Sounds like** the audio when converted back to waveform

This is inherently a **multi-objective optimization** problem because improving visual fidelity
typically degrades audio quality, and vice versa.

### Key Concepts

1. **Mask-Based Blending**: We blend image and audio spectrograms using a learnable mask
2. **Pareto Optimality**: A solution is Pareto-optimal if you can't improve one objective without worsening the other
3. **Hybrid Optimization**: We combine gradient descent (for fast seeding) with evolutionary algorithms (for exploration)

## 1. Configuration and Imports

All parameters match the production MOSS system exactly.

In [None]:
import sys
sys.path.insert(0, '.')  # Add project root to path

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image
from IPython.display import HTML, Audio, display
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION - Copied from src/config.py EXACTLY
# ============================================================================
SAMPLE_RATE = 16000       # 16kHz for efficiency
N_FFT = 1024              # 513 frequency bins (N_FFT/2 + 1)
HOP_LENGTH = 256          # 75% overlap between frames
WIN_LENGTH = N_FFT

# Control grid resolution (the optimization variable)
GRID_HEIGHT = 64
GRID_WIDTH = 128
N_PARAMS = GRID_HEIGHT * GRID_WIDTH  # 8192 parameters

# Proxy optimization for speed (1/4 resolution)
PROXY_HEIGHT = 129        # 513 // 4 ‚âà 128, use 129 for safety
DEVICE = 'cpu'

# Optimization hyperparameters (EXACT match to production)
SINGLE_STEPS = 500        # Steps for single-objective optimization
PARETO_SEED_STEPS = 200   # Gradient seeding steps for Pareto
PARETO_GENERATIONS = 50   # NSGA-II generations for Pareto
PARETO_SEED_POP = 10      # Population for gradient seeding
PARETO_EVOL_POP = 100     # Population for evolutionary phase

def plot_spectrogram(ax, mag_tensor, title=None, show_colorbar=False):
    """
    Plot spectrogram with EXACT same scaling as frontend.
    
    Uses robust percentile-based scaling:
    - vmax = 99.5th percentile of dB values
    - vmin = vmax - 80 dB (80 dB dynamic range)
    """
    spec_db = 20 * torch.log10(mag_tensor + 1e-8).cpu().numpy()
    ref_max = np.percentile(spec_db, 99.5)
    vmin = ref_max - 80
    vmax = ref_max
    
    im = ax.imshow(spec_db, origin='lower', aspect='auto', cmap='magma',
                   vmin=vmin, vmax=vmax)
    if title:
        ax.set_title(title)
    if show_colorbar:
        plt.colorbar(im, ax=ax, label='dB')
    return im

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {DEVICE}")
print(f"Control grid: {GRID_HEIGHT}√ó{GRID_WIDTH} = {N_PARAMS} parameters")

## 2. Loss Functions

### Visual Loss: Mean Absolute Error (MAE)

$$\mathcal{L}_{visual} = \frac{1}{F \cdot T} \sum_{f,t} |M_{mixed}(f,t) - M_{image}(f,t)|$$

We use MAE in linear magnitude domain for computational efficiency.

### Audio Loss: Log-Domain L1

$$\mathcal{L}_{audio} = \frac{1}{F \cdot T} \sum_{f,t} |\log(M_{mixed}(f,t)) - \log(M_{audio}(f,t))|$$

The log domain matches human perception: equal dB differences feel equally significant.

### Loss Normalization

Both losses are normalized to [0, 1] by dividing by their worst-case values:
- **Worst visual**: Pure audio spectrogram (mask = 0)
- **Worst audio**: Pure image spectrogram (mask = 1)

### 20√ó Visual Boost

During optimization, visual loss is multiplied by 20√ó to ensure the image is clearly visible.
This compensates for the natural dominance of audio preservation in the loss landscape.

In [None]:
def calc_audio_mag_loss(mixed_mag, target_audio_mag):
    """
    Audio loss: L1 in log-magnitude domain.
    
    This matches human perception where equal dB differences
    feel equally significant regardless of absolute level.
    
    Args:
        mixed_mag: (B, F, T) or (F, T) - mixed spectrogram magnitude
        target_audio_mag: (1, F, T) or (F, T) - target audio magnitude
    
    Returns:
        Loss tensor of shape (B,) or scalar
    """
    mixed_log = torch.log(mixed_mag + 1e-8)
    target_log = torch.log(target_audio_mag + 1e-8)
    target_log = target_log.expand_as(mixed_log)
    
    loss = F.l1_loss(mixed_log, target_log, reduction='none')
    
    if mixed_mag.dim() == 3:
        return loss.mean(dim=(1, 2))  # (B,)
    return loss.mean()


def calc_visual_loss(mixed_mag, target_image_mag):
    """
    Visual loss: Mean Absolute Error in linear magnitude domain.
    
    Args:
        mixed_mag: (B, F, T) or (F, T)
        target_image_mag: (1, F, T) or (F, T)
    
    Returns:
        Loss tensor of shape (B,) or scalar
    """
    diff = torch.abs(mixed_mag - target_image_mag)
    
    if mixed_mag.dim() == 3:
        return diff.mean(dim=(1, 2))  # (B,)
    return diff.mean()

## 3. Gaussian Smoothing for Musicality

The mask is smoothed using a Gaussian kernel to enforce **temporal and spectral continuity**.
This prevents harsh transitions that would create unmusical artifacts.

$$\text{mask}_{smooth} = \text{mask} * G_\sigma$$

where $G_\sigma$ is a 2D Gaussian kernel with standard deviation $\sigma$.

In [None]:
def gaussian_blur_2d(x, sigma=1.0):
    """
    Apply separable Gaussian blur to enforce smooth masks.
    
    Args:
        x: Input tensor (B, H, W) or (H, W)
        sigma: Standard deviation of Gaussian kernel
    
    Returns:
        Blurred tensor of same shape
    """
    if sigma < 0.5:
        return x
    
    # Kernel size: 6œÉ ensures we capture 99.7% of the Gaussian
    kernel_size = int(6 * sigma) | 1  # Ensure odd
    kernel_size = max(3, kernel_size)
    
    # Create 1D Gaussian kernel
    coords = torch.arange(kernel_size, dtype=x.dtype, device=x.device) - kernel_size // 2
    g = torch.exp(-coords**2 / (2 * sigma**2))
    g = g / g.sum()  # Normalize
    
    # Ensure 4D for conv2d
    pad = kernel_size // 2
    original_dim = x.dim()
    
    if x.dim() == 2:
        x = x.unsqueeze(0).unsqueeze(0)
    elif x.dim() == 3:
        x = x.unsqueeze(1)
    
    # Separable convolution (faster than 2D)
    x = F.conv2d(F.pad(x, (pad, pad, 0, 0), mode='replicate'), 
                 g.view(1, 1, 1, -1), padding=0)
    x = F.conv2d(F.pad(x, (0, 0, pad, pad), mode='replicate'), 
                 g.view(1, 1, -1, 1), padding=0)
    
    # Restore original dimensions
    if original_dim == 2:
        return x.squeeze(0).squeeze(0)
    elif original_dim == 3:
        return x.squeeze(1)
    return x

## 4. The Mask Encoder

The encoder converts a **low-resolution control grid** (64√ó128) into:
1. A **blended spectrogram** (513√óT)
2. Optionally, **reconstructed audio** via inverse STFT

### Blending Formula

$$M_{mixed} = \text{mask} \cdot M_{image} + (1 - \text{mask}) \cdot M_{audio}$$

where:
- $\text{mask} = 1$: Show the image
- $\text{mask} = 0$: Keep the audio

### Proxy Optimization

For speed, we optimize at **1/4 resolution** (129 frequency bins instead of 513).
This provides a **4√ó speedup** with minimal quality loss.

In [None]:
# =============================================================================
# MaskProcessor - EXACT COPY from src/audio_encoder.py
# =============================================================================
class MaskProcessor(nn.Module):
    """Helper module for Mask generation to separate convolution logic."""

    def __init__(self, h, w, sigma, device):
        super().__init__()
        self.h = h
        self.w = w
        self.device = device
        self._init_gaussian_kernel(sigma)

    def _init_gaussian_kernel(self, sigma):
        kernel_size = int(2 * math.ceil(2 * sigma) + 1)
        x_cord = torch.arange(kernel_size)
        x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
        y_grid = x_grid.t()
        xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()

        mean = (kernel_size - 1) / 2.0
        variance = sigma**2.0
        gaussian_kernel = (1.0 / (2.0 * math.pi * variance)) * torch.exp(
            -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance)
        )
        gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
        self.blur_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size).to(
            self.device
        )
        self.kernel_padding = kernel_size // 2

    def forward(self, params, target_h, target_w):
        B = params.shape[0]
        grid = params.view(B, 1, self.h, self.w)

        # Blur the Low-Res Grid instead of the High-Res Up-sampled Mask
        grid_blurred = F.conv2d(grid, self.blur_kernel, padding=self.kernel_padding)

        mask = F.interpolate(
            grid_blurred,
            size=(target_h, target_w),
            mode="bilinear",
            align_corners=False,
        )

        return mask.squeeze(1)


# =============================================================================
# MaskEncoder - EXACT COPY from src/audio_encoder.py
# =============================================================================
class MaskEncoder(nn.Module):
    """Encodes parameters into audio via mask-based spectrogram blending."""

    def __init__(
        self,
        target_image: torch.Tensor,
        target_audio_path: str,
        grid_height: int = 128,
        grid_width: int = 256,
        smoothing_sigma: float = 1.0,
        device: str = DEVICE,
    ):
        super().__init__()
        self.device = device
        self.grid_height = grid_height
        self.grid_width = grid_width
        self.n_params = grid_height * grid_width
        self.smoothing_sigma = smoothing_sigma

        # 1. Load and Process Target Audio
        audio, sr = torchaudio.load(target_audio_path)
        if sr != SAMPLE_RATE:
            audio = torchaudio.functional.resample(audio, sr, SAMPLE_RATE)

        audio = audio.mean(dim=0, keepdim=True)  # Mix to mono

        self.mask_processor = MaskProcessor(
            grid_height, grid_width, smoothing_sigma, device
        )

        self.register_buffer("target_audio_waveform", audio)

        # Compute STFT
        window = torch.hann_window(N_FFT).to(device)
        stft = torch.stft(
            audio.to(device),
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            win_length=WIN_LENGTH,
            window=window,
            return_complex=True,
        )

        self.audio_mag = stft.abs() + 1e-8
        self.audio_phase = stft.angle()

        self.full_height = self.audio_mag.shape[1]
        self.full_width = self.audio_mag.shape[2]

        self.audio_log = torch.log(self.audio_mag)

        # 2. Process Target Image
        img = target_image.to(device)
        if img.dim() == 3:
            img = img.unsqueeze(0)

        target_h_visual = self.full_height
        # FLIP image vertically: spectrograms have low freq at bottom, images have origin at top
        img_flipped = torch.flip(img, dims=[-2])
        img_full_freq = F.interpolate(
            img_flipped,
            size=(target_h_visual, self.full_width),
            mode="bilinear",
            align_corners=False,
        )
        img_resized = img_full_freq

        # Dynamic Gain Staging
        audio_log_max = torch.quantile(self.audio_log, 0.98)
        audio_max_val = self.audio_log.max()
        audio_floor_val = torch.quantile(self.audio_log, 0.01)

        target_ceiling = audio_max_val - 0.5
        headroom_nat = (audio_log_max - target_ceiling).item()
        dynamic_range_nat = (target_ceiling - audio_floor_val).item() + 0.5
        dynamic_range_nat = max(4.0, min(dynamic_range_nat, 12.0))

        print("Dynamic Gain Staging:")
        print(f"  > Audio Max: {audio_max_val:.2f}, Floor (q01): {audio_floor_val:.2f}")
        print(f"  > Target Ceiling: {target_ceiling:.2f}")
        print(f"  > Adaptive Dynamic Range: {dynamic_range_nat:.2f}")

        audio_log_ceil = audio_log_max - headroom_nat
        audio_log_floor = audio_log_ceil - dynamic_range_nat

        self.audio_log = torch.clamp(self.audio_log, min=audio_log_floor)

        img_01 = img_resized.squeeze(0)
        img_01 = (img_01 - img_01.min()) / (img_01.max() - img_01.min() + 1e-8)
        img_01 = img_01.pow(1.8)  # Gamma correction

        self.image_log = img_01 * (audio_log_ceil - audio_log_floor) + audio_log_floor
        self.image_mag = torch.exp(self.image_log)
        self.audio_mag_static = torch.exp(self.audio_log)

        # PROXY OPTIMIZATION SETUP
        self.proxy_height = 129
        self.proxy_width = self.full_width // 2

        self.image_mag_proxy = F.interpolate(
            self.image_mag.unsqueeze(0),
            size=(self.proxy_height, self.proxy_width),
            mode="bilinear",
            align_corners=False,
        ).squeeze(0)

        self.audio_mag_proxy = F.interpolate(
            self.audio_mag_static.unsqueeze(0),
            size=(self.proxy_height, self.proxy_width),
            mode="bilinear",
            align_corners=False,
        ).squeeze(0)

        # Expose PROXY as the default reference for Optimizers
        self.image_mag_ref = self.image_mag_proxy
        self.audio_mag = self.audio_mag_proxy

        # Keep FULL res for final Export/Reconstruction
        self.image_mag_full = self.image_mag
        self.audio_mag_full = self.audio_mag_static

        print(f"Spectrogram: {self.full_height}√ó{self.full_width}")
        print(f"Proxy: {self.proxy_height}√ó{self.proxy_width}")

    def _compute_spectrogram(self, mask, img_mag, aud_mag):
        return mask * img_mag + (1 - mask) * aud_mag

    def _reconstruct_audio(self, mixed_mag, phase):
        complex_stft = torch.polar(mixed_mag, phase)

        audio_recon = torch.istft(
            complex_stft,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            win_length=WIN_LENGTH,
            window=torch.hann_window(N_FFT, device=mixed_mag.device),
        )

        max_val = audio_recon.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-8)
        audio_recon = audio_recon / max_val * 0.9
        return audio_recon

    def forward(self, params: torch.Tensor, return_wav: bool = True):
        batch_size = params.shape[0]

        # Generate Mask - target PROXY or FULL size based on mode
        target_h = self.proxy_height if not return_wav else self.full_height
        target_w = self.proxy_width if not return_wav else self.full_width

        mask = self.mask_processor(params, target_h, target_w)

        # Expand sources (Use PROXY or FULL based on mode)
        if not return_wav:
            img_mag = self.image_mag_ref.expand(batch_size, -1, -1)
            aud_mag = self.audio_mag.expand(batch_size, -1, -1)
        else:
            img_mag = self.image_mag_full.expand(batch_size, -1, -1)
            aud_mag = self.audio_mag_full.expand(batch_size, -1, -1)

        phase = self.audio_phase.expand(batch_size, -1, -1) if return_wav else None

        mixed_mag = self._compute_spectrogram(mask, img_mag, aud_mag)

        audio_recon = None
        if return_wav:
            audio_recon = self._reconstruct_audio(mixed_mag, phase)

        return audio_recon, mixed_mag

## 5. Load Data: Mona Lisa + Tchaikovsky

In [None]:
# Load target image (matches backend preprocess_image)
img_path = 'data/input/monalisa.jpg'
img_pil = Image.open(img_path).convert('L')  # Grayscale
img_tensor = torch.from_numpy(np.array(img_pil)).float() / 255.0
img_tensor = img_tensor.unsqueeze(0)  # Add channel dim: (1, H, W)

# Calculate grid dimensions (matches backend service.py lines 86-92)
audio_path = 'data/input/tchaikovsky.mp3'
waveform, sr = torchaudio.load(audio_path)
duration_sec = waveform.shape[-1] / sr
raw_width = int(duration_sec * 4.0)
grid_width = ((raw_width + 15) // 16) * 16
if grid_width < 16:
    grid_width = 16
grid_height = 64  # Matches backend

# Create encoder with EXACT backend parameters
sigma = 5.0  # Matches backend
encoder = MaskEncoder(
    img_tensor, 
    audio_path,
    grid_height=grid_height,
    grid_width=grid_width,
    smoothing_sigma=sigma,
    device=DEVICE
)

# Update notebook constants to match
GRID_HEIGHT = grid_height
GRID_WIDTH = grid_width
N_PARAMS = GRID_HEIGHT * GRID_WIDTH

print(f"\nGrid: {GRID_HEIGHT}√ó{GRID_WIDTH} = {N_PARAMS} parameters")

In [None]:
# Visualize inputs
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original image
axes[0].imshow(img_pil, cmap='gray')
axes[0].set_title('Target Image: Mona Lisa', fontsize=12)
axes[0].axis('off')

# Image as spectrogram (with frontend-matching scaling)
plot_spectrogram(axes[1], encoder.image_mag_full[0], 'Image ‚Üí Spectrogram')
axes[1].set_xlabel('Time (frames)')
axes[1].set_ylabel('Frequency (bins)')

# Target audio spectrogram
plot_spectrogram(axes[2], encoder.audio_mag_full[0], 'Target Audio: Tchaikovsky')
axes[2].set_xlabel('Time (frames)')

plt.tight_layout()
plt.show()

# Play original audio
print("üîä Original Tchaikovsky:")
display(Audio(encoder.target_audio_waveform[0].numpy(), rate=SAMPLE_RATE))

## 6. The Pareto Manager

This class manages a **population of masks** for multi-objective optimization.
It implements the **scalarization** approach where each individual optimizes a different
weighted combination of visual and audio objectives.

### Weight Distribution

To counteract the 20√ó visual boost, we use a **power-law weight distribution**:

$$w_{visual}^{(i)} = \left(1 - \frac{i}{N-1}\right)^4$$

This bunches more weights towards low visual priority, ensuring balanced coverage of the Pareto front.

### Normalization

Both losses are normalized by their worst-case values:

$$\hat{\mathcal{L}}_{visual} = \frac{\mathcal{L}_{visual}}{\max(\mathcal{L}_{visual})}$$
$$\hat{\mathcal{L}}_{audio} = \frac{\mathcal{L}_{audio}}{\max(\mathcal{L}_{audio})}$$

In [None]:
class ParetoManager(nn.Module):
    """
    Manages a population of masks for Pareto optimization.
    
    Uses scalarization with power-law weight distribution to find diverse
    solutions along the Pareto front.
    """
    
    def __init__(self, encoder, pop_size=10, learning_rate=0.05):
        super().__init__()
        self.encoder = encoder
        self.pop_size = pop_size
        self.grid_h = encoder.grid_height
        self.grid_w = encoder.grid_width
        
        # Initialize population in logit space (unbounded, sigmoid to [0,1])
        self.mask_logits = nn.Parameter(
            torch.randn(pop_size, self.grid_h * self.grid_w) * 0.5
        )
        
        # ====================================================================
        # Power-law weight distribution (counteracts 20x visual boost)
        # ====================================================================
        alpha = torch.linspace(0, 1, pop_size)
        
        # weights = (1 - alpha)^4
        # This bunches values near 0, so more individuals prioritize audio
        # which balances the 20x visual boost during optimization
        self.weights_img = (1.0 - alpha).pow(4.0)
        self.weights_aud = 1.0 - self.weights_img
        
        # Normalize to sum to 1
        total = self.weights_img + self.weights_aud
        self.weights_img = self.weights_img / total
        self.weights_aud = self.weights_aud / total
        
        # Force anchors: pure image and pure audio at extremes
        if pop_size >= 2:
            with torch.no_grad():
                self.mask_logits[0].fill_(10.0)   # Index 0: Pure image
                self.mask_logits[-1].fill_(-10.0)  # Index -1: Pure audio
        
        # Optimizer
        self.optimizer = optim.Adam([self.mask_logits], lr=learning_rate)
        
        # Normalization factors (will be computed)
        self.scale_vis = 1.0
        self.scale_aud = 1.0
        
    def calculate_normalization(self):
        """
        Calculate worst-case losses to normalize objectives to [0, 1].
        """
        print("Calculating normalization factors...")
        with torch.no_grad():
            # Worst visual: pure audio (mask = 0)
            max_vis_loss = torch.abs(
                self.encoder.audio_mag - self.encoder.image_mag_ref
            ).mean().item()
            
            # Worst audio: pure image (mask = 1)
            max_aud_loss = calc_audio_mag_loss(
                self.encoder.image_mag_ref, 
                self.encoder.audio_mag
            ).item()
            
            # Avoid division by zero
            self.scale_vis = 1.0 / max(max_vis_loss, 1e-6)
            self.scale_aud = 1.0 / max(max_aud_loss, 1e-6)
            
            print(f"  Max visual loss: {max_vis_loss:.4f} ‚Üí scale: {self.scale_vis:.4f}")
            print(f"  Max audio loss:  {max_aud_loss:.4f} ‚Üí scale: {self.scale_aud:.4f}")
    
    def optimize_step(self):
        """
        Perform one gradient descent step for all population members.
        
        MATCHES BACKEND: encoder(return_wav=False) automatically uses proxy.
        
        Returns:
            avg_vis: Average normalized visual loss
            avg_aud: Average normalized audio loss
        """
        self.optimizer.zero_grad()
        
        # Get current masks
        masks = torch.sigmoid(self.mask_logits)
        
        # Forward pass - return_wav=False uses PROXY automatically!
        _, mixed_mag = self.encoder(masks, return_wav=False)
        
        # Calculate losses (matches backend - uses proxy refs)
        diff = torch.abs(mixed_mag - self.encoder.image_mag_ref)
        raw_loss_vis = diff.mean(dim=(1, 2))
        raw_loss_aud = calc_audio_mag_loss(mixed_mag, self.encoder.audio_mag)
        
        # Normalize
        loss_vis = raw_loss_vis * self.scale_vis
        loss_aud = raw_loss_aud * self.scale_aud
        
        # Scalarized loss with 20x visual boost
        total_loss = torch.sum(
            self.weights_img * loss_vis * 20.0 + self.weights_aud * loss_aud
        )
        
        # Backward and step
        total_loss.backward()
        self.optimizer.step()
        
        # Force clamp anchors
        if self.pop_size >= 2:
            with torch.no_grad():
                self.mask_logits[0].fill_(20.0)
                self.mask_logits[-1].fill_(-20.0)
        
        return loss_vis.detach().mean().item(), loss_aud.detach().mean().item()
    
    def get_current_front(self):
        """Get current population losses for plotting."""
        with torch.no_grad():
            masks = torch.sigmoid(self.mask_logits)
            _, mixed_mag = self.encoder(masks, return_wav=False)
            
            diff = torch.abs(mixed_mag - self.encoder.image_mag_ref)
            loss_vis = diff.mean(dim=(1, 2)) * self.scale_vis
            loss_aud = calc_audio_mag_loss(mixed_mag, self.encoder.audio_mag) * self.scale_aud
            
            return loss_vis.numpy(), loss_aud.numpy()

## 7. Single-Objective Optimization

This matches the frontend exactly when you adjust the "Better Sound" / "Better Image" sliders.

The weights are passed directly as `(w_visual, w_audio)` and normalized to sum to 1.

In [None]:
def single_objective_optimize(encoder, w_vis=0.5, w_aud=0.5, steps=SINGLE_STEPS, lr=0.05):
    """
    Single-objective optimization with specified weights.
    
    This matches the frontend EXACTLY for the three presets:
    - "Better Image": w_vis=0.8, w_aud=0.2
    - "Balanced":     w_vis=0.5, w_aud=0.5
    - "Better Sound": w_vis=0.2, w_aud=0.8
    
    Args:
        encoder: MaskEncoder instance
        w_vis: Visual weight (0-1)
        w_aud: Audio weight (0-1)
        steps: Number of optimization steps (default: 500)
        lr: Learning rate
    
    Returns:
        mask: Optimized mask tensor
        history: List of (vis_loss, aud_loss) tuples
    """
    # Create single-individual manager
    manager = ParetoManager(encoder, pop_size=1, learning_rate=lr)
    manager.calculate_normalization()
    
    # Override weights (normalize to sum to 1)
    total_weight = w_vis + w_aud
    w_vis_norm = w_vis / total_weight
    w_aud_norm = w_aud / total_weight
    
    manager.weights_img = torch.tensor([w_vis_norm])
    manager.weights_aud = torch.tensor([w_aud_norm])
    
    # Initialize mask to neutral (logit = 0 ‚Üí sigmoid = 0.5)
    with torch.no_grad():
        manager.mask_logits.data.zero_()
    
    history = []
    
    print(f"Optimizing with weights: Visual={w_vis_norm:.2f}, Audio={w_aud_norm:.2f}")
    print(f"Steps: {steps}, Learning rate: {lr}")
    
    for step in range(1, steps + 1):
        avg_vis, avg_aud = manager.optimize_step()
        history.append((avg_vis, avg_aud))
        
        if step % 100 == 0 or step == 1:
            print(f"  Step {step:4d}: Visual={avg_vis:.4f}, Audio={avg_aud:.4f}")
    
    # Get final mask
    final_mask = torch.sigmoid(manager.mask_logits).detach()
    
    return final_mask, history

In [None]:
# Run the three frontend presets
import time

print("=" * 60)
print("üé® VISUAL PRIORITY (80/20) - matches 'Better Image' button")
print("=" * 60)
start = time.time()
mask_visual, hist_visual = single_objective_optimize(encoder, w_vis=0.8, w_aud=0.2)
print(f"Completed in {time.time() - start:.1f}s\n")

print("=" * 60)
print("‚öñÔ∏è  BALANCED (50/50)")
print("=" * 60)
start = time.time()
mask_balanced, hist_balanced = single_objective_optimize(encoder, w_vis=0.5, w_aud=0.5)
print(f"Completed in {time.time() - start:.1f}s\n")

print("=" * 60)
print("üîä AUDIO PRIORITY (20/80) - matches 'Better Sound' button")
print("=" * 60)
start = time.time()
mask_audio, hist_audio = single_objective_optimize(encoder, w_vis=0.2, w_aud=0.8)
print(f"Completed in {time.time() - start:.1f}s\n")

In [None]:
# Plot convergence
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, hist, title, color in [
    (axes[0], hist_visual, 'Visual Priority (80/20)', 'purple'),
    (axes[1], hist_balanced, 'Balanced (50/50)', 'gray'),
    (axes[2], hist_audio, 'Audio Priority (20/80)', 'green'),
]:
    vis_losses = [h[0] for h in hist]
    aud_losses = [h[1] for h in hist]
    
    ax.plot(vis_losses, label='Visual Loss', color='purple', alpha=0.8)
    ax.plot(aud_losses, label='Audio Loss', color='green', alpha=0.8)
    ax.set_xlabel('Step')
    ax.set_ylabel('Normalized Loss')
    ax.set_title(title)
    ax.legend()
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

results = [
    (mask_visual, 'Visual Priority (80/20)'),
    (mask_balanced, 'Balanced (50/50)'),
    (mask_audio, 'Audio Priority (20/80)')
]

for i, (mask, title) in enumerate(results):
    # Generate full-resolution output
    _, mixed_mag = encoder(mask, return_wav=True)  # return_wav=True = full resolution
    
    # Mask visualization
    axes[0, i].imshow(
        mask.view(GRID_HEIGHT, GRID_WIDTH).numpy(), 
        cmap='gray', vmin=0, vmax=1
    )
    axes[0, i].set_title(f'{title}\nMask (white=image, black=audio)')
    axes[0, i].axis('off')
    
    # Spectrogram (with frontend-matching scaling)
    plot_spectrogram(axes[1, i], mixed_mag[0], 'Result Spectrogram')
    axes[1, i].set_xlabel('Time')
    axes[1, i].set_ylabel('Frequency')

plt.tight_layout()
plt.show()

In [None]:
# Listen to results
print("üîä AUDIO PRIORITY (20/80) - Best sound quality:")
wav_audio, _ = encoder(mask_audio, return_wav=True)
display(Audio(wav_audio[0].numpy(), rate=SAMPLE_RATE))

print("\n‚öñÔ∏è  BALANCED (50/50):")
wav_balanced, _ = encoder(mask_balanced, return_wav=True)
display(Audio(wav_balanced[0].numpy(), rate=SAMPLE_RATE))

print("\nüé® VISUAL PRIORITY (80/20) - Best visual fidelity:")
wav_visual, _ = encoder(mask_visual, return_wav=True)
display(Audio(wav_visual[0].numpy(), rate=SAMPLE_RATE))

## 8. Full Pareto Frontier

This matches the "Map Full Pareto Frontier" button in the frontend EXACTLY.

### Hybrid Optimization Strategy

1. **Phase 1: Gradient Seeding** (200 steps, population=10)
   - Run parallel Adam optimization with power-law weight distribution
   - Force anchor solutions at extremes (pure image, pure audio)
   - Fast exploration of the objective space

2. **Phase 2: Evolutionary Expansion** (50 generations, population=100)
   - NSGA-II with SBX crossover and polynomial mutation
   - Refines and expands the Pareto front
   - Non-dominated sorting ensures quality

In [None]:
def run_pareto_optimization(encoder, 
                            seed_steps=PARETO_SEED_STEPS,
                            seed_pop=PARETO_SEED_POP,
                            n_gen=PARETO_GENERATIONS,
                            evol_pop=PARETO_EVOL_POP):
    """
    Full Pareto frontier optimization matching the frontend exactly.
    
    Args:
        encoder: MaskEncoder instance
        seed_steps: Gradient seeding steps (default: 200)
        seed_pop: Seeding population size (default: 10)
        n_gen: NSGA-II generations (default: 50)
        evol_pop: Evolution population size (default: 100)
    
    Returns:
        final_X: (N, params) Pareto-optimal mask parameters
        final_F: (N, 2) Pareto-optimal objective values
        full_history: List of population snapshots for animation
    """
    from pymoo.core.problem import Problem
    from pymoo.core.callback import Callback
    from pymoo.algorithms.moo.nsga2 import NSGA2
    from pymoo.optimize import minimize
    from pymoo.operators.crossover.sbx import SBX
    from pymoo.operators.mutation.pm import PM
    
    # ========================================================================
    # Phase 1: Gradient Seeding
    # ========================================================================
    print("=" * 60)
    print("PHASE 1: GRADIENT SEEDING")
    print(f"Population: {seed_pop}, Steps: {seed_steps}")
    print("=" * 60)
    
    manager = ParetoManager(encoder, pop_size=seed_pop, learning_rate=0.05)
    manager.calculate_normalization()
    
    phase1_history = []
    
    for step in range(1, seed_steps + 1):
        manager.optimize_step()
        
        # Record history every 5 steps (for animation)
        if step % 5 == 0:
            vis, aud = manager.get_current_front()
            phase1_history.append(np.column_stack([vis, aud]))
        
        if step % 50 == 0:
            vis, aud = manager.get_current_front()
            print(f"  Step {step:4d}: Vis range [{vis.min():.3f}, {vis.max():.3f}], "
                  f"Aud range [{aud.min():.3f}, {aud.max():.3f}]")
    
    # Extract seeds
    with torch.no_grad():
        seed_masks = torch.sigmoid(manager.mask_logits).cpu().numpy()
    
    print(f"\nGenerated {len(seed_masks)} seed solutions")
    
    # ========================================================================
    # Phase 2: Evolutionary Expansion
    # ========================================================================
    print("\n" + "=" * 60)
    print("PHASE 2: EVOLUTIONARY EXPANSION (NSGA-II)")
    print(f"Population: {evol_pop}, Generations: {n_gen}")
    print("=" * 60)
    
    class MOSSProblem(Problem):
        def __init__(self, enc, scale_vis, scale_aud):
            self.enc = enc
            self.scale_vis = scale_vis
            self.scale_aud = scale_aud
            super().__init__(
                n_var=enc.grid_height * enc.grid_width,
                n_obj=2,
                xl=0.0,
                xu=1.0
            )
        
        def _evaluate(self, x, out, *args, **kwargs):
            with torch.no_grad():
                masks = torch.from_numpy(x).float()
                _, mixed_mag = self.enc(masks, return_wav=False)  # proxy automatically
                
                vis = calc_visual_loss(mixed_mag, self.enc.image_mag_ref).numpy()
                aud = calc_audio_mag_loss(mixed_mag, self.enc.audio_mag).numpy()
                
                vis = vis * self.scale_vis
                aud = aud * self.scale_aud
                
                out['F'] = np.column_stack([vis, aud])
    
    class HistoryCallback(Callback):
        def __init__(self):
            super().__init__()
            self.history = []
        
        def notify(self, algorithm):
            # Record surviving population (not just offspring)
            self.history.append(algorithm.pop.get('F').copy())
    
    problem = MOSSProblem(encoder, manager.scale_vis, manager.scale_aud)
    
    # Initialize with seeds + random
    n_random = evol_pop - len(seed_masks)
    X_init = np.vstack([
        seed_masks,
        np.random.rand(n_random, problem.n_var)
    ])
    
    algorithm = NSGA2(
        pop_size=evol_pop,
        sampling=X_init,
        crossover=SBX(eta=15, prob=0.9),
        mutation=PM(eta=20),
        eliminate_duplicates=True
    )
    
    callback = HistoryCallback()
    
    result = minimize(
        problem,
        algorithm,
        ('n_gen', n_gen),
        callback=callback,
        verbose=True
    )
    
    full_history = phase1_history + callback.history
    
    print(f"\n‚úÖ Found {len(result.F)} Pareto-optimal solutions")
    
    return result.X, result.F, full_history

In [None]:
# Run full Pareto optimization
start = time.time()
pareto_X, pareto_F, pareto_history = run_pareto_optimization(encoder)
print(f"\nTotal time: {time.time() - start:.1f}s")

In [None]:
# Visualize evolution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Final Pareto front
axes[0].scatter(
    pareto_F[:, 0], pareto_F[:, 1],
    c='#4ade80', s=60, edgecolors='white', linewidths=1,
    label='Final Pareto Front', zorder=10
)

# Seeds for reference
if pareto_history:
    seeds = pareto_history[len(pareto_history) // 4]  # Early seeding
    axes[0].scatter(
        seeds[:, 0], seeds[:, 1],
        c='purple', s=40, alpha=0.3, label='Seeding Phase'
    )

axes[0].set_xlabel('Visual Loss (normalized) ‚Üì')
axes[0].set_ylabel('Audio Loss (normalized) ‚Üì')
axes[0].set_title('Pareto Frontier: Visual vs Audio Quality')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Evolution over time
n_hist = len(pareto_history)
phase1_len = PARETO_SEED_STEPS // 5  # History recorded every 5 steps

for i, frame in enumerate(pareto_history):
    if i < phase1_len:
        color = 'purple'
        alpha = 0.1 + 0.3 * (i / phase1_len)
    else:
        color = '#4ade80'
        alpha = 0.2 + 0.6 * ((i - phase1_len) / (n_hist - phase1_len))
    
    axes[1].scatter(frame[:, 0], frame[:, 1], c=color, alpha=alpha, s=20)

axes[1].set_xlabel('Visual Loss (normalized) ‚Üì')
axes[1].set_ylabel('Audio Loss (normalized) ‚Üì')
axes[1].set_title('Evolution Over Time (purple=seeding, green=evolution)')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Animated Evolution (Matches Frontend)

This animation shows the optimization process with:
- **Phase 1 (Purple)**: Gradient seeding with smooth point transitions
- **Phase 2 (Green)**: NSGA-II evolution with fade trails

In [None]:
from scipy.spatial.distance import cdist
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

def create_pareto_animation(history, phase1_frames=40):
    """
    Create an animated visualization of Pareto optimization.
    Matches frontend/backend animation style (light mode).
    """
    if not history or len(history) == 0:
        print("No history to animate")
        return None
    
    # Setup Figure (light mode version)
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.set_title("Pareto Optimization Animation", fontsize=14, pad=10)
    ax.set_xlabel("Visual Loss (normalized) ‚Üì")
    ax.set_ylabel("Audio Loss (normalized) ‚Üì")
    ax.grid(True, alpha=0.3)
    
    # Find global bounds
    all_points = np.vstack(history)
    min_x, max_x = all_points[:, 0].min(), all_points[:, 0].max()
    min_y, max_y = all_points[:, 1].min(), all_points[:, 1].max()
    
    pad_x = (max_x - min_x) * 0.1 if max_x > min_x else 0.1
    pad_y = (max_y - min_y) * 0.1 if max_y > min_y else 0.1
    
    ax.set_xlim(min_x - pad_x, max_x + pad_x)
    ax.set_ylim(min_y - pad_y, max_y + pad_y)
    
    # Scatter for current frame
    scat = ax.scatter([], [], c='#a855f7', alpha=0.8, s=60, 
                      edgecolors='white', linewidths=0.5)
    
    # Trail scatters for fade effect
    trail_depth = 5
    trail_scats = [ax.scatter([], [], c='#4ade80', alpha=0, s=40) 
                   for _ in range(trail_depth)]
    
    text = ax.text(0.02, 0.98, '', transform=ax.transAxes, fontsize=10, va='top')
    
    prev_positions = [None]  # Use list to allow mutation in closure
    
    def match_points(prev, curr):
        """Match points between frames for smooth transition."""
        if prev is None or len(prev) != len(curr):
            return curr
        dist = cdist(prev, curr)
        matched = np.zeros_like(curr)
        used = set()
        for i in range(len(prev)):
            dists = dist[i].copy()
            dists[list(used)] = np.inf
            j = np.argmin(dists)
            matched[i] = curr[j]
            used.add(j)
        return matched
    
    def update(frame):
        if frame >= len(history):
            return (scat, text, *trail_scats)
        
        data = history[frame]
        
        # Smooth transition during seeding phase
        if frame < phase1_frames:
            data = match_points(prev_positions[0], data)
        prev_positions[0] = data.copy()
        
        scat.set_offsets(data)
        
        # Phase-based coloring
        if frame < phase1_frames:
            scat.set_facecolors('#a855f7')  # Purple for seeding
            text.set_text(f'Phase 1: Gradient Seeding (Step {frame * 5})')
            for ts in trail_scats:
                ts.set_offsets(np.empty((0, 2)))
        else:
            scat.set_facecolors('#4ade80')  # Green for evolution
            gen = frame - phase1_frames
            text.set_text(f'Phase 2: Evolutionary (Gen {gen})')
            
            # Update trails with fade effect
            for i, ts in enumerate(trail_scats):
                trail_frame = frame - (i + 1)
                if trail_frame >= phase1_frames and trail_frame < len(history):
                    ts.set_offsets(history[trail_frame])
                    alpha = 0.3 * (1 - (i / trail_depth))
                    ts.set_alpha(alpha)
                else:
                    ts.set_offsets(np.empty((0, 2)))
        
        return (scat, text, *trail_scats)
    
    ani = FuncAnimation(fig, update, frames=len(history), interval=100, blit=True)
    plt.close(fig)  # Prevent static display
    
    return ani

# Create and display animation
phase1_len = PARETO_SEED_STEPS // 5
ani = create_pareto_animation(pareto_history, phase1_frames=phase1_len)
if ani:
    display(HTML(ani.to_jshtml()))

In [None]:
# Explore the Pareto front
sorted_idx = np.argsort(pareto_F[:, 0])

# Sample 5 points along the front
n_samples = min(5, len(sorted_idx))
sample_indices = [sorted_idx[int(i * (len(sorted_idx) - 1) / (n_samples - 1))] 
                  for i in range(n_samples)]

fig, axes = plt.subplots(2, n_samples, figsize=(4 * n_samples, 8))

for col, idx in enumerate(sample_indices):
    mask = torch.from_numpy(pareto_X[idx:idx+1]).float()
    wav, mixed_mag = encoder(mask, return_wav=True)
    
    # Mask
    axes[0, col].imshow(
        mask.view(GRID_HEIGHT, GRID_WIDTH).numpy(),
        cmap='gray', vmin=0, vmax=1
    )
    axes[0, col].set_title(f'V:{pareto_F[idx,0]:.3f}, A:{pareto_F[idx,1]:.3f}')
    axes[0, col].axis('off')
    
    # Spectrogram (frontend-matching scaling)
    plot_spectrogram(axes[1, col], mixed_mag[0])
    axes[1, col].axis('off')

plt.suptitle('Samples Across the Pareto Front (Left=Best Visual, Right=Best Audio)')
plt.tight_layout()
plt.show()

# Listen to samples
for i, idx in enumerate(sample_indices):
    mask = torch.from_numpy(pareto_X[idx:idx+1]).float()
    wav, _ = encoder(mask, return_wav=True)
    print(f"\nüéµ Sample {i+1}: Visual={pareto_F[idx,0]:.3f}, Audio={pareto_F[idx,1]:.3f}")
    display(Audio(wav[0].numpy(), rate=SAMPLE_RATE))

## 9. Summary

### What We Learned

1. **Mask-based blending** allows smooth control between image and audio fidelity
2. **Proxy optimization** (1/4 resolution) provides 4√ó speedup with minimal quality loss
3. **20√ó visual boost** is necessary to make the image visible
4. **Power-law weights** counteract the boost for balanced Pareto coverage
5. **Hybrid optimization** (gradient + evolutionary) provides thorough exploration

### Design Decisions

| Parameter | Value | Rationale |
|-----------|-------|-----------|
| Sample rate | 16 kHz | Fast processing, sufficient for demos |
| Grid size | 64√ó128 | Balance between control and optimization speed |
| Smoothing œÉ | 5.0 | Prevents harsh artifacts, ensures musicality |
| Visual boost | 20√ó | Compensates for audio loss dominance |
| Weight power | 4 | Counteracts 20√ó boost for balanced coverage |

---

*Made with üíú by [Vojtƒõch Kucha≈ô](https://github.com/kuchar-one)*