# Phase 4: Cache & Format Enhancements

This notebook tests the three new Phase 4 features:

1. **`compute_normalization_stats`** — per-channel RGB mean/std from a slip cache
2. **YUV output mode** — `DecodeYUVFullRes` and `DecodeYUVPlanes` pipelines
3. **`sync_s3_dataset`** — fast S3→local sync via s5cmd

Prereqs: run `uv sync` to rebuild the C extension with new YUV kernels.

In [None]:
LITDATA_VAL_PATH = "s3://visionlab-datasets/imagenet1k/pre-processed/s256-l512-jpgbytes-q100-streaming/val/"

In [None]:
from slipstream import SlipstreamDataset

dataset = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=False,
)
print(f"Dataset: {len(dataset):,} samples")
print(f"Cache path: {dataset.cache_path}")

## 1. Normalization Stats

Compute per-channel RGB mean and std using Welford's online algorithm.
First test on a small subset, then compare JPEG vs YUV420 results.

In [None]:
from slipstream import SlipstreamLoader, DecodeCenterCrop
from slipstream.cache import OptimizedCache

# Ensure slip cache exists (loader builds it on first use)
loader = SlipstreamLoader(
    dataset, batch_size=256,
    pipelines={'image': [DecodeCenterCrop(224)]},
    exclude_fields=['path'],
    verbose=True,
)
cache = loader.cache
print(cache)

In [None]:
from slipstream import compute_normalization_stats
import numpy as np
import torch

# First, sanity-check: manually compute stats on a small batch via PIL/torchvision
from slipstream import SlipstreamDataset
ds_pil = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=True,
    to_pil=False,  # returns CHW tensor
)

# Compute reference stats on first 200 images using torchvision decode
channel_sum = np.zeros(3, dtype=np.float64)
channel_sum_sq = np.zeros(3, dtype=np.float64)
pixel_count = 0
for i in range(200):
    img = ds_pil[i]['image']  # CHW uint8 tensor
    pixels = img.permute(1, 2, 0).numpy().reshape(-1, 3).astype(np.float64) / 255.0
    channel_sum += pixels.sum(axis=0)
    channel_sum_sq += (pixels * pixels).sum(axis=0)
    pixel_count += pixels.shape[0]

ref_mean = channel_sum / pixel_count
ref_std = np.sqrt(channel_sum_sq / pixel_count - ref_mean ** 2)
print(f"Reference (torchvision, 200 imgs): mean={tuple(ref_mean.round(4))}, std={tuple(ref_std.round(4))}")

# Now compute with slipstream on the same 200 images
stats_200 = compute_normalization_stats(cache, image_format="jpeg", num_samples=200)
print(f"Slipstream (200 imgs):             mean={stats_200['mean']}, std={stats_200['std']}")

# Full dataset
stats = compute_normalization_stats(cache, image_format="jpeg")
print(f"\nSlipstream (full dataset): {stats}")

# Compare against known ImageNet stats (computed from training set)

In [None]:
from slipstream import IMAGENET_MEAN, IMAGENET_STD

print(f"ImageNet reference mean (train): {IMAGENET_MEAN}")
print(f"ImageNet reference std  (train): {IMAGENET_STD}")
print()
print(f"Computed mean (val):  {stats['mean']}")
print(f"Computed std  (val):  {stats['std']}")
print()
mean_diff = np.abs(np.array(IMAGENET_MEAN) - np.array(stats['mean']))
std_diff = np.abs(np.array(IMAGENET_STD) - np.array(stats['std']))
print(f"Mean abs diff (train vs val): {tuple(mean_diff.round(4))}")
print(f"Std abs diff  (train vs val): {tuple(std_diff.round(4))}")

# YUV Normalization Stats

In [None]:
# Build YUV420 cache if needed (uses the existing JPEG cache)
loader_yuv = SlipstreamLoader(
    dataset, batch_size=256,
    image_format="yuv420",
    pipelines={'image': [DecodeCenterCrop(224)]},
    exclude_fields=['path'],
    verbose=True,
)

stats_yuv = compute_normalization_stats(cache, image_format="yuv420")

print(f"\nJPEG mean: {stats['mean']}")
print(f"YUV  mean: {stats_yuv['mean']}")
print(f"JPEG std:  {stats['std']}")
print(f"YUV  std:  {stats_yuv['std']}")

import numpy as np
mean_diff = np.abs(np.array(stats['mean']) - np.array(stats_yuv['mean']))
std_diff = np.abs(np.array(stats['std']) - np.array(stats_yuv['std']))
print(f"\nMean abs diff: {mean_diff} (expected < 0.01)")
print(f"Std abs diff:  {std_diff} (expected < 0.01)")

loader_yuv.shutdown()

## 2. YUV Output Mode

Two new pipelines for emitting YUV instead of RGB:
- **`DecodeYUVFullRes`**: nearest-neighbor upsample U/V → `[H, W, 3]` (Y, U, V)
- **`DecodeYUVPlanes`**: raw planes → `(Y, U, V)` with Y at full res, U/V at half res

In [None]:
from slipstream import SlipstreamDataset, SlipstreamLoader, DecodeYUVFullRes, DecodeYUVPlanes

LITDATA_VAL_PATH = "s3://visionlab-datasets/imagenet1k/pre-processed/s256-l512-jpgbytes-q100-streaming/val/"

dataset = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=False,
)
print(f"Dataset: {len(dataset):,} samples")
print(f"Cache path: {dataset.cache_path}")

# Full-res YUV output
loader_yuv_full = SlipstreamLoader(
    dataset, batch_size=8, shuffle=False,
    image_format="yuv420",
    pipelines={'image': [DecodeYUVFullRes()]},
    exclude_fields=['path'],
    verbose=False,
)

batch = next(iter(loader_yuv_full))
yuv_images = batch['image']  # list of [H, W, 3] arrays
print(f"DecodeYUVFullRes: {len(yuv_images)} images")
print(f"  Shape: {yuv_images[0].shape}, dtype: {yuv_images[0].dtype}")
print(f"  Channels are (Y, U, V) — not RGB")
loader_yuv_full.shutdown()

In [None]:
# DecodeYUVFullRes?

In [None]:
yuv_images[0].shape

In [None]:
# Raw planes output
loader_yuv_planes = SlipstreamLoader(
    dataset, batch_size=8, shuffle=False,
    image_format="yuv420",
    pipelines={'image': [DecodeYUVPlanes()]},
    exclude_fields=['path'],
    verbose=False,
)

batch = next(iter(loader_yuv_planes))
planes = batch['image']  # list of (Y, U, V) tuples
y, u, v = planes[0]
print(f"DecodeYUVPlanes: {len(planes)} images")
print(f"  Y shape: {y.shape} (full resolution)")
print(f"  U shape: {u.shape} (half resolution)")
print(f"  V shape: {v.shape} (half resolution)")
loader_yuv_planes.shutdown()

### Visual comparison: RGB vs YUV channels

In [None]:
import matplotlib.pyplot as plt
from slipstream import DecodeRandomResizedCrop, DecodeCenterCrop, DecodeOnly

# Get same image as RGB and as YUV full-res
loader_rgb = SlipstreamLoader(
    dataset, batch_size=1, shuffle=False,
    image_format="yuv420",
    pipelines={'image': [DecodeOnly()]},
    exclude_fields=['path'], verbose=False,
)
loader_yuv_vis = SlipstreamLoader(
    dataset, batch_size=1, shuffle=False,
    image_format="yuv420",
    pipelines={'image': [DecodeYUVFullRes()]},
    exclude_fields=['path'], verbose=False,
)

rgb_batch = next(iter(loader_rgb))
yuv_batch = next(iter(loader_yuv_vis))

rgb_img = rgb_batch['image'][0]  # [H, W, 3]
yuv_img = yuv_batch['image'][0]  # [H, W, 3] — variable size, not cropped

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
axes[0].imshow(rgb_img)
axes[0].set_title('RGB (CenterCrop 224)')
axes[1].imshow(yuv_img[:, :, 0], cmap='gray')
axes[1].set_title('Y (luma)')
axes[2].imshow(yuv_img[:, :, 1], cmap='gray')
axes[2].set_title('U (chroma)')
axes[3].imshow(yuv_img[:, :, 2], cmap='gray')
axes[3].set_title('V (chroma)')
axes[4].imshow(yuv_img)
axes[4].set_title('YUV as-is (false color)')
for ax in axes:
    ax.axis('off')
plt.suptitle('RGB vs YUV channel visualization', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

loader_rgb.shutdown()
loader_yuv_vis.shutdown()

### YUV→RGB round-trip verification

Manually convert YUV full-res back to RGB and compare against the direct RGB decode.

In [None]:
import numpy as np
from slipstream.decoders.yuv420_decoder import YUV420NumbaBatchDecoder
from slipstream.cache import load_yuv420_cache

def yuv_to_rgb_bt601(yuv: np.ndarray) -> np.ndarray:
    """Convert full-res YUV (BT.601) to RGB."""
    y = yuv[:, :, 0].astype(np.float32)
    u = yuv[:, :, 1].astype(np.float32) - 128.0
    v = yuv[:, :, 2].astype(np.float32) - 128.0
    r = np.clip(y + 1.402 * v, 0, 255).astype(np.uint8)
    g = np.clip(y - 0.344136 * u - 0.714136 * v, 0, 255).astype(np.uint8)
    b = np.clip(y + 1.772 * u, 0, 255).astype(np.uint8)
    return np.stack([r, g, b], axis=-1)

# Load raw YUV420 data and decode both ways with the same decoder
storage = load_yuv420_cache(cache.cache_dir, "image")
indices = np.arange(8, dtype=np.int64)
batch_data = storage.load_batch(indices)

decoder = YUV420NumbaBatchDecoder(num_threads=4)

rgb_images = decoder.decode_batch(
    batch_data['data'], batch_data['sizes'],
    batch_data['heights'], batch_data['widths'],
)
yuv_images = decoder.decode_batch_yuv_fullres(
    batch_data['data'], batch_data['sizes'],
    batch_data['heights'], batch_data['widths'],
)

for i in range(len(rgb_images)):
    rgb_direct = rgb_images[i]
    rgb_roundtrip = yuv_to_rgb_bt601(yuv_images[i])
    diff = np.abs(rgb_direct.astype(np.int16) - rgb_roundtrip.astype(np.int16))
    print(f"  Image {i}: shape={rgb_direct.shape}, max diff={diff.max()}, mean diff={diff.mean():.2f}")

print("\nExpected: max diff ≤ 1 (fixed-point rounding), mean ≈ 0")
decoder.shutdown()

## 3. Fast S3 Sync

The `sync_s3_dataset` utility uses `s5cmd sync` for fast parallel S3→local copies.
Also available as `presync_s3=True` on `SlipstreamLoader`.

In [None]:
import shutil

# Check if s5cmd is available
if shutil.which("s5cmd"):
    print("s5cmd found ✓")
else:
    print("s5cmd not found — install with: brew install peak/tap/s5cmd")
    print("Skipping S3 sync test.")

In [None]:
from slipstream import sync_s3_dataset
from slipstream.s3_sync import _deterministic_local_dir
from slipstream.dataset import get_default_cache_dir
from pathlib import Path

# Show where data would be synced to (deterministic path)
local_dir = _deterministic_local_dir(LITDATA_VAL_PATH, get_default_cache_dir())
print(f"Sync target: {local_dir}")
print(f"Exists: {local_dir.exists()}")
if local_dir.exists():
    n_files = sum(1 for _ in local_dir.iterdir() if _.is_file())
    print(f"Files already present: {n_files}")

In [None]:
# Uncomment to actually run the sync (downloads ~7GB for ImageNet val)
local_path = sync_s3_dataset(LITDATA_VAL_PATH)
print(f"Synced to: {local_path}")

In [None]:
# Alternative: use presync_s3 param on the loader
# loader = SlipstreamLoader(
#     dataset, batch_size=256,
#     presync_s3=True,
#     pipelines={'image': [DecodeRandomResizedCrop(224)]},
# )