# Image Colorization Demo

This notebook demonstrates the image colorization system built on the Needle framework.

## Overview
- Load a trained colorization model
- Colorize grayscale CIFAR-10 images
- Visualize results and compare with ground truth
- Compute evaluation metrics (SSIM, PSNR, L1 error)


In [None]:
import sys
sys.path.append('./python')

import needle as ndl
import needle.nn as nn
from needle import ops
import matplotlib.pyplot as plt
import pickle
import random
import math

%matplotlib inline


## Load Dataset


In [None]:
# Load CIFAR-10 test dataset
test_dataset = ndl.data.CIFAR10Dataset(
    base_folder="./data/cifar-10-batches-py",
    train=False
)

# Wrap with colorization dataset (return_rgb=True for visualization)
test_color_dataset = ndl.data.ColorizationDataset(test_dataset, return_rgb=True)

print(f"Test samples: {len(test_color_dataset)}")


Test samples: 10000


## Load Trained Model


In [3]:
# Initialize model
device = ndl.cpu()
dtype = "float32"

model = nn.ColorizationModel(device=device, dtype=dtype)

# For demo purposes without training, we'll use a randomly initialized model
print("Using initialized model (train first for better results)")


Using initialized model (train first for better results)


## Helper Functions for Visualization


In [None]:
def lab_to_rgb_python(L_2d, ab, H, W):
    """Convert Lab to RGB using pure Python. Returns nested list [H][W][3]."""
    rgb_result = []
    for i in range(H):
        row = []
        for j in range(W):
            l = float(L_2d[i, j]) * 100.0
            a = float(ab[0, i, j]) if len(ab.shape) == 3 else float(ab[i, j, 0])
            b = float(ab[1, i, j]) if len(ab.shape) == 3 else float(ab[i, j, 1])
            
            # Lab to XYZ
            fy = (l + 16) / 116
            fx = a / 500 + fy
            fz = fy - b / 200
            
            delta = 6/29
            def f_inv(t):
                return t ** 3 if t > delta else 3 * delta**2 * (t - 4/29)
            
            X = 0.95047 * f_inv(fx)
            Y = 1.00000 * f_inv(fy)
            Z = 1.08883 * f_inv(fz)
            
            # XYZ to RGB
            R = 3.2406 * X - 1.5372 * Y - 0.4986 * Z
            G = -0.9689 * X + 1.8758 * Y + 0.0415 * Z
            B = 0.0557 * X - 0.2040 * Y + 1.0570 * Z
            
            # Clip
            R, G, B = max(0, min(1, R)), max(0, min(1, G)), max(0, min(1, B))
            row.append([R, G, B])
        rgb_result.append(row)
    return rgb_result


## Visualize Colorization Results


In [None]:
# Select random samples
num_samples = 4
random.seed(42)
indices = random.sample(range(len(test_color_dataset)), num_samples)

fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))

H, W = 32, 32

for i, idx in enumerate(indices):
    # Get data (with return_rgb=True)
    gray_np, ab_target_np, rgb_original_np = test_color_dataset[idx]
    
    # Convert to tensor and predict
    gray_reshaped = gray_np.reshape(1, 1, 32, 32)
    gray_tensor = ndl.Tensor(gray_reshaped, device=device, dtype=dtype)
    ab_pred = model.predict_ab(gray_tensor).numpy()[0]
    
    # Get L channel
    L_2d = gray_np[0]  # (H, W)
    
    # Reconstruct RGB from predicted ab
    rgb_pred = lab_to_rgb_python(L_2d, ab_pred, H, W)
    
    # Ground truth RGB - transpose CHW to HWC
    rgb_gt = [[[float(rgb_original_np[c, hi, wi]) for c in range(3)] for wi in range(W)] for hi in range(H)]
    
    # Grayscale for display
    gray_display = [[float(L_2d[hi, wi]) for wi in range(W)] for hi in range(H)]
    
    # Plot
    axes[i, 0].imshow(gray_display, cmap='gray')
    axes[i, 0].set_title('Grayscale Input')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(rgb_pred)
    axes[i, 1].set_title('Colorized (Predicted)')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(rgb_gt)
    axes[i, 2].set_title('Ground Truth')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.savefig('colorization_results.png', dpi=150, bbox_inches='tight')
plt.show()