# Notebook 3: Nén ảnh Màu (RGB) bằng SVD

**Mục tiêu:** Áp dụng SVD để nén ảnh màu RGB

**Nội dung:**
1. Load ảnh RGB
2. Hiểu cách SVD hoạt động với ảnh màu (per-channel compression)
3. Nén với k khác nhau
4. Đánh giá chất lượng
5. So sánh với ảnh grayscale

## 1. Setup

In [None]:
# Import
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from image_utils import load_image, save_image, get_image_info
from svd_compression import compress_rgb, compress_grayscale, get_svd_matrices, calculate_compression_ratio
from quality_metrics import calculate_all_metrics

plt.rcParams['figure.figsize'] = (16, 10)
sns.set_style("whitegrid")

print("Setup OK!")

## 2. Lý thuyết: SVD cho ảnh RGB

### Ảnh RGB có 3 kênh màu:
- **R** (Red): Kênh đỏ
- **G** (Green): Kênh xanh lá
- **B** (Blue): Kênh xanh dương

### Quy trình nén:
1. **Tách ảnh** thành 3 ma trận: R, G, B (mỗi ma trận m×n)
2. **Áp dụng SVD riêng biệt** cho từng kênh:
   - R = U_R × Σ_R × V^T_R
   - G = U_G × Σ_G × V^T_G
   - B = U_B × Σ_B × V^T_B
3. **Giữ k components** cho mỗi kênh
4. **Ghép lại** thành ảnh RGB

### Dung lượng:
- **Gốc**: m × n × 3
- **Nén**: 3 × k(m + n + 1)
- **Compression ratio**: (1 - 3k(m+n+1)/(3mn)) × 100%

## 3. Load ảnh RGB

In [None]:
# Load ảnh màu test
image_path = '../images/color/4.1.01.tiff'

try:
    original_rgb = load_image(image_path, mode='RGB')
    print("Load anh thanh cong!")
    print(f"   File: {image_path}")
except FileNotFoundError:
    print(f"Khong tim thay file: {image_path}")
    raise

# Thong tin anh
info = get_image_info(original_rgb)
print("\nThong tin anh:")
for key, value in info.items():
    print(f"  {key}: {value}")

# Visualize
fig, axes = plt.subplots(1, 4, figsize=(18, 5))

# RGB image
axes[0].imshow(original_rgb)
axes[0].set_title(f'Original RGB\n{original_rgb.shape}', fontweight='bold')
axes[0].axis('off')

# Separate channels
channels = ['Red', 'Green', 'Blue']
colors = ['Reds', 'Greens', 'Blues']

for i, (channel, cmap) in enumerate(zip(channels, colors), start=1):
    axes[i].imshow(original_rgb[:, :, i-1], cmap=cmap)
    axes[i].set_title(f'{channel} Channel', fontweight='bold')
    axes[i].axis('off')

plt.tight_layout()
plt.savefig('../results/visualizations/rgb_channels.png', dpi=120, bbox_inches='tight')
plt.show()

## 4. Phân tích SVD cho từng kênh

In [None]:
# SVD for each channel
channels_data = {}
channel_names = ['Red', 'Green', 'Blue']

for i, name in enumerate(channel_names):
    channel = original_rgb[:, :, i]
    U, S, Vt = get_svd_matrices(channel)
    channels_data[name] = {'U': U, 'S': S, 'Vt': Vt}
    print(f"{name} channel:")
    print(f"  U: {U.shape}, S: {S.shape}, Vt: {Vt.shape}")
    print(f"  Top 5 singular values: {S[:5]}")
    print()

In [None]:
# Compare singular values across channels
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot singular values
colors_plot = ['red', 'green', 'blue']
for name, color in zip(channel_names, colors_plot):
    S = channels_data[name]['S']
    axes[0].semilogy(S, label=name, linewidth=2, color=color, alpha=0.7)

axes[0].set_xlabel('Index')
axes[0].set_ylabel('Singular Value (log scale)')
axes[0].set_title('Singular Value Spectrum - All Channels', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Cumulative energy
for name, color in zip(channel_names, colors_plot):
    S = channels_data[name]['S']
    cumulative_energy = np.cumsum(S**2) / np.sum(S**2) * 100
    axes[1].plot(cumulative_energy, label=name, linewidth=2, color=color, alpha=0.7)

axes[1].axhline(y=90, color='gray', linestyle='--', alpha=0.5)
axes[1].set_xlabel('k')
axes[1].set_ylabel('Cumulative Energy (%)')
axes[1].set_title('Energy Preservation - All Channels', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0, 102])

plt.tight_layout()
plt.savefig('../results/visualizations/rgb_svd_analysis.png', dpi=120, bbox_inches='tight')
plt.show()

## 5. Compression với k khác nhau

In [None]:
# Test voi cac k khac nhau
k_values = [5, 10, 20, 50, 100]

# Adjust k based on image size
max_k = min(original_rgb.shape[:2])
k_values = [k for k in k_values if k <= max_k]

print(f"Testing with k values: {k_values}")
print(f"\n{'k':>5} | {'PSNR':>8} | {'MSE':>10} | {'Saved':>8}")
print("-" * 45)

results_rgb = {}

for k in k_values:
    # Compress
    compressed = compress_rgb(original_rgb, k)
    
    # Metrics
    metrics = calculate_all_metrics(original_rgb, compressed, include_ssim=False)
    stats = calculate_compression_ratio(original_rgb.shape, k, is_rgb=True)
    
    # Save results
    results_rgb[k] = {
        'compressed': compressed,
        'metrics': metrics,
        'stats': stats
    }
    
    # Print
    print(f"{k:5d} | {metrics['psnr']:8.2f} | {metrics['mse']:10.2f} | "
          f"{stats['space_saved_percent']:7.1f}%")
    
    # Save compressed image
    save_image(compressed, f'../results/compressed/rgb_k{k:03d}.png')

print("\nCompression completed!")

## 6. Visual Comparison

In [None]:
# Compare visually
n_images = len(k_values) + 1
ncols = 3
nrows = (n_images + ncols - 1) // ncols

fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows))
axes = axes.flatten()

# Original
axes[0].imshow(original_rgb)
axes[0].set_title('Original', fontsize=13, fontweight='bold')
axes[0].axis('off')

# Compressed versions
for idx, k in enumerate(k_values, start=1):
    compressed = results_rgb[k]['compressed']
    psnr = results_rgb[k]['metrics']['psnr']
    saved = results_rgb[k]['stats']['space_saved_percent']
    
    axes[idx].imshow(compressed)
    axes[idx].set_title(f'k={k}\nPSNR={psnr:.1f}dB, Saved={saved:.0f}%',
                        fontsize=12, fontweight='bold')
    axes[idx].axis('off')

# Hide unused
for idx in range(n_images, len(axes)):
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('../results/visualizations/rgb_comparison.png', dpi=120, bbox_inches='tight')
plt.show()

## 7. Error Analysis

In [None]:
# Error maps for each k
fig, axes = plt.subplots(2, 3, figsize=(16, 11))
axes = axes.flatten()

axes[0].text(0.5, 0.5, 'Error Maps\n(RGB channels)',
             ha='center', va='center', fontsize=14, fontweight='bold')
axes[0].axis('off')

for idx, k in enumerate(k_values, start=1):
    compressed = results_rgb[k]['compressed']
    
    # Calculate error (average across channels)
    error = np.mean(np.abs(original_rgb.astype(np.float64) - compressed.astype(np.float64)), axis=2)
    
    im = axes[idx].imshow(error, cmap='hot', vmin=0, vmax=30)
    axes[idx].set_title(f'k={k}\nMax Error={error.max():.1f}',
                        fontsize=12, fontweight='bold')
    axes[idx].axis('off')
    plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig('../results/visualizations/rgb_error_maps.png', dpi=120, bbox_inches='tight')
plt.show()

## 8. Quality vs k

In [None]:
# Detailed analysis
k_range = list(range(5, min(original_rgb.shape[:2]), 10))
psnr_list_rgb = []
mse_list_rgb = []
saved_list_rgb = []

for k in k_range:
    comp = compress_rgb(original_rgb, k)
    m = calculate_all_metrics(original_rgb, comp)
    s = calculate_compression_ratio(original_rgb.shape, k, is_rgb=True)
    
    psnr_list_rgb.append(m['psnr'])
    mse_list_rgb.append(m['mse'])
    saved_list_rgb.append(s['space_saved_percent'])

# Plot
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# PSNR vs k
axes[0].plot(k_range, psnr_list_rgb, marker='o', linewidth=2, markersize=5)
axes[0].axhline(y=30, color='red', linestyle='--', label='Good (30 dB)')
axes[0].axhline(y=40, color='green', linestyle='--', label='Excellent (40 dB)')
axes[0].set_xlabel('k')
axes[0].set_ylabel('PSNR (dB)')
axes[0].set_title('PSNR vs k', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# MSE vs k
axes[1].plot(k_range, mse_list_rgb, marker='s', linewidth=2, markersize=5, color='orangered')
axes[1].set_xlabel('k')
axes[1].set_ylabel('MSE')
axes[1].set_title('MSE vs k', fontweight='bold')
axes[1].grid(True, alpha=0.3)

# Trade-off
axes[2].plot(saved_list_rgb, psnr_list_rgb, marker='D', linewidth=2, markersize=5, color='darkgreen')
axes[2].set_xlabel('Space Saved (%)')
axes[2].set_ylabel('PSNR (dB)')
axes[2].set_title('Quality vs Compression', fontweight='bold')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/visualizations/rgb_quality_vs_k.png', dpi=120, bbox_inches='tight')
plt.show()

## 9. Find Optimal k

In [None]:
# Find optimal k (PSNR >= 30 dB)
optimal_k_rgb = None
for k in k_values:
    if results_rgb[k]['metrics']['psnr'] >= 30:
        optimal_k_rgb = k
        break

if optimal_k_rgb is None:
    optimal_k_rgb = k_values[-1]

opt_metrics = results_rgb[optimal_k_rgb]['metrics']
opt_stats = results_rgb[optimal_k_rgb]['stats']

print("="*60)
print(f"OPTIMAL k = {optimal_k_rgb}")
print("="*60)
print(f"  PSNR: {opt_metrics['psnr']:.2f} dB")
print(f"  MSE:  {opt_metrics['mse']:.2f}")
print(f"  Space saved: {opt_stats['space_saved_percent']:.1f}%")
print(f"  Original size: {opt_stats['original_size']:,} values")
print(f"  Compressed size: {opt_stats['compressed_size']:,} values")
print("="*60)

# Save optimal
save_image(results_rgb[optimal_k_rgb]['compressed'], f'../results/compressed/rgb_optimal_k{optimal_k_rgb}.png')

## 10. Summary

### Key Findings:

1. **RGB compression works well** - Each channel compressed independently
2. **Different channels have different characteristics** - Singular value distributions vary
3. **Good quality at low k** - Significant compression with minimal quality loss
4. **Trade-off still exists** - Balance between compression ratio and visual quality

### Comparison with Grayscale:
- RGB requires 3× more storage for same k
- But maintains color information
- Overall compression ratio still good

### Next Steps:
- Compare RGB vs Grayscale compression
- Try different k for different channels
- Explore color space conversions (YCbCr)