# Image Colorization - Data Visualization

This notebook visualizes:
1. Train/Val/Test split statistics
2. Sample images from each split
3. Grayscale vs. color comparisons
4. Augmentation effects
5. Baseline metrics distribution


In [None]:
import sys
sys.path.append('./python')

import numpy as np
import matplotlib.pyplot as plt
import needle.data as data

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')


## 1. Load Train/Val/Test Splits


In [None]:
# Create train/val/test splits
train_ds, val_ds, test_ds = data.create_colorization_splits(
    base_folder='./data/cifar-10-batches-py',
    val_size=5000,
    seed=42,
    augment_train=True
)

print("Dataset Split Statistics")
print("=" * 40)
print(f"Training samples:   {len(train_ds):,}")
print(f"Validation samples: {len(val_ds):,}")
print(f"Test samples:       {len(test_ds):,}")
print(f"Total:              {len(train_ds) + len(val_ds) + len(test_ds):,}")


## 2. Visualize Sample Images


In [None]:
def show_colorization_samples(dataset, title, num_samples=4):
    """Display grayscale input, predicted baseline (gray), and ground truth."""
    fig, axes = plt.subplots(num_samples, 3, figsize=(10, 3*num_samples))
    
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        L, ab, rgb = dataset[idx]
        
        # Grayscale (input)
        gray = L[0]  # (H, W)
        
        # Ground truth RGB
        rgb_gt = np.transpose(rgb, (1, 2, 0))
        
        # Baseline: grayscale RGB
        rgb_baseline = np.stack([gray, gray, gray], axis=-1)
        
        axes[i, 0].imshow(gray, cmap='gray')
        axes[i, 0].set_title('Grayscale Input')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(rgb_baseline)
        axes[i, 1].set_title('Baseline (No Color)')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(rgb_gt)
        axes[i, 2].set_title('Ground Truth')
        axes[i, 2].axis('off')
    
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig

# Show samples from each split
np.random.seed(42)
fig_train = show_colorization_samples(train_ds, "Training Set Samples")
plt.savefig('train_samples.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# Validation and Test samples
fig_val = show_colorization_samples(val_ds, "Validation Set Samples")
plt.savefig('val_samples.png', dpi=150, bbox_inches='tight')
plt.show()

fig_test = show_colorization_samples(test_ds, "Test Set Samples")
plt.savefig('test_samples.png', dpi=150, bbox_inches='tight')
plt.show()


## 3. Augmentation Pipeline Visualization


In [None]:
# Show augmentation effects on the same image
from needle.data.datasets.colorization_dataset import ColorizationAugmentation

# Get a raw image (no augmentation)
raw_cifar = data.CIFAR10Dataset('./data/cifar-10-batches-py', train=True, split='train')
sample_img, label = raw_cifar[100]

# Apply augmentations multiple times
aug = ColorizationAugmentation(flip_prob=0.5, crop_prob=0.5, jitter_prob=0.5, rotation_prob=0.5)

fig, axes = plt.subplots(2, 4, figsize=(14, 7))
axes[0, 0].imshow(np.transpose(sample_img, (1, 2, 0)))
axes[0, 0].set_title('Original', fontweight='bold')
axes[0, 0].axis('off')

for i in range(1, 8):
    row, col = i // 4, i % 4
    aug_img = aug(sample_img.copy())
    axes[row, col].imshow(np.transpose(aug_img, (1, 2, 0)))
    axes[row, col].set_title(f'Augmented #{i}')
    axes[row, col].axis('off')

fig.suptitle('Augmentation Pipeline Effects', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('augmentation_examples.png', dpi=150, bbox_inches='tight')
plt.show()


## 4. Baseline Metrics (PSNR/SSIM)


In [None]:
# Run baseline metrics on 100 samples
from baseline_metrics import run_baseline_metrics

results = run_baseline_metrics(num_samples=100, seed=42)


In [None]:
# Plot metrics distributions
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# PSNR histogram
psnr_vals = results['raw_scores']['psnr']
axes[0].hist(psnr_vals, bins=20, edgecolor='black', alpha=0.7, color='steelblue')
axes[0].axvline(np.mean(psnr_vals), color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {np.mean(psnr_vals):.2f} dB')
axes[0].set_xlabel('PSNR (dB)')
axes[0].set_ylabel('Frequency')
axes[0].set_title('PSNR Distribution')
axes[0].legend()

# SSIM histogram
ssim_vals = results['raw_scores']['ssim_windowed']
axes[1].hist(ssim_vals, bins=20, edgecolor='black', alpha=0.7, color='darkorange')
axes[1].axvline(np.mean(ssim_vals), color='red', linestyle='--', linewidth=2,
                label=f'Mean: {np.mean(ssim_vals):.4f}')
axes[1].set_xlabel('SSIM')
axes[1].set_ylabel('Frequency')
axes[1].set_title('SSIM Distribution (Windowed)')
axes[1].legend()

# L1 Error histogram
l1_vals = results['raw_scores']['l1']
axes[2].hist(l1_vals, bins=20, edgecolor='black', alpha=0.7, color='forestgreen')
axes[2].axvline(np.mean(l1_vals), color='red', linestyle='--', linewidth=2,
                label=f'Mean: {np.mean(l1_vals):.4f}')
axes[2].set_xlabel('L1 Error')
axes[2].set_ylabel('Frequency')
axes[2].set_title('L1 Error Distribution')
axes[2].legend()

fig.suptitle('Baseline Metrics (Grayscale Only - No Colorization)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('baseline_metrics_distribution.png', dpi=150, bbox_inches='tight')
plt.show()


## 5. Summary

### Dataset Splits
- **Training**: 45,000 samples (with augmentation)
- **Validation**: 5,000 samples (no augmentation)
- **Test**: 10,000 samples (no augmentation)

### Augmentation Pipeline
- Random horizontal flip (50%)
- Random crop & resize (30%)
- Color jitter - brightness/contrast (20%)
- Random 90-degree rotation (20%)

### Baseline Metrics (Grayscale â†’ No Color)
The baseline represents the lower bound for our colorization model. These metrics will improve as the model learns to predict colors.


In [None]:
# Print summary table
print("=" * 60)
print("BASELINE METRICS SUMMARY (100 samples)")
print("=" * 60)
print(f"{'Metric':<20} {'Mean':>12} {'Std':>12}")
print("-" * 60)
print(f"{'PSNR (dB)':<20} {results['metrics']['psnr']['mean']:>12.2f} {results['metrics']['psnr']['std']:>12.2f}")
print(f"{'SSIM (windowed)':<20} {results['metrics']['ssim_windowed']['mean']:>12.4f} {results['metrics']['ssim_windowed']['std']:>12.4f}")
print(f"{'L1 Error':<20} {results['metrics']['l1_error']['mean']:>12.4f} {results['metrics']['l1_error']['std']:>12.4f}")
print("=" * 60)
print("\nNote: Higher PSNR/SSIM = better, Lower L1 = better")
print("These are BASELINE values (grayscale only, no colorization)")
print("Trained models should significantly improve these metrics.")
