# JPEG vs YUV420 Visual Verification

Head-to-head comparison of the JPEG and YUV420 image formats in `SlipstreamLoader`.

YUV420 stores decoded images as raw planar YUV 4:2:0 (chroma subsampled), eliminating
JPEG Huffman decode + IDCT at load time. The round-trip (JPEG → RGB → YUV420P → RGB)
introduces small pixel-level errors from chroma subsampling. This notebook verifies
the outputs are visually identical and quantifies the pixel differences.

In [None]:
LITDATA_VAL_PATH = "s3://visionlab-datasets/imagenet1k/pre-processed/s256-l512-jpgbytes-q100-streaming/val/"

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from slipstream import SlipstreamDataset, SlipstreamLoader, DecodeCenterCrop

dataset = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=False,
)
print(f"Dataset: {len(dataset):,} samples")

## Load the same batch with JPEG and YUV420

We use `CenterCrop(224)` with `shuffle=False` so both formats decode the exact same
images with the same deterministic crop.

**Note:** The first time `image_format="yuv420"` is used, the loader builds a YUV420
sibling cache by decoding all JPEG images and converting them to YUV420P format.
This is a one-time operation that takes a few minutes for 50k images.

In [None]:
# Note: First run with image_format="yuv420" triggers a one-time cache build
# that converts all JPEGs to YUV420 format. This takes a few minutes.

def load_batch(image_format, verbose=True):
    loader = SlipstreamLoader(
        dataset,
        batch_size=16,
        shuffle=False,
        image_format=image_format,
        pipelines={'image': [DecodeCenterCrop(224)]},
        exclude_fields=['path'],
        verbose=verbose,
    )
    batch = next(iter(loader))
    loader.shutdown()
    return batch

batch_jpeg = load_batch("jpeg", verbose=False)
batch_yuv = load_batch("yuv420", verbose=True)  # verbose=True shows YUV420 cache build progress

img_jpeg = batch_jpeg['image']  # [B, H, W, C] uint8
img_yuv = batch_yuv['image']
print(f"JPEG:   {img_jpeg.shape}, dtype={img_jpeg.dtype}")
print(f"YUV420: {img_yuv.shape}, dtype={img_yuv.dtype}")

## Side-by-side comparison

Top row: JPEG decode. Bottom row: YUV420 decode. Should look identical.

In [None]:
n = 8
fig, axes = plt.subplots(2, n, figsize=(n * 1.8, 3.8))
for i in range(n):
    axes[0][i].imshow(img_jpeg[i])  # Already HWC numpy
    axes[0][i].axis('off')
    axes[1][i].imshow(img_yuv[i])  # Already HWC numpy
    axes[1][i].axis('off')
axes[0][0].set_ylabel('JPEG', fontsize=11)
axes[1][0].set_ylabel('YUV420', fontsize=11)
fig.suptitle('CenterCrop(224): JPEG vs YUV420', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Pixel-level difference analysis

The YUV420 round-trip (RGB → YUV420P → RGB) introduces errors from:
1. **Chroma subsampling** — U/V planes are half-resolution (2x2 block averaging), which
   permanently discards color detail. At sharp color edges (e.g., a bright object against
   a contrasting background), the averaged chroma is wrong for pixels on both sides of the
   edge, causing large errors in individual pixels.
2. **Fixed-point BT.601 conversion** — integer rounding in both directions adds ±1.

Typical results: mean ~5, median ~2, with most pixels within ±5. Outliers up to ~150 occur
at sharp color edges where subsampled chroma overshoots in one channel. This is the same
error profile as JPEG 4:2:0 and all video codecs (H.264, H.265). For vision model training,
these errors are negligible compared to standard augmentations (random crop, color jitter, etc.).

In [None]:
# Convert numpy HWC to torch for numerical analysis
img_jpeg_t = torch.from_numpy(img_jpeg).permute(0, 3, 1, 2)  # HWC -> CHW
img_yuv_t = torch.from_numpy(img_yuv).permute(0, 3, 1, 2)
diff = (img_jpeg_t.float() - img_yuv_t.float()).abs()

print(f"Pixel difference statistics (across entire batch):")
print(f"  Mean absolute error:  {diff.mean():.3f}")
print(f"  Max absolute error:   {diff.max():.0f}")
print(f"  Median:               {diff.median():.0f}")
print(f"  % pixels with diff=0: {(diff == 0).float().mean() * 100:.1f}%")
print(f"  % pixels with diff≤1: {(diff <= 1).float().mean() * 100:.1f}%")
print(f"  % pixels with diff≤2: {(diff <= 2).float().mean() * 100:.1f}%")
print(f"  % pixels with diff>5: {(diff > 5).float().mean() * 100:.2f}%")

In [None]:
# Show difference heatmap for first 4 images
fig, axes = plt.subplots(3, 4, figsize=(9, 7))
for i in range(4):
    axes[0][i].imshow(img_jpeg[i])  # Already HWC numpy
    axes[0][i].set_title(f'JPEG #{i}', fontsize=9)
    axes[0][i].axis('off')

    axes[1][i].imshow(img_yuv[i])  # Already HWC numpy
    axes[1][i].set_title(f'YUV420 #{i}', fontsize=9)
    axes[1][i].axis('off')

    # Per-pixel max diff across channels (use torch tensor computed above)
    pixel_diff = diff[i].max(dim=0).values.numpy()
    im = axes[2][i].imshow(pixel_diff, cmap='hot', vmin=0, vmax=5)
    axes[2][i].set_title(f'|diff| (max={pixel_diff.max():.0f})', fontsize=9)
    axes[2][i].axis('off')

fig.colorbar(im, ax=axes[2], shrink=0.8, label='Pixel difference')
fig.suptitle('Pixel-level difference: JPEG vs YUV420', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Histogram of per-pixel differences
fig, ax = plt.subplots(figsize=(7, 3.5))
values = diff.flatten().numpy()
max_val = int(values.max())
counts, bins, _ = ax.hist(values, bins=np.arange(-0.5, max_val + 1.5, 1),
                          edgecolor='black', linewidth=0.5)
ax.set_xlabel('Absolute pixel difference')
ax.set_ylabel('Count')
ax.set_title('Distribution of per-pixel errors (JPEG vs YUV420)', fontweight='bold')
ax.set_yscale('log')
ax.set_xlim(-0.5, max_val + 0.5)
plt.tight_layout()
plt.show()