# Rate Distortion Theory Tutorial

## Introduction

**Rate Distortion Theory** is a fundamental concept in information theory developed by **Claude Shannon** in 1948-1959. It addresses the fundamental trade-off between the rate of data compression and the distortion (quality loss) that results from compression.

### Key Concepts:

1. **Rate (R)**: The number of bits per source symbol needed to represent the data after compression.
2. **Distortion (D)**: A measure of the difference between the original and reconstructed data.
3. **Rate-Distortion Function R(D)**: The minimum rate needed to achieve a given distortion level.

Shannon's key insight was that for any given distortion level D, there exists a theoretical minimum rate R(D) below which it's impossible to compress the data without exceeding that distortion.

## Mathematical Foundation

For a memoryless source with distribution $p(x)$, the rate-distortion function is defined as:

$$R(D) = \min_{p(\hat{x}|x): E[d(X,\hat{X})] \leq D} I(X;\hat{X})$$

where:
- $I(X;\hat{X})$ is the mutual information between the source $X$ and reconstruction $\hat{X}$
- $d(X,\hat{X})$ is the distortion measure
- $E[d(X,\hat{X})]$ is the expected distortion

### Gaussian Source Example

For a Gaussian source with variance $\sigma^2$ and squared error distortion, the rate-distortion function has a closed form:

$$R(D) = \begin{cases} 
\frac{1}{2}\log_2\left(\frac{\sigma^2}{D}\right) & \text{if } D < \sigma^2 \\
0 & \text{if } D \geq \sigma^2
\end{cases}$$

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import ipywidgets as widgets
from IPython.display import display

# Set random seed for reproducibility
np.random.seed(42)

# Configure matplotlib
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## Rate-Distortion Function for Gaussian Source

Let's implement and visualize the rate-distortion function for a Gaussian source.

In [None]:
def rate_distortion_gaussian(D, sigma_squared):
    """
    Calculate the rate-distortion function for a Gaussian source.
    
    Parameters:
    -----------
    D : float or array
        Distortion level(s)
    sigma_squared : float
        Variance of the Gaussian source
    
    Returns:
    --------
    R : float or array
        Rate in bits per symbol
    """
    D = np.array(D)
    R = np.zeros_like(D)
    
    # Only calculate for D < sigma_squared
    mask = D < sigma_squared
    R[mask] = 0.5 * np.log2(sigma_squared / D[mask])
    
    return R

def plot_rate_distortion(sigma_squared=1.0):
    """
    Plot the rate-distortion function for a Gaussian source.
    """
    # Create distortion range
    D_max = sigma_squared * 1.5
    D = np.linspace(0.001, D_max, 1000)
    
    # Calculate rate-distortion function
    R = rate_distortion_gaussian(D, sigma_squared)
    
    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(D, R, 'b-', linewidth=2, label=f'R(D) for σ² = {sigma_squared}')
    ax.axvline(x=sigma_squared, color='r', linestyle='--', 
               label=f'D = σ² = {sigma_squared}', alpha=0.7)
    
    ax.set_xlabel('Distortion (D)', fontsize=12)
    ax.set_ylabel('Rate (R) [bits/symbol]', fontsize=12)
    ax.set_title('Rate-Distortion Function for Gaussian Source', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, D_max)
    ax.set_ylim(0, max(R) * 1.1)
    
    plt.tight_layout()
    plt.show()

# Display static plot
plot_rate_distortion(sigma_squared=1.0)

## Interactive Exploration

Use the slider below to adjust the source variance (σ²) and observe how the rate-distortion curve changes.

In [None]:
# Create interactive widget
interactive_plot = widgets.interactive(
    plot_rate_distortion,
    sigma_squared=widgets.FloatSlider(
        value=1.0,
        min=0.1,
        max=5.0,
        step=0.1,
        description='Variance (σ²):',
        style={'description_width': 'initial'}
    )
)

display(interactive_plot)

## Practical Example: Signal Compression

Let's demonstrate rate-distortion theory with a practical example: compressing a Gaussian signal.

We'll:
1. Generate a random Gaussian signal
2. Compress it using uniform quantization (a simple compression method)
3. Calculate the resulting rate and distortion
4. Compare with the theoretical rate-distortion bound

In [None]:
def uniform_quantizer(signal, num_levels):
    """
    Quantize a signal uniformly.
    
    Parameters:
    -----------
    signal : array
        Input signal
    num_levels : int
        Number of quantization levels
    
    Returns:
    --------
    quantized : array
        Quantized signal
    """
    signal_min = signal.min()
    signal_max = signal.max()
    
    # Create quantization boundaries
    boundaries = np.linspace(signal_min, signal_max, num_levels + 1)
    levels = (boundaries[:-1] + boundaries[1:]) / 2
    
    # Quantize
    indices = np.digitize(signal, boundaries) - 1
    indices = np.clip(indices, 0, num_levels - 1)
    quantized = levels[indices]
    
    return quantized

def calculate_mse(original, reconstructed):
    """
    Calculate Mean Squared Error.
    """
    return np.mean((original - reconstructed) ** 2)

def demonstrate_compression(signal_length=1000, sigma_squared=1.0, num_levels=16):
    """
    Demonstrate signal compression and compare with rate-distortion bound.
    """
    # Generate Gaussian signal
    signal = np.random.normal(0, np.sqrt(sigma_squared), signal_length)
    
    # Quantize signal
    quantized = uniform_quantizer(signal, num_levels)
    
    # Calculate distortion (MSE)
    distortion = calculate_mse(signal, quantized)
    
    # Calculate rate (bits per sample)
    rate = np.log2(num_levels)
    
    # Theoretical rate-distortion bound
    D_range = np.linspace(0.001, sigma_squared, 1000)
    R_theoretical = rate_distortion_gaussian(D_range, sigma_squared)
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: Original vs Quantized Signal (first 200 samples)
    ax1 = axes[0, 0]
    sample_range = min(200, signal_length)
    ax1.plot(signal[:sample_range], 'b-', alpha=0.7, label='Original', linewidth=1)
    ax1.plot(quantized[:sample_range], 'r-', alpha=0.7, label='Quantized', linewidth=1)
    ax1.set_xlabel('Sample Index')
    ax1.set_ylabel('Amplitude')
    ax1.set_title('Original vs Quantized Signal (First 200 samples)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Error signal
    ax2 = axes[0, 1]
    error = signal - quantized
    ax2.plot(error[:sample_range], 'g-', linewidth=1)
    ax2.set_xlabel('Sample Index')
    ax2.set_ylabel('Error')
    ax2.set_title(f'Quantization Error (MSE = {distortion:.4f})')
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Rate-Distortion curve with operating point
    ax3 = axes[1, 0]
    ax3.plot(D_range, R_theoretical, 'b-', linewidth=2, label='Theoretical R(D)')
    ax3.plot(distortion, rate, 'ro', markersize=10, label=f'Operating Point\n(R={rate:.2f}, D={distortion:.4f})')
    ax3.axhline(y=rate, color='r', linestyle='--', alpha=0.3)
    ax3.axvline(x=distortion, color='r', linestyle='--', alpha=0.3)
    ax3.set_xlabel('Distortion (D)')
    ax3.set_ylabel('Rate (R) [bits/symbol]')
    ax3.set_title('Rate-Distortion Function')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_xlim(0, sigma_squared * 1.2)
    
    # Plot 4: Histograms
    ax4 = axes[1, 1]
    ax4.hist(signal, bins=50, alpha=0.5, label='Original', density=True)
    ax4.hist(quantized, bins=num_levels, alpha=0.5, label='Quantized', density=True)
    
    # Add theoretical Gaussian
    x = np.linspace(signal.min(), signal.max(), 100)
    ax4.plot(x, stats.norm.pdf(x, 0, np.sqrt(sigma_squared)), 'k-', 
             linewidth=2, label='Theoretical Gaussian')
    
    ax4.set_xlabel('Amplitude')
    ax4.set_ylabel('Density')
    ax4.set_title('Distribution Comparison')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"COMPRESSION SUMMARY")
    print(f"{'='*60}")
    print(f"Source variance (σ²):        {sigma_squared:.4f}")
    print(f"Number of quantization levels: {num_levels}")
    print(f"Rate (bits/symbol):          {rate:.4f}")
    print(f"Distortion (MSE):            {distortion:.4f}")
    print(f"Theoretical R(D):            {rate_distortion_gaussian(distortion, sigma_squared):.4f}")
    print(f"Rate gap from bound:         {rate - rate_distortion_gaussian(distortion, sigma_squared):.4f} bits/symbol")
    print(f"{'='*60}\n")

# Run demonstration
demonstrate_compression(signal_length=1000, sigma_squared=1.0, num_levels=16)

## Interactive Compression Demo

Adjust the parameters below to explore how different quantization levels affect the rate-distortion trade-off:

- **Variance (σ²)**: Controls the power of the source signal
- **Quantization Levels**: Number of discrete levels used to represent the signal (determines the rate)

In [None]:
# Create interactive demonstration
interactive_demo = widgets.interactive(
    demonstrate_compression,
    signal_length=widgets.fixed(1000),
    sigma_squared=widgets.FloatSlider(
        value=1.0,
        min=0.5,
        max=3.0,
        step=0.1,
        description='Variance (σ²):',
        style={'description_width': 'initial'}
    ),
    num_levels=widgets.IntSlider(
        value=16,
        min=4,
        max=256,
        step=4,
        description='Quantization Levels:',
        style={'description_width': 'initial'}
    )
)

display(interactive_demo)

## Key Observations

From the interactive examples above, you should observe:

1. **Trade-off**: As you increase the number of quantization levels:
   - Rate (R) increases (more bits needed per sample)
   - Distortion (D) decreases (better quality reconstruction)

2. **Theoretical Bound**: The operating point (red dot) is always above or on the theoretical R(D) curve. This is because:
   - The R(D) curve represents the theoretical minimum
   - Uniform quantization is not optimal (but simple to implement)
   - Optimal codes can get closer to the R(D) bound

3. **Source Variance Effect**: Higher source variance requires:
   - More bits to achieve the same distortion level
   - The R(D) curve shifts upward

4. **Diminishing Returns**: The rate-distortion curve shows that:
   - Small distortions require exponentially more bits
   - There's a "sweet spot" for practical applications
   - Beyond the source variance, no bits are needed (R = 0)

## Applications of Rate Distortion Theory

Rate distortion theory has numerous practical applications:

1. **Image and Video Compression** (JPEG, MPEG, H.264, H.265)
   - Determines optimal bit allocation
   - Guides perceptual coding strategies

2. **Audio Compression** (MP3, AAC, Opus)
   - Balances file size vs audio quality
   - Perceptual coding based on human hearing

3. **Wireless Communications**
   - Channel coding and modulation design
   - Adaptive coding based on channel conditions

4. **Data Storage**
   - Optimizing storage efficiency
   - Lossy compression for large datasets

5. **Machine Learning**
   - Neural network quantization
   - Model compression for edge devices
   - Information bottleneck principle

## Summary

**Rate Distortion Theory** provides a fundamental framework for understanding lossy compression:

- Developed by Claude Shannon as part of information theory
- Establishes theoretical limits on compression efficiency
- The rate-distortion function R(D) defines the minimum rate for a given distortion
- For Gaussian sources: $R(D) = \frac{1}{2}\log_2\left(\frac{\sigma^2}{D}\right)$ for $D < \sigma^2$
- Practical compression schemes approach but cannot exceed the R(D) bound
- Essential for designing efficient compression algorithms across many domains

## Further Reading

1. Shannon, C. E. (1959). "Coding theorems for a discrete source with a fidelity criterion"
2. Cover, T. M., & Thomas, J. A. (2006). "Elements of Information Theory" (Chapter 10)
3. Berger, T. (1971). "Rate Distortion Theory: A Mathematical Basis for Data Compression"

---

**Experiment with the interactive widgets above to develop intuition about rate-distortion trade-offs!**