# ðŸŽµ MOSS: Multi-Objective Sound Synthesis

## Pareto Optimization for Audio-Visual Spectrogram Blending

This notebook demonstrates the core algorithms behind MOSS - a system that creates audio spectrograms
that simultaneously look like a target image and sound like a target audio track.

**The Problem**: Given:
- A target **image** (e.g., Mona Lisa)
- A target **audio** (e.g., Tchaikovsky)

Find a spectrogram that:
1. **Looks like** the image (when visualized)
2. **Sounds like** the audio (when converted back to waveform)

This is inherently a **multi-objective optimization** problem with two conflicting goals.

## 1. Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image
from pytorch_msssim import ssim
from IPython.display import HTML, Audio, display
import warnings
warnings.filterwarnings('ignore')

# Configuration
SAMPLE_RATE = 16000  # 16kHz for efficiency
N_FFT = 1024         # 513 frequency bins
HOP_LENGTH = 256     # 75% overlap
WIN_LENGTH = N_FFT
DEVICE = 'cpu'

print(f"PyTorch: {torch.__version__}")
print(f"Device: {DEVICE}")

## 2. Core Algorithm: Mask-Based Spectrogram Blending

The key insight is that we can **blend** two spectrograms using a learnable mask:

$$\text{Mixed} = \text{Mask} \cdot \text{Image} + (1 - \text{Mask}) \cdot \text{Audio}$$

Where:
- **Mask** âˆˆ [0, 1] for each time-frequency bin
- **Mask = 1**: Show the image (visual priority)
- **Mask = 0**: Keep the audio (audio priority)

In [None]:
def gaussian_blur_2d(x, sigma=1.0):
    """Apply Gaussian smoothing to enforce musicality constraints."""
    if sigma < 0.5:
        return x
    
    kernel_size = int(6 * sigma) | 1  # Ensure odd
    kernel_size = max(3, kernel_size)
    
    # Create 1D Gaussian kernel
    coords = torch.arange(kernel_size, dtype=x.dtype, device=x.device) - kernel_size // 2
    g = torch.exp(-coords**2 / (2 * sigma**2))
    g = g / g.sum()
    
    # Separable 2D convolution
    pad = kernel_size // 2
    if x.dim() == 2:
        x = x.unsqueeze(0).unsqueeze(0)
    elif x.dim() == 3:
        x = x.unsqueeze(1)
    
    x = F.conv2d(F.pad(x, (pad, pad, 0, 0), mode='replicate'), 
                 g.view(1, 1, 1, -1), padding=0)
    x = F.conv2d(F.pad(x, (0, 0, pad, pad), mode='replicate'), 
                 g.view(1, 1, -1, 1), padding=0)
    
    return x.squeeze()

## 3. Loss Functions

### Visual Loss: SSIM (Structural Similarity)
We compare the **visual appearance** of the mixed spectrogram to the target image using SSIM,
which captures structural patterns better than pixel-wise L1/L2.

### Audio Loss: Log-Domain L1
We compare audio in the **log-magnitude domain** (dB-like), which matches human perception
where a 10x difference in power = 10dB = perceptually "twice as loud".

In [None]:
def calc_image_loss(mixed_mag, target_image_01):
    """SSIM-based visual similarity in log domain."""
    mixed_log = torch.log(mixed_mag + 1e-8)
    
    # Normalize to [0, 1]
    flat = mixed_log.flatten(start_dim=-2)
    min_val = flat.min(dim=-1, keepdim=True)[0].unsqueeze(-1)
    max_val = flat.max(dim=-1, keepdim=True)[0].unsqueeze(-1)
    mixed_norm = (mixed_log - min_val) / (max_val - min_val + 1e-8)
    
    target_norm = target_image_01.expand_as(mixed_norm)
    
    # Add channel dim for SSIM
    mixed_norm = mixed_norm.unsqueeze(1)
    target_norm = target_norm.unsqueeze(1)
    
    ssim_vals = []
    for i in range(mixed_mag.shape[0]):
        s = ssim(mixed_norm[i:i+1], target_norm[i:i+1], data_range=1.0, size_average=True)
        ssim_vals.append(1.0 - s)  # Loss = 1 - SSIM
    
    return torch.stack(ssim_vals)


def calc_audio_loss(mixed_mag, target_audio_mag):
    """L1 loss in log-magnitude domain (perceptually weighted)."""
    mixed_log = torch.log(mixed_mag + 1e-8)
    target_log = torch.log(target_audio_mag + 1e-8)
    target_log = target_log.expand_as(mixed_log)
    
    loss = F.l1_loss(mixed_log, target_log, reduction='none').mean(dim=(1, 2))
    return loss

## 4. The Encoder: Converting Mask to Audio

In [None]:
class MaskEncoder(nn.Module):
    """Encodes a mask grid into a blended spectrogram and reconstructed audio."""
    
    def __init__(self, target_image, audio_path, grid_h=64, grid_w=128, sigma=5.0):
        super().__init__()
        self.grid_height = grid_h
        self.grid_width = grid_w
        self.sigma = sigma
        
        # Load audio
        audio, sr = torchaudio.load(audio_path)
        if sr != SAMPLE_RATE:
            audio = torchaudio.functional.resample(audio, sr, SAMPLE_RATE)
        audio = audio.mean(dim=0, keepdim=True)  # Mono
        self.register_buffer('audio_waveform', audio)
        
        # Compute STFT
        window = torch.hann_window(N_FFT)
        stft = torch.stft(audio, n_fft=N_FFT, hop_length=HOP_LENGTH, 
                          win_length=WIN_LENGTH, window=window, return_complex=True)
        
        self.audio_mag = stft.abs() + 1e-8
        self.audio_phase = stft.angle()
        self.full_h, self.full_w = self.audio_mag.shape[1], self.audio_mag.shape[2]
        
        # Process image
        img = target_image.unsqueeze(0) if target_image.dim() == 3 else target_image
        img = F.interpolate(img, size=(self.full_h, self.full_w), mode='bilinear', align_corners=False)
        
        # Map image to audio dynamic range
        audio_log = torch.log(self.audio_mag)
        audio_log_max = torch.quantile(audio_log, 0.98)
        audio_log_floor = torch.quantile(audio_log, 0.01)
        
        img_01 = img.squeeze(0)
        img_01 = (img_01 - img_01.min()) / (img_01.max() - img_01.min() + 1e-8)
        img_01 = img_01.pow(1.8)  # Gamma correction
        
        image_log = img_01 * (audio_log_max - audio_log_floor) + audio_log_floor
        self.image_mag = torch.exp(image_log)
        
        # Reference for loss calculation
        self.image_mag_ref = self.image_mag
        self.audio_mag_ref = self.audio_mag
        
    def forward(self, mask_flat, return_wav=True):
        """Forward pass: mask -> blended spectrogram -> audio."""
        B = mask_flat.shape[0]
        
        # Reshape mask
        mask = mask_flat.view(B, self.grid_height, self.grid_width)
        
        # Upsample to full resolution
        mask_up = F.interpolate(mask.unsqueeze(1), size=(self.full_h, self.full_w),
                                mode='bilinear', align_corners=False).squeeze(1)
        
        # Apply Gaussian smoothing (forces musical smoothness)
        mask_smooth = gaussian_blur_2d(mask_up, self.sigma)
        mask_smooth = mask_smooth.clamp(0, 1)
        
        # Blend spectrograms
        mixed_mag = mask_smooth * self.image_mag + (1 - mask_smooth) * self.audio_mag
        
        if not return_wav:
            return None, mixed_mag
        
        # Reconstruct audio via ISTFT
        mixed_complex = mixed_mag * torch.exp(1j * self.audio_phase)
        window = torch.hann_window(N_FFT, device=mixed_mag.device)
        
        wavs = []
        for i in range(B):
            wav = torch.istft(mixed_complex[i], n_fft=N_FFT, hop_length=HOP_LENGTH,
                              win_length=WIN_LENGTH, window=window)
            wavs.append(wav)
        
        return torch.stack(wavs), mixed_mag

## 5. Load Data: Mona Lisa + Tchaikovsky

In [None]:
# Load image
img_path = 'data/input/monalisa.jpg'
img = Image.open(img_path).convert('L')  # Grayscale
img_tensor = torch.from_numpy(np.array(img)).float() / 255.0
img_tensor = img_tensor.unsqueeze(0)  # Add channel dim

# Audio path
audio_path = 'data/input/tchaikovsky.mp3'

# Create encoder
encoder = MaskEncoder(img_tensor, audio_path, grid_h=64, grid_w=128, sigma=5.0)

print(f"Spectrogram shape: {encoder.full_h} x {encoder.full_w}")
print(f"Control grid: {encoder.grid_height} x {encoder.grid_width}")
print(f"Total parameters: {encoder.grid_height * encoder.grid_width}")

In [None]:
# Visualize inputs
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Target image
axes[0].imshow(img, cmap='gray')
axes[0].set_title('Target Image: Mona Lisa')
axes[0].axis('off')

# Image as spectrogram
axes[1].imshow(20*torch.log10(encoder.image_mag[0]+1e-8).numpy(), 
               origin='lower', aspect='auto', cmap='magma')
axes[1].set_title('Image â†’ Spectrogram')
axes[1].set_xlabel('Time')
axes[1].set_ylabel('Frequency')

# Audio spectrogram
axes[2].imshow(20*torch.log10(encoder.audio_mag[0]+1e-8).numpy(), 
               origin='lower', aspect='auto', cmap='magma')
axes[2].set_title('Audio: Tchaikovsky')
axes[2].set_xlabel('Time')

plt.tight_layout()
plt.show()

## 6. Single-Objective Optimization

First, let's optimize for a **single weighted combination** of the two objectives:

$$\mathcal{L}_{total} = w_{vis} \cdot \mathcal{L}_{visual} + w_{aud} \cdot \mathcal{L}_{audio}$$

In [None]:
def single_objective_optimize(encoder, w_vis=0.5, w_aud=0.5, steps=200, lr=0.1):
    """Optimize a single mask for a weighted combination of objectives."""
    
    # Initialize mask in logit space (unbounded, sigmoid to [0,1])
    mask_logits = nn.Parameter(torch.zeros(1, encoder.grid_height * encoder.grid_width))
    optimizer = optim.Adam([mask_logits], lr=lr)
    
    history = []
    
    for step in range(steps):
        optimizer.zero_grad()
        
        mask = torch.sigmoid(mask_logits)
        _, mixed_mag = encoder(mask, return_wav=False)
        
        loss_vis = calc_image_loss(mixed_mag, encoder.image_mag_ref).mean()
        loss_aud = calc_audio_loss(mixed_mag, encoder.audio_mag_ref).mean()
        
        total_loss = w_vis * loss_vis + w_aud * loss_aud
        total_loss.backward()
        optimizer.step()
        
        history.append([loss_vis.item(), loss_aud.item()])
        
        if step % 50 == 0:
            print(f"Step {step}: Vis={loss_vis.item():.4f}, Aud={loss_aud.item():.4f}")
    
    return torch.sigmoid(mask_logits).detach(), np.array(history)

In [None]:
# Run three different weightings
print("\nðŸ“Š Optimizing: Visual Priority (80/20)")
mask_visual, hist_visual = single_objective_optimize(encoder, w_vis=0.8, w_aud=0.2, steps=200)

print("\nðŸ“Š Optimizing: Balanced (50/50)")
mask_balanced, hist_balanced = single_objective_optimize(encoder, w_vis=0.5, w_aud=0.5, steps=200)

print("\nðŸ“Š Optimizing: Audio Priority (20/80)")
mask_audio, hist_audio = single_objective_optimize(encoder, w_vis=0.2, w_aud=0.8, steps=200)

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

masks = [mask_visual, mask_balanced, mask_audio]
titles = ['Visual Priority', 'Balanced', 'Audio Priority']

for i, (mask, title) in enumerate(zip(masks, titles)):
    _, mixed_mag = encoder(mask, return_wav=False)
    
    # Show mask
    axes[0, i].imshow(mask.view(64, 128).numpy(), cmap='gray', vmin=0, vmax=1)
    axes[0, i].set_title(f'{title}\nMask')
    axes[0, i].axis('off')
    
    # Show spectrogram
    axes[1, i].imshow(20*torch.log10(mixed_mag[0]+1e-8).numpy(),
                      origin='lower', aspect='auto', cmap='magma')
    axes[1, i].set_title('Result Spectrogram')
    axes[1, i].set_xlabel('Time')

plt.tight_layout()
plt.show()

In [None]:
# Listen to results
print("ðŸ”Š Audio Priority Result:")
wav_audio, _ = encoder(mask_audio, return_wav=True)
display(Audio(wav_audio[0].numpy(), rate=SAMPLE_RATE))

print("\nðŸ”Š Balanced Result:")
wav_balanced, _ = encoder(mask_balanced, return_wav=True)
display(Audio(wav_balanced[0].numpy(), rate=SAMPLE_RATE))

print("\nðŸ”Š Visual Priority Result:")
wav_visual, _ = encoder(mask_visual, return_wav=True)
display(Audio(wav_visual[0].numpy(), rate=SAMPLE_RATE))

## 7. The Pareto Frontier: Multi-Objective Optimization

Instead of choosing a single weighting, we can find the **entire Pareto frontier** - 
all solutions where you cannot improve one objective without worsening the other.

### Our Hybrid Approach:
1. **Gradient Seeding** (200 steps): Run Adam with different weight combinations in parallel
2. **Evolutionary Expansion** (NSGA-II): Use genetic operators to explore the frontier

In [None]:
def gradient_seeding(encoder, pop_size=10, steps=200, lr=0.05):
    """
    Phase 1: Generate diverse seeds using parallel gradient descent.
    Uses power-law weight distribution to counteract visual loss dominance.
    """
    # Initialize population in logit space
    n_params = encoder.grid_height * encoder.grid_width
    mask_logits = nn.Parameter(torch.randn(pop_size, n_params) * 0.5)
    
    # Power-law weights to balance visual (20x boosted) vs audio
    alpha = torch.linspace(0, 1, pop_size)
    weights_vis = (1.0 - alpha).pow(4.0)  # Bunch towards low visual weight
    weights_aud = 1.0 - weights_vis
    total = weights_vis + weights_aud
    weights_vis /= total
    weights_aud /= total
    
    # Force anchors
    with torch.no_grad():
        mask_logits[0].fill_(10.0)   # Pure image
        mask_logits[-1].fill_(-10.0) # Pure audio
    
    optimizer = optim.Adam([mask_logits], lr=lr)
    history = []
    
    for step in range(steps):
        optimizer.zero_grad()
        
        masks = torch.sigmoid(mask_logits)
        _, mixed_mag = encoder(masks, return_wav=False)
        
        # Calculate losses
        diff = torch.abs(mixed_mag - encoder.image_mag_ref)
        loss_vis = diff.mean(dim=(1, 2))
        loss_aud = calc_audio_loss(mixed_mag, encoder.audio_mag_ref)
        
        # Weighted sum (20x boost on visual for gradient direction)
        total_loss = (weights_vis * loss_vis * 20.0 + weights_aud * loss_aud).sum()
        total_loss.backward()
        optimizer.step()
        
        # Force anchor masks
        with torch.no_grad():
            mask_logits[0].fill_(20.0)
            mask_logits[-1].fill_(-20.0)
        
        if step % 10 == 0:
            history.append(np.column_stack([loss_vis.detach().numpy(), 
                                           loss_aud.detach().numpy()]))
    
    return torch.sigmoid(mask_logits).detach().numpy(), history

print("ðŸŒ± Running Gradient Seeding...")
seed_masks, seed_history = gradient_seeding(encoder, pop_size=10, steps=200)
print(f"Generated {len(seed_masks)} seed solutions")

In [None]:
# Visualize seeding progress
plt.figure(figsize=(10, 6))

for i, frame in enumerate(seed_history):
    alpha = 0.3 + 0.7 * (i / len(seed_history))
    plt.scatter(frame[:, 0], frame[:, 1], c='purple', alpha=alpha, s=40)

plt.xlabel('Visual Loss')
plt.ylabel('Audio Loss')
plt.title('Phase 1: Gradient Seeding Progress')
plt.grid(True, alpha=0.3)
plt.show()

### Phase 2: Evolutionary Expansion (NSGA-II)

NSGA-II is a state-of-the-art multi-objective genetic algorithm that:
1. Uses **non-dominated sorting** to rank solutions
2. Maintains **diversity** via crowding distance
3. Applies **crossover** (SBX) and **mutation** (PM) operators

In [None]:
from pymoo.core.problem import Problem
from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.optimize import minimize
from pymoo.operators.crossover.sbx import SBX
from pymoo.operators.mutation.pm import PM
from pymoo.core.callback import Callback


class MOSSProblem(Problem):
    """Pymoo problem definition for MOSS."""
    
    def __init__(self, encoder):
        self.encoder = encoder
        n_params = encoder.grid_height * encoder.grid_width
        super().__init__(n_var=n_params, n_obj=2, xl=0.0, xu=1.0)
    
    def _evaluate(self, x, out, *args, **kwargs):
        with torch.no_grad():
            mask = torch.from_numpy(x).float()
            _, mixed_mag = self.encoder(mask, return_wav=False)
            
            diff = torch.abs(mixed_mag - self.encoder.image_mag_ref)
            loss_vis = diff.mean(dim=(1, 2)).numpy()
            loss_aud = calc_audio_loss(mixed_mag, self.encoder.audio_mag_ref).numpy()
            
            out['F'] = np.column_stack([loss_vis, loss_aud])


class HistoryCallback(Callback):
    def __init__(self):
        super().__init__()
        self.history = []
    
    def notify(self, algorithm):
        self.history.append(algorithm.pop.get('F').copy())

In [None]:
# Run NSGA-II
print("ï¿½ï¿½ Running NSGA-II Evolution...")

problem = MOSSProblem(encoder)

# Initialize with seeds + random
n_random = 100 - len(seed_masks)
X_init = np.vstack([seed_masks, np.random.rand(n_random, problem.n_var)])

algorithm = NSGA2(
    pop_size=100,
    sampling=X_init,
    crossover=SBX(eta=15, prob=0.9),
    mutation=PM(eta=20),
    eliminate_duplicates=True
)

callback = HistoryCallback()

result = minimize(
    problem,
    algorithm,
    ('n_gen', 50),
    callback=callback,
    verbose=True
)

print(f"\nâœ… Found {len(result.F)} Pareto-optimal solutions")

In [None]:
# Visualize final Pareto front
plt.figure(figsize=(10, 6))

# Final front
plt.scatter(result.F[:, 0], result.F[:, 1], c='#4ade80', s=80, 
            edgecolors='white', linewidths=1, label='Pareto Front', zorder=10)

# Seeds for reference
seed_F = seed_history[-1] if seed_history else None
if seed_F is not None:
    plt.scatter(seed_F[:, 0], seed_F[:, 1], c='purple', s=60, alpha=0.5, label='Initial Seeds')

plt.xlabel('Visual Loss â†“')
plt.ylabel('Audio Loss â†“')
plt.title('Pareto Frontier: Visual vs Audio Quality')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 8. Animated Evolution

In [None]:
# Create animation
fig, ax = plt.subplots(figsize=(10, 6), facecolor='#09090b')
ax.set_facecolor('#09090b')
ax.set_title('Hybrid Pareto Optimization', color='white', fontsize=14)
ax.set_xlabel('Visual Loss', color='#666')
ax.set_ylabel('Audio Loss', color='#666')
ax.tick_params(colors='#444')
for spine in ax.spines.values():
    spine.set_color('#333')
ax.grid(True, alpha=0.15, color='#444')

# Compute bounds
all_F = np.vstack(seed_history + callback.history)
ax.set_xlim(all_F[:, 0].min() * 0.9, all_F[:, 0].max() * 1.1)
ax.set_ylim(all_F[:, 1].min() * 0.9, all_F[:, 1].max() * 1.1)

scat = ax.scatter([], [], c='#a855f7', s=60, edgecolors='white', linewidths=0.5)
text = ax.text(0.02, 0.98, '', transform=ax.transAxes, color='white', fontsize=10, va='top')

full_history = seed_history + callback.history
phase1_len = len(seed_history)

def update(frame):
    if frame >= len(full_history):
        return scat, text
    
    data = full_history[frame]
    scat.set_offsets(data)
    
    if frame < phase1_len:
        scat.set_facecolors('#a855f7')
        text.set_text(f'Phase 1: Gradient Seeding (Step {frame * 10})')
    else:
        scat.set_facecolors('#4ade80')
        text.set_text(f'Phase 2: Evolution (Gen {frame - phase1_len})')
    
    return scat, text

ani = animation.FuncAnimation(fig, update, frames=len(full_history), interval=100, blit=True)
plt.close()

# Display in notebook
HTML(ani.to_jshtml())

## 9. Explore the Pareto Front

Each point on the front represents a different trade-off. Let's listen to a few!

In [None]:
# Sort by visual loss
sorted_idx = np.argsort(result.F[:, 0])

# Pick samples from different parts of the front
samples = [sorted_idx[0], sorted_idx[len(sorted_idx)//2], sorted_idx[-1]]
labels = ['Best Visual', 'Balanced', 'Best Audio']

for idx, label in zip(samples, labels):
    mask = torch.from_numpy(result.X[idx:idx+1]).float()
    wav, mixed_mag = encoder(mask, return_wav=True)
    
    print(f"\nðŸŽµ {label} (Vis: {result.F[idx, 0]:.4f}, Aud: {result.F[idx, 1]:.4f})")
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 3))
    axes[0].imshow(mask.view(64, 128).numpy(), cmap='gray')
    axes[0].set_title('Mask')
    axes[0].axis('off')
    
    axes[1].imshow(20*torch.log10(mixed_mag[0]+1e-8).numpy(),
                   origin='lower', aspect='auto', cmap='magma')
    axes[1].set_title('Spectrogram')
    plt.tight_layout()
    plt.show()
    
    display(Audio(wav[0].numpy(), rate=SAMPLE_RATE))

## 10. Summary

### Key Design Decisions:

1. **Mask-Based Blending**: Instead of generating spectrograms from scratch, we blend
   existing image and audio spectrograms using a learnable mask. This preserves audio quality.

2. **Gaussian Smoothing**: Forces the mask to vary slowly in time-frequency, creating
   smoother, more musical transitions instead of harsh noise.

3. **SSIM for Visual Loss**: Captures structural patterns, not just pixel differences.
   A blurry but structurally correct image scores better than a sharp but wrong one.

4. **Log-Domain Audio Loss**: Matches human perception where equal dB differences feel
   equally significant regardless of absolute level.

5. **Hybrid Optimization**: Gradient descent finds initial seeds efficiently,
   while NSGA-II explores the trade-off space more thoroughly.

6. **Power-Law Weight Distribution**: Counteracts the natural dominance of visual loss
   to ensure balanced exploration of the Pareto front.

---
*Made with ðŸ’œ by VojtÄ›ch KuchaÅ™*