# YUV Colorspace Decoders

This notebook demonstrates the YUV-output decoder stages that keep images in YUV colorspace
rather than converting to RGB. These are useful for:

1. **Cross-channel contrastive learning** (e.g., Isola's colorization work) where Y, U, V
   channels are treated as separate prediction targets
2. **Perceptual loss** using luminance (Y channel) for structure
3. **Research** on color-blind or luminance-only models

## Available Stages

| Stage | Output | Description |
|-------|--------|-------------|
| `DecodeYUVFullRes` | List of `[H,W,3]` | Full image, variable size, YUV colorspace |
| `DecodeYUVPlanes` | List of `(Y, U, V)` tuples | Raw planes at native resolution (Y: HxW, U/V: H/2 x W/2) |
| `DecodeYUVCenterCrop` | `[B,H,W,3]` | Center crop, fixed size, YUV colorspace |
| `DecodeYUVRandomResizedCrop` | `[B,H,W,3]` | Random resized crop, fixed size, YUV colorspace |
| `DecodeYUVResizeCrop` | `[B,H,W,3]` | Resize + center crop, fixed size, YUV colorspace |

In [None]:
LITDATA_VAL_PATH = "s3://visionlab-datasets/imagenet1k/pre-processed/s256-l512-jpgbytes-q100-streaming/val/"

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from slipstream import SlipstreamDataset, SlipstreamLoader
from slipstream.decoders import (
    # RGB output (standard)
    DecodeCenterCrop,
    DecodeRandomResizedCrop,
    # YUV output (keeps colorspace)
    DecodeYUVCenterCrop,
    DecodeYUVRandomResizedCrop,
    DecodeYUVResizeCrop,
    DecodeYUVFullRes,
    DecodeYUVPlanes,
)

dataset = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=False,
)
print(f"Dataset: {len(dataset):,} samples")

## Helper: Visualize YUV images

YUV images have channels (Y=luminance, U=blue-difference, V=red-difference).
To display them, we either:
1. Convert back to RGB for viewing
2. Show each channel separately

In [None]:
def yuv_to_rgb(yuv: np.ndarray) -> np.ndarray:
    """Convert YUV [H,W,3] uint8 to RGB [H,W,3] uint8.
    
    Uses BT.601 conversion matrix (same as JPEG/FFMPEG).
    """
    yuv = yuv.astype(np.float32)
    y, u, v = yuv[..., 0], yuv[..., 1] - 128, yuv[..., 2] - 128
    
    r = y + 1.402 * v
    g = y - 0.344136 * u - 0.714136 * v
    b = y + 1.772 * u
    
    rgb = np.stack([r, g, b], axis=-1)
    return np.clip(rgb, 0, 255).astype(np.uint8)


def show_yuv_channels(yuv: np.ndarray, title: str = ""):
    """Display Y, U, V channels and RGB reconstruction."""
    fig, axes = plt.subplots(1, 4, figsize=(12, 3))
    
    axes[0].imshow(yuv[..., 0], cmap='gray')
    axes[0].set_title('Y (Luminance)')
    axes[0].axis('off')
    
    axes[1].imshow(yuv[..., 1], cmap='coolwarm', vmin=0, vmax=255)
    axes[1].set_title('U (Blue-diff)')
    axes[1].axis('off')
    
    axes[2].imshow(yuv[..., 2], cmap='coolwarm', vmin=0, vmax=255)
    axes[2].set_title('V (Red-diff)')
    axes[2].axis('off')
    
    axes[3].imshow(yuv_to_rgb(yuv))
    axes[3].set_title('RGB (converted)')
    axes[3].axis('off')
    
    if title:
        fig.suptitle(title, fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.show()

## DecodeYUVCenterCrop vs DecodeCenterCrop

Compare RGB output (standard) with YUV output (keeps colorspace).

In [None]:
# Load same images with RGB and YUV decoders
def load_batch(pipeline, image_format="yuv420"):
    loader = SlipstreamLoader(
        dataset,
        batch_size=8,
        shuffle=False,
        image_format=image_format,
        pipelines={'image': [pipeline]},
        exclude_fields=['path'],
        verbose=False,
    )
    batch = next(iter(loader))
    loader.shutdown()
    return batch['image']

# RGB output (standard decoder)
rgb_images = load_batch(DecodeCenterCrop(224))
print(f"RGB output: {rgb_images.shape}")

# YUV output (keeps colorspace)
yuv_images = load_batch(DecodeYUVCenterCrop(224))
print(f"YUV output: {yuv_images.shape}")

In [None]:
# Show first image: RGB vs YUV channels
fig, axes = plt.subplots(2, 4, figsize=(14, 6))

# Top row: RGB image shown as RGB and split into R, G, B
rgb = rgb_images[0]  # [H, W, 3]
axes[0, 0].imshow(rgb)
axes[0, 0].set_title('RGB (full)')
axes[0, 1].imshow(rgb[..., 0], cmap='Reds')
axes[0, 1].set_title('R channel')
axes[0, 2].imshow(rgb[..., 1], cmap='Greens')
axes[0, 2].set_title('G channel')
axes[0, 3].imshow(rgb[..., 2], cmap='Blues')
axes[0, 3].set_title('B channel')

# Bottom row: YUV image shown as converted RGB and split into Y, U, V
yuv = yuv_images[0]  # [H, W, 3]
axes[1, 0].imshow(yuv_to_rgb(yuv))
axes[1, 0].set_title('YUV→RGB')
axes[1, 1].imshow(yuv[..., 0], cmap='gray')
axes[1, 1].set_title('Y (luminance)')
axes[1, 2].imshow(yuv[..., 1], cmap='coolwarm', vmin=0, vmax=255)
axes[1, 2].set_title('U (blue-diff)')
axes[1, 3].imshow(yuv[..., 2], cmap='coolwarm', vmin=0, vmax=255)
axes[1, 3].set_title('V (red-diff)')

for ax in axes.flat:
    ax.axis('off')

axes[0, 0].set_ylabel('DecodeCenterCrop\n(RGB output)', fontsize=10)
axes[1, 0].set_ylabel('DecodeYUVCenterCrop\n(YUV output)', fontsize=10)

fig.suptitle('RGB vs YUV Decoder Output', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## DecodeYUVRandomResizedCrop

Random crops in YUV colorspace for training.

In [None]:
yuv_rrc = load_batch(DecodeYUVRandomResizedCrop(224, seed=42))
print(f"YUV RRC output: {yuv_rrc.shape}")

# Show multiple crops
n = 4
fig, axes = plt.subplots(n, 4, figsize=(12, n * 2.5))

for i in range(n):
    yuv = yuv_rrc[i]
    axes[i, 0].imshow(yuv_to_rgb(yuv))
    axes[i, 0].set_title('RGB' if i == 0 else '')
    axes[i, 1].imshow(yuv[..., 0], cmap='gray')
    axes[i, 1].set_title('Y' if i == 0 else '')
    axes[i, 2].imshow(yuv[..., 1], cmap='coolwarm', vmin=0, vmax=255)
    axes[i, 2].set_title('U' if i == 0 else '')
    axes[i, 3].imshow(yuv[..., 2], cmap='coolwarm', vmin=0, vmax=255)
    axes[i, 3].set_title('V' if i == 0 else '')

for ax in axes.flat:
    ax.axis('off')

fig.suptitle('DecodeYUVRandomResizedCrop(224)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## DecodeYUVResizeCrop (Validation Transform)

Standard ImageNet validation: resize shortest edge to 256, center crop 224.

In [None]:
yuv_val = load_batch(DecodeYUVResizeCrop(resize_size=256, crop_size=224))
print(f"YUV ResizeCrop output: {yuv_val.shape}")

# Show a few samples
fig, axes = plt.subplots(2, 4, figsize=(12, 5))
for i in range(4):
    yuv = yuv_val[i]
    axes[0, i].imshow(yuv_to_rgb(yuv))
    axes[0, i].axis('off')
    axes[1, i].imshow(yuv[..., 0], cmap='gray')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('RGB view', fontsize=10)
axes[1, 0].set_ylabel('Y channel', fontsize=10)
fig.suptitle('DecodeYUVResizeCrop(256, 224)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Use Case: Cross-Channel Prediction

For colorization or cross-channel contrastive learning, you can use Y as input
and U, V as prediction targets (or vice versa).

In [None]:
# Example: Split YUV into separate channels for model input/target
yuv_batch = yuv_rrc  # [B, H, W, 3]

# For colorization: predict color (U, V) from grayscale (Y)
luminance = yuv_batch[..., 0:1]  # [B, H, W, 1]
chrominance = yuv_batch[..., 1:3]  # [B, H, W, 2]

print(f"Luminance (input):    {luminance.shape}")
print(f"Chrominance (target): {chrominance.shape}")

# Visualize the split
fig, axes = plt.subplots(1, 4, figsize=(14, 3))
i = 0
axes[0].imshow(yuv_to_rgb(yuv_batch[i]))
axes[0].set_title('Original (YUV→RGB)')
axes[1].imshow(luminance[i, ..., 0], cmap='gray')
axes[1].set_title('Input: Y (grayscale)')
axes[2].imshow(chrominance[i, ..., 0], cmap='coolwarm', vmin=0, vmax=255)
axes[2].set_title('Target: U')
axes[3].imshow(chrominance[i, ..., 1], cmap='coolwarm', vmin=0, vmax=255)
axes[3].set_title('Target: V')

for ax in axes:
    ax.axis('off')
fig.suptitle('Cross-Channel Prediction Setup', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Channel Normalization

YUV channels have different value ranges:
- **Y** (luminance): 16-235 for video, 0-255 for full-range
- **U, V** (chrominance): centered at 128, range ~16-240

For neural network training, you may want to normalize differently than RGB.

In [None]:
# Compute channel statistics from a batch
yuv_float = yuv_batch.astype(np.float32)

print("Channel statistics:")
for c, name in enumerate(['Y', 'U', 'V']):
    channel = yuv_float[..., c]
    print(f"  {name}: mean={channel.mean():.3f}, std={channel.std():.3f}, "
          f"min={channel.min():.3f}, max={channel.max():.3f}")

print("\nNote: U and V are centered around 0.5 (128/255)")

In [None]:
# Compute channel statistics from a batch
yuv_float = yuv_batch.astype(np.float32) / 255.0

print("Channel statistics (normalized 0-1):")
for c, name in enumerate(['Y', 'U', 'V']):
    channel = yuv_float[..., c]
    print(f"  {name}: mean={channel.mean():.3f}, std={channel.std():.3f}, "
          f"min={channel.min():.3f}, max={channel.max():.3f}")

print("\nNote: U and V are centered around 0.5 (128/255)")

In [None]:
from slipstream import compute_normalization_stats                                                                                                                             
from slipstream.cache import OptimizedCache                                                                                                                                  

loader = SlipstreamLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    image_format="yuv420",
    pipelines={'image': [DecodeYUVCenterCrop(224)]},
    exclude_fields=['path'],
    verbose=False,
)

# Compute YUV stats (use colorspace="yuv" to keep YUV instead of converting to RGB)
yuv_stats = compute_normalization_stats(
    loader.cache, 
    image_format="yuv420", 
    colorspace="yuv",  # <-- NEW: keeps YUV colorspace
    num_samples=5000,  # subset for speed
)

# Compare with RGB stats (default behavior)
print("\nFor comparison, RGB stats (converts YUV→RGB):")
rgb_stats = compute_normalization_stats(
    loader.cache, 
    image_format="yuv420", 
    colorspace="rgb",  # <-- default: converts to RGB
    num_samples=5000,
)

## Summary

The `DecodeYUV*` stages provide a complete set of crop operations that output
YUV colorspace instead of RGB:

| Stage | Use Case |
|-------|----------|
| `DecodeYUVCenterCrop` | Validation with YUV output |
| `DecodeYUVRandomResizedCrop` | Training with YUV output |
| `DecodeYUVResizeCrop` | ImageNet-style validation with YUV output |
| `DecodeYUVFullRes` | Variable-size full images in YUV |
| `DecodeYUVPlanes` | Raw Y/U/V planes at native resolution |

These are useful for:
- Cross-channel contrastive learning (colorization, split-brain autoencoders)
- Luminance-only processing (grayscale models using just Y channel)
- Research on color perception and representation