# Transform Visual Verification

Visual comparison of slipstream GPU batch transforms vs torchvision v2 equivalents.

1. **Deterministic transforms** — exact match expected (difference heatmap)
2. **Geometric transforms** — qualitative side-by-side (different RNG)
3. **Color transforms** — qualitative side-by-side
4. **Effect transforms** — qualitative side-by-side
5. **Slipstream-only transforms** — before/after
6. **SSL replay demo** — `apply_last` replays identical params

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms import v2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

## Section 0: Helpers

In [None]:
def show_batch(images, title="", nrow=8, figsize=None):
    """Display a grid of images from a [B, C, H, W] float32 [0,1] tensor."""
    if isinstance(images, torch.Tensor):
        images = images.detach().cpu().clamp(0, 1)
    n = min(len(images), nrow * 2)
    ncols = min(n, nrow)
    nrows = (n + ncols - 1) // ncols
    if figsize is None:
        figsize = (ncols * 1.8, nrows * 1.8)
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    if nrows == 1 and ncols == 1:
        axes = [[axes]]
    elif nrows == 1:
        axes = [list(axes)]
    elif ncols == 1:
        axes = [[ax] for ax in axes]
    for i in range(nrows):
        for j in range(ncols):
            idx = i * ncols + j
            ax = axes[i][j]
            if idx < n:
                img = images[idx].permute(1, 2, 0).numpy()
                ax.imshow(img)
            ax.axis('off')
    if title:
        fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def show_comparison(original, ss_out, tv_out, title, nrow=4):
    """3-panel: Original | Slipstream | Torchvision."""
    original = original.detach().cpu().clamp(0, 1)
    ss_out = ss_out.detach().cpu().clamp(0, 1)
    tv_out = tv_out.detach().cpu().clamp(0, 1)
    n = min(len(original), nrow)
    fig, axes = plt.subplots(3, n, figsize=(n * 2, 6.5))
    for i in range(n):
        axes[0][i].imshow(original[i].permute(1, 2, 0).numpy())
        axes[0][i].axis('off')
        axes[1][i].imshow(ss_out[i].permute(1, 2, 0).numpy())
        axes[1][i].axis('off')
        axes[2][i].imshow(tv_out[i].permute(1, 2, 0).numpy())
        axes[2][i].axis('off')
    axes[0][0].set_ylabel('Original', fontsize=11)
    axes[1][0].set_ylabel('Slipstream', fontsize=11)
    axes[2][0].set_ylabel('Torchvision', fontsize=11)
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def show_comparison_with_diff(original, ss_out, tv_out, title, nrow=4):
    """4-panel: Original | Slipstream | Torchvision | Difference heatmap.

    Diff heatmap uses fixed vmax=1.0 so only meaningful differences are visible.
    """
    original = original.detach().cpu().clamp(0, 1)
    ss_out = ss_out.detach().cpu().float()
    tv_out = tv_out.detach().cpu().float()
    diff = (ss_out - tv_out).abs()
    n = min(len(original), nrow)
    fig, axes = plt.subplots(4, n, figsize=(n * 2, 8.5))
    for i in range(n):
        axes[0][i].imshow(original[i].clamp(0, 1).permute(1, 2, 0).numpy())
        axes[0][i].axis('off')
        axes[1][i].imshow(ss_out[i].clamp(0, 1).permute(1, 2, 0).numpy())
        axes[1][i].axis('off')
        axes[2][i].imshow(tv_out[i].clamp(0, 1).permute(1, 2, 0).numpy())
        axes[2][i].axis('off')
        # Mean diff across channels, fixed scale [0, 1] so zero diff = black
        d = diff[i].mean(0).numpy()
        im = axes[3][i].imshow(d, cmap='hot', vmin=0, vmax=1.0)
        axes[3][i].axis('off')
    axes[0][0].set_ylabel('Original', fontsize=11)
    axes[1][0].set_ylabel('Slipstream', fontsize=11)
    axes[2][0].set_ylabel('Torchvision', fontsize=11)
    axes[3][0].set_ylabel('|Diff|', fontsize=11)
    max_diff = diff.max().item()
    mean_diff = diff.mean().item()
    fig.suptitle(f"{title}\nmax |diff|={max_diff:.6f}, mean |diff|={mean_diff:.6f}", fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.show()


def show_before_after(original, transformed, title, nrow=4):
    """2-panel: Original | Transformed."""
    original = original.detach().cpu().clamp(0, 1)
    # transformed may have != 3 channels, handle gracefully
    transformed = transformed.detach().cpu()
    n = min(len(original), nrow)
    fig, axes = plt.subplots(2, n, figsize=(n * 2, 4.5))
    for i in range(n):
        axes[0][i].imshow(original[i].permute(1, 2, 0).numpy())
        axes[0][i].axis('off')
        t = transformed[i]
        if t.shape[0] == 3:
            axes[1][i].imshow(t.clamp(0, 1).permute(1, 2, 0).numpy())
        elif t.shape[0] == 1:
            axes[1][i].imshow(t[0].numpy(), cmap='gray')
        else:
            # Show first 3 channels or grayscale of first channel
            axes[1][i].imshow(t[0].numpy(), cmap='viridis')
        axes[1][i].axis('off')
    axes[0][0].set_ylabel('Original', fontsize=11)
    axes[1][0].set_ylabel('Transformed', fontsize=11)
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def load_test_batch(n=8):
    """Load real images from ImageNet val via SlipstreamLoader, return [B,C,H,W] float32 [0,1]."""
    from slipstream import SlipstreamDataset, SlipstreamLoader, CenterCrop
    LITDATA_VAL_PATH = "s3://visionlab-datasets/imagenet1k/pre-processed/s256-l512-jpgbytes-q100-streaming/val/"
    dataset = SlipstreamDataset(remote_dir=LITDATA_VAL_PATH, decode_images=False)
    loader = SlipstreamLoader(
        dataset, batch_size=n, shuffle=False,
        pipelines={'image': [CenterCrop(224)]},
        exclude_fields=['path'], verbose=False,
    )
    batch = next(iter(loader))
    loader.shutdown()
    # Convert uint8 [0,255] -> float32 [0,1]
    images = batch['image'].float() / 255.0
    return images

print("Helpers defined.")

In [None]:
# Load test images (or use synthetic if no dataset available)
try:
    images = load_test_batch(n=8)
    print(f"Loaded real images: {images.shape}, dtype={images.dtype}")
except Exception as e:
    print(f"Could not load real images ({e}), using synthetic data")
    images = torch.rand(8, 3, 224, 224, dtype=torch.float32)

images = images.to(DEVICE)
show_batch(images, title="Test Images")

## Section 1: Deterministic Transforms (exact match expected)

In [None]:
from slipstream.transforms import Normalize

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

ss_norm = Normalize(MEAN, STD, device=DEVICE)
tv_norm = v2.Normalize(MEAN, STD)

ss_out = ss_norm(images.clone())
# TV per-sample
tv_out = torch.stack([tv_norm(images[i].cpu()) for i in range(len(images))]).to(DEVICE)

show_comparison_with_diff(images, ss_out, tv_out, "Normalize vs v2.Normalize")

In [None]:
from slipstream.transforms import ToGrayscale

ss_gray = ToGrayscale(num_output_channels=3)
tv_gray = v2.Grayscale(3)

ss_out = ss_gray(images.clone())
tv_out = torch.stack([tv_gray(images[i].cpu()) for i in range(len(images))]).to(DEVICE)

show_comparison_with_diff(images, ss_out, tv_out, "ToGrayscale vs v2.Grayscale(3)")

## Section 2: Geometric Transforms (fixed-parameter comparison)

Force identical parameters in both SS and TV to enable exact comparison with difference heatmaps.

In [None]:
from slipstream.transforms import RandomHorizontalFlip

# p=1.0 is deterministic — all images flipped
ss_flip = RandomHorizontalFlip(p=1.0, device=DEVICE)
tv_flip = v2.RandomHorizontalFlip(p=1.0)

ss_out = ss_flip(images.clone())
tv_out = torch.stack([tv_flip(images[i].cpu()) for i in range(len(images))]).to(DEVICE)

show_comparison_with_diff(images, ss_out, tv_out, "RandomHorizontalFlip (p=1.0, deterministic)")

In [None]:
# v2.RandomRotation?

In [None]:
from slipstream.transforms import RandomRotate
from torchvision.transforms import functional as TF

B = len(images)

# Test two fixed angles: 15° and 30°
for angle in [15.0, 30.0]:
    ss_rot = RandomRotate(p=1.0, max_deg=angle, device=DEVICE)
    deg = torch.full((B,), angle, device=DEVICE)
    ss_out = ss_rot(images.clone(), deg=deg)
    tv_out = torch.stack([TF.rotate(images[i].cpu(), angle) for i in range(B)]).to(DEVICE)

    show_comparison_with_diff(images, ss_out, tv_out, f"RandomRotate — fixed {angle}°")

In [None]:
from slipstream.transforms import RandomZoom

B = len(images)

# SS zoom=0.5 means "zoom in 2x" (magnify), TV scale=0.5 means "shrink to half".
# So TV scale = 1/zoom to match SS behavior.
for zoom_val in [0.5, 0.75]:
    ss_zoom = RandomZoom(p=1.0, zoom=(zoom_val, zoom_val), device=DEVICE)
    zoom_t = torch.full((B,), zoom_val, device=DEVICE)
    ss_out = ss_zoom(images.clone(), zoom=zoom_t)
    tv_scale = 1.0 / zoom_val  # invert: SS zoom=0.5 → TV scale=2.0
    tv_out = torch.stack([
        TF.affine(images[i].cpu(), angle=0, translate=[0, 0], scale=tv_scale, shear=0)
        for i in range(B)
    ]).to(DEVICE)

    show_comparison_with_diff(images, ss_out, tv_out, f"RandomZoom(zoom={zoom_val}) vs TF.affine(scale={tv_scale:.2f})")

## Section 3: Color Transforms

Deterministic where possible (fixed params); qualitative where algorithms differ (YIQ vs RGB).

In [None]:
from slipstream.transforms import RandomGrayscale

# p=1.0: all images converted to grayscale (deterministic)
ss_rg = RandomGrayscale(p=1.0, num_output_channels=3, device=DEVICE)
tv_rg = v2.Grayscale(3)  # Use deterministic Grayscale, not RandomGrayscale

ss_out = ss_rg(images.clone())
tv_out = torch.stack([tv_rg(images[i].cpu()) for i in range(len(images))]).to(DEVICE)

show_comparison_with_diff(images, ss_out, tv_out, "RandomGrayscale(p=1.0) vs v2.Grayscale(3) — should match")

# Also show p=0.5 batch to demonstrate per-image randomization
ss_rg_rand = RandomGrayscale(p=0.5, num_output_channels=3, device=DEVICE)
images16 = images[:4].repeat(4, 1, 1, 1)
ss_out_rand = ss_rg_rand(images16.clone())
show_batch(ss_out_rand, title="RandomGrayscale(p=0.5) — per-image randomization demo", nrow=8)

In [None]:
from slipstream.transforms import RandomBrightness

B = len(images)

# Fixed brightness factors
for factor in [0.6, 1.4]:
    ss_br = RandomBrightness(p=1.0, scale_range=(factor, factor), device=DEVICE)
    ss_out = ss_br(images.clone())
    tv_out = torch.stack([TF.adjust_brightness(images[i].cpu(), factor) for i in range(B)]).to(DEVICE)

    show_comparison_with_diff(images, ss_out, tv_out, f"RandomBrightness — fixed factor={factor}")

In [None]:
from slipstream.transforms import RandomContrast

B = len(images)

# Fixed contrast factors
for factor in [0.6, 1.4]:
    ss_ct = RandomContrast(p=1.0, scale_range=(factor, factor), device=DEVICE)
    ss_out = ss_ct(images.clone())
    tv_out = torch.stack([TF.adjust_contrast(images[i].cpu(), factor) for i in range(B)]).to(DEVICE)

    show_comparison_with_diff(images, ss_out, tv_out, f"RandomContrast — fixed factor={factor}")

In [None]:
from slipstream.transforms import ColorJitter as SSColorJitter

B = len(images)

# HSV ColorJitter — should match torchvision's ColorJitter
# SS params: hue, saturation, value (=brightness), contrast
# TV params: brightness, contrast, saturation, hue

# Test individual components with fixed values to isolate differences
# Use per-image jitter (RandomColorJitter) with fixed params via kwargs

# 1. Hue only — fixed value
from slipstream.transforms import RandomColorJitter as SSRandomColorJitter

for hue_val in [0.1, -0.2]:
    ss_jit = SSColorJitter(p=1.0, hue=0.5, device=DEVICE)
    h = torch.full((B,), hue_val, device=DEVICE)
    ss_out = ss_jit(images.clone(), h=h)
    tv_out = torch.stack([TF.adjust_hue(images[i].cpu(), hue_val) for i in range(B)]).to(DEVICE)
    show_comparison_with_diff(images, ss_out, tv_out, f"HSV ColorJitter — hue only, fixed h={hue_val}")

# 2. Saturation only — fixed value
for sat_val in [0.5, 1.5]:
    ss_jit = SSColorJitter(p=1.0, saturation=2.0, device=DEVICE)
    s = torch.full((B,), sat_val, device=DEVICE)
    ss_out = ss_jit(images.clone(), s=s)
    tv_out = torch.stack([TF.adjust_saturation(images[i].cpu(), sat_val) for i in range(B)]).to(DEVICE)
    show_comparison_with_diff(images, ss_out, tv_out, f"HSV ColorJitter — saturation only, fixed s={sat_val}")

# 3. Value/Brightness only — fixed value
for val_val in [0.7, 1.3]:
    ss_jit = SSColorJitter(p=1.0, value=2.0, device=DEVICE)
    v = torch.full((B,), val_val, device=DEVICE)
    ss_out = ss_jit(images.clone(), v=v)
    tv_out = torch.stack([TF.adjust_brightness(images[i].cpu(), val_val) for i in range(B)]).to(DEVICE)
    show_comparison_with_diff(images, ss_out, tv_out, f"HSV ColorJitter — value(brightness) only, fixed v={val_val}")

# 4. Contrast only — fixed value
for con_val in [0.6, 1.4]:
    ss_jit = SSColorJitter(p=1.0, contrast=2.0, device=DEVICE)
    c = torch.full((B,), con_val, device=DEVICE)
    ss_out = ss_jit(images.clone(), c=c)
    tv_out = torch.stack([TF.adjust_contrast(images[i].cpu(), con_val) for i in range(B)]).to(DEVICE)
    show_comparison_with_diff(images, ss_out, tv_out, f"HSV ColorJitter — contrast only, fixed c={con_val}")

In [None]:
from slipstream.transforms import RandomColorJitterYIQ

# YIQ jitter — flagged for review, qualitative comparison only
# TODO: review YIQ jitter algorithm for correctness
ss_jit = RandomColorJitterYIQ(p=1.0, hue=20, saturation=0.3, value=0.3, brightness=0.3, contrast=0.3, device=DEVICE)
tv_jit = v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=20/360)

ss_out = ss_jit(images.clone())
tv_out = torch.stack([tv_jit(images[i].cpu()) for i in range(len(images))]).to(DEVICE)

show_comparison(images, ss_out, tv_out, "RandomColorJitterYIQ vs v2.ColorJitter — YIQ vs RGB (qualitative, YIQ flagged for review)")

## Section 4: Effect Transforms (fixed-parameter comparison)

In [None]:
from slipstream.transforms import RandomGaussianBlur

B = len(images)

# Fixed sigma — use p=1.0 and a single sigma value via num_sigmas=1
for sigma in [0.5, 2.0]:
    ss_blur = RandomGaussianBlur(p=1.0, kernel_size=7, sigma_range=(sigma, sigma), num_sigmas=1, device=DEVICE)
    ss_out = ss_blur(images.clone())
    tv_out = torch.stack([TF.gaussian_blur(images[i].cpu(), kernel_size=7, sigma=sigma) for i in range(B)]).to(DEVICE)

    show_comparison_with_diff(images, ss_out, tv_out, f"GaussianBlur — fixed sigma={sigma}, kernel=7")

In [None]:
from slipstream.transforms import RandomSolarization

B = len(images)

# p=1.0 with fixed threshold — deterministic
for threshold in [0.3, 0.5, 0.7]:
    ss_sol = RandomSolarization(p=1.0, threshold=threshold, device=DEVICE)
    ss_out = ss_sol(images.clone())
    tv_out = torch.stack([TF.solarize(images[i].cpu(), threshold) for i in range(B)]).to(DEVICE)

    show_comparison_with_diff(images, ss_out, tv_out, f"Solarization — fixed threshold={threshold}")

## Section 5: Slipstream-Only Transforms (before/after)

No torchvision equivalent — showing Original → Transformed.

In [None]:
from slipstream.transforms import RandomPatchShuffle

ss_ps = RandomPatchShuffle(sizes=0.25, p=1.0, img_size=224, device=DEVICE)
ss_out = ss_ps(images.clone())

show_before_after(images, ss_out, "RandomPatchShuffle(sizes=0.25)")

In [None]:
# CircularMask?

In [None]:
from slipstream.transforms import CircularMask

ss_cm = CircularMask(output_size=224, blur_span=8.0, device=DEVICE)
ss_out = ss_cm(images.clone())

show_before_after(images, ss_out, "CircularMask(224)")

In [None]:
from slipstream.transforms import FixedOpticalDistortion

ss_barrel = FixedOpticalDistortion(output_size=(224, 224), distortion=-0.5, device=DEVICE)
ss_pincushion = FixedOpticalDistortion(output_size=(224, 224), distortion=0.5, device=DEVICE)

barrel_out = ss_barrel(images.clone())
pincushion_out = ss_pincushion(images.clone())

show_before_after(images, barrel_out, "FixedOpticalDistortion(distortion=-0.5, barrel)")
show_before_after(images, pincushion_out, "FixedOpticalDistortion(distortion=+0.5, pincushion)")

In [None]:
from slipstream.transforms import RandomRotate, CircularMask, FixedOpticalDistortion, Compose

# Pipeline trick: rotate + circular mask + barrel distortion
# The barrel distortion effectively pads the corners, and the circular mask
# hides the black borders from rotation. Result: rotation without visible artifacts.
resolution = 224

pipeline = Compose([
    RandomRotate(p=1.0, max_deg=30, x=0.5, y=0.5, device=DEVICE),
    # CircularMask(resolution, blur_span=6, device=DEVICE),
    FixedOpticalDistortion(resolution, distortion=-.5, device=DEVICE),
])

# Show multiple runs to see different rotation angles
n_runs = 4
all_outs = []
for _ in range(n_runs):
    out = pipeline(images[:4].clone())
    all_outs.append(out)

# Display: original row, then each pipeline output
n = min(4, len(images))
fig, axes = plt.subplots(1 + n_runs, n, figsize=(n * 2.2, (1 + n_runs) * 2.2))
for i in range(n):
    axes[0][i].imshow(images[i].detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
axes[0][0].set_ylabel('Original', fontsize=10)

for run_idx, out in enumerate(all_outs):
    for i in range(n):
        axes[run_idx + 1][i].imshow(out[i].detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy())
        axes[run_idx + 1][i].axis('off')
    axes[run_idx + 1][0].set_ylabel(f'Run {run_idx + 1}', fontsize=10)

fig.suptitle(
    'Rotate + CircularMask + BarrelDistortion pipeline\n'
    'Random rotation without visible black border artifacts',
    fontsize=13, fontweight='bold',
)
plt.tight_layout()
plt.show()

In [None]:
from slipstream.transforms import RandomRotateObject

ss_ro = RandomRotateObject(p=1.0, max_deg=30, scale=(1.0, 1.5), device=DEVICE)
ss_out = ss_ro(images.clone())

show_before_after(images, ss_out, "RandomRotateObject(30°, scale=1.0-1.5)")

In [None]:
from slipstream.transforms import SRGBToLMS

ss_lms = SRGBToLMS()
lms_out = ss_lms(images.clone())

# Show individual L, M, S channels
n = min(4, len(images))
fig, axes = plt.subplots(4, n, figsize=(n * 2, 8.5))
lms_cpu = lms_out.detach().cpu()
for i in range(n):
    axes[0][i].imshow(images[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
    for ch, label in enumerate(['L', 'M', 'S']):
        axes[ch+1][i].imshow(lms_cpu[i, ch].numpy(), cmap='gray')
        axes[ch+1][i].axis('off')
        if i == 0:
            axes[ch+1][i].set_ylabel(label, fontsize=12)
axes[0][0].set_ylabel('RGB', fontsize=12)
fig.suptitle('SRGBToLMS — L, M, S channels', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
from slipstream.transforms import RGBToLGN

ss_lgn = RGBToLGN(device=DEVICE)
lgn_out = ss_lgn(images.clone())
print(f"RGBToLGN output: {lgn_out.shape}")  # [B, 5, H, W]

# Show 5 channels for first 4 images
n = min(4, len(images))
channel_names = ['Parvo L-M', 'Parvo M-L', 'Magno ON', 'Magno OFF', 'Konio S-(L+M)']
lgn_cpu = lgn_out.detach().cpu()
fig, axes = plt.subplots(6, n, figsize=(n * 2, 13))
for i in range(n):
    axes[0][i].imshow(images[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
    for ch in range(5):
        axes[ch+1][i].imshow(lgn_cpu[i, ch].numpy(), cmap='gray')
        axes[ch+1][i].axis('off')
        if i == 0:
            axes[ch+1][i].set_ylabel(channel_names[ch], fontsize=9)
axes[0][0].set_ylabel('RGB', fontsize=10)
fig.suptitle('RGBToLGN — 5-channel output', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
from slipstream.transforms import RGBToMagno

ss_magno = RGBToMagno(device=DEVICE)
magno_out = ss_magno(images.clone())
print(f"RGBToMagno output: {magno_out.shape}")  # [B, 2, H, W]

n = min(4, len(images))
magno_cpu = magno_out.detach().cpu()
fig, axes = plt.subplots(3, n, figsize=(n * 2, 6.5))
for i in range(n):
    axes[0][i].imshow(images[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
    axes[1][i].imshow(magno_cpu[i, 0].numpy(), cmap='gray')
    axes[1][i].axis('off')
    axes[2][i].imshow(magno_cpu[i, 1].numpy(), cmap='gray')
    axes[2][i].axis('off')
axes[0][0].set_ylabel('RGB', fontsize=11)
axes[1][0].set_ylabel('Magno ON', fontsize=11)
axes[2][0].set_ylabel('Magno OFF', fontsize=11)
fig.suptitle('RGBToMagno — 2-channel output', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Section 6: SSL Replay & Seed Reproducibility

1. **Seed reproducibility** — same seed produces identical output across separate runs
2. **SSL replay** — `apply_last` replays identical params on a different view

In [None]:
from slipstream.transforms import (
    RandomHorizontalFlip, RandomRotate, RandomGaussianBlur,
    RandomColorJitterYIQ, Compose,
)

# Build the same pipeline twice with the same seed
def make_pipeline(seed):
    return Compose([
        RandomHorizontalFlip(p=0.5, seed=seed, device=DEVICE),
        RandomRotate(p=0.8, max_deg=30, seed=seed, device=DEVICE),
        RandomGaussianBlur(p=0.5, kernel_size=7, sigma_range=(0.1, 2.0), seed=seed, device=DEVICE),
        RandomColorJitterYIQ(p=1.0, hue=20, saturation=0.3, brightness=0.3, seed=seed, device=DEVICE),
    ])

# Run 1: seed=42
pipe_a = make_pipeline(seed=42)
out_a = pipe_a(images.clone())

# Run 2: same seed=42, fresh pipeline
pipe_b = make_pipeline(seed=42)
out_b = pipe_b(images.clone())

# Run 3: different seed
pipe_c = make_pipeline(seed=123)
out_c = pipe_c(images.clone())

# Check
same_seed_match = torch.equal(out_a, out_b)
diff_seed_differ = not torch.equal(out_a, out_c)
print(f"seed=42 run1 vs seed=42 run2: identical={same_seed_match}")
print(f"seed=42 vs seed=123: different={diff_seed_differ}")

n = min(4, len(images))
fig, axes = plt.subplots(4, n, figsize=(n * 2.2, 8.5))
for i in range(n):
    axes[0][i].imshow(images[i].detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
    axes[1][i].imshow(out_a[i].detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy())
    axes[1][i].axis('off')
    axes[2][i].imshow(out_b[i].detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy())
    axes[2][i].axis('off')
    axes[3][i].imshow(out_c[i].detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy())
    axes[3][i].axis('off')
axes[0][0].set_ylabel('Original', fontsize=10)
axes[1][0].set_ylabel('seed=42 (run 1)', fontsize=10)
axes[2][0].set_ylabel('seed=42 (run 2)', fontsize=10)
axes[3][0].set_ylabel('seed=123', fontsize=10)
fig.suptitle(
    f'Seed reproducibility: same seed → identical output (match={same_seed_match})\n'
    f'Different seed → different output (differ={diff_seed_differ})',
    fontsize=13, fontweight='bold',
)
plt.tight_layout()
plt.show()

In [None]:
from slipstream.transforms import RandomColorJitter, RandomColorJitterYIQ

# Two different "views" of the same images (simulated with slight crops)
view1 = images.clone()
view2 = images.clone()  # In practice these would be different crops

# t = RandomColorJitter(p=1.0, hue=20, saturation=0.3, value=0.3, brightness=0.3, contrast=0.3, seed=42, device=DEVICE)
t = RandomColorJitter(p=1.0, hue=.4, saturation=0.3, value=0.3, contrast=0.3, seed=42, device=DEVICE)
view1_aug = t(view1)          # samples new random params
view2_aug = t.apply_last(view2)  # replays same params

# Verify params are the same
diff = (view1_aug - view2_aug).abs()
print(f"Max pixel diff between view1_aug and view2_aug: {diff.max().item():.6f}")
print(f"(Should be 0.0 since same input + same params = same output)")

# Display: View1 Orig -> View1 Aug -> View2 Orig -> View2 Aug
n = min(4, len(images))
fig, axes = plt.subplots(4, n, figsize=(n * 2, 8.5))
for i in range(n):
    axes[0][i].imshow(view1[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
    axes[1][i].imshow(view1_aug[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[1][i].axis('off')
    axes[2][i].imshow(view2[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[2][i].axis('off')
    axes[3][i].imshow(view2_aug[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[3][i].axis('off')
axes[0][0].set_ylabel('View 1', fontsize=11)
axes[1][0].set_ylabel('View 1 Aug', fontsize=11)
axes[2][0].set_ylabel('View 2', fontsize=11)
axes[3][0].set_ylabel('View 2 Aug', fontsize=11)
fig.suptitle('SSL Replay: apply_last() replays identical color jitter\n(same color shift visible on both views)', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
from slipstream.transforms import RandomHorizontalFlip, Compose

# Compose replay: flip + color jitter, replayed identically on view2
aug = Compose([
    RandomHorizontalFlip(p=0.5, seed=42, device=DEVICE),
    # RandomColorJitterYIQ(p=1.0, hue=20, saturation=0.3, seed=42, device=DEVICE),
    RandomColorJitter(p=1.0, hue=.4, saturation=0.3, value=0.3, contrast=0.3, seed=42, device=DEVICE)
])

view1_aug = aug(view1)             # forward pass: samples params
view2_aug = aug(view2, replay=True)  # replay pass: same params

diff = (view1_aug - view2_aug).abs()
print(f"Compose replay max diff: {diff.max().item():.6f} (should be 0.0)")

n = min(4, len(images))
fig, axes = plt.subplots(2, n, figsize=(n * 2, 4.5))
for i in range(n):
    axes[0][i].imshow(view1_aug[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
    axes[1][i].imshow(view2_aug[i].detach().cpu().clamp(0,1).permute(1, 2, 0).numpy())
    axes[1][i].axis('off')
axes[0][0].set_ylabel('View 1', fontsize=11)
axes[1][0].set_ylabel('View 2 (replay)', fontsize=11)
fig.suptitle('Compose replay: Flip + ColorJitter — identical augmentation on both views', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()