# Multi-Crop Named Fields

This notebook demonstrates `MultiRandomResizedCrop` and `MultiCropPipeline` —
the new Phase 5c pipeline classes for SSL multi-crop training with:

1. **Named crop fields** — each crop gets a top-level batch key (e.g., `batch["global_0"]`)
2. **Per-crop parameters** — different sizes, scales, ratios, and seeds per crop
3. **Mixed-size crops** — e.g., 2 global 224px + 4 local 96px (DINO-style)
4. **Per-crop augmentation** via `MultiCropPipeline`
5. **Comparison** with the original `MultiCropRandomResizedCrop`

In [None]:
LITDATA_VAL_PATH = "s3://visionlab-datasets/imagenet1k/pre-processed/s256-l512-jpgbytes-q100-streaming/val/"

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from slipstream import SlipstreamDataset

dataset = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=False,
)
print(f"Dataset: {len(dataset):,} samples")

In [None]:
def show_batch(images, title="", nrow=8, figsize=None):
    """Display a grid of images from a [B, C, H, W] uint8 tensor."""
    if isinstance(images, torch.Tensor):
        images = images.cpu()
    n = min(len(images), nrow * 2)
    ncols = min(n, nrow)
    nrows = (n + ncols - 1) // ncols
    if figsize is None:
        figsize = (ncols * 1.8, nrows * 1.8)
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    if nrows == 1:
        axes = [axes] if ncols == 1 else list(axes)
        axes = [axes]
    for i in range(nrows):
        for j in range(ncols):
            idx = i * ncols + j
            ax = axes[i][j]
            if idx < n:
                img = images[idx].permute(1, 2, 0).numpy()
                ax.imshow(img)
            ax.axis('off')
    if title:
        fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

## 1. Basic usage: 2 named crops at the same size

Equivalent to `MultiCropRandomResizedCrop(num_crops=2, size=224)` but with
named output fields instead of a list.

In [None]:
from slipstream import SlipstreamLoader, MultiRandomResizedCrop

loader_2x = SlipstreamLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    seed=42,
    pipelines={'image': [
        MultiRandomResizedCrop({
            "view_0": dict(size=224, seed=42),
            "view_1": dict(size=224, seed=43),
        }),
    ]},
    exclude_fields=['path'],
)

batch = next(iter(loader_2x))
print(f"Batch keys: {list(batch.keys())}")
print(f"view_0 shape: {batch['view_0'].shape}, dtype: {batch['view_0'].dtype}")
print(f"view_1 shape: {batch['view_1'].shape}, dtype: {batch['view_1'].dtype}")
loader_2x.shutdown()

In [None]:
n = min(6, batch['view_0'].shape[0])
fig, axes = plt.subplots(2, n, figsize=(n * 2.2, 4.6))
for i in range(n):
    for row, key in enumerate(['view_0', 'view_1']):
        axes[row][i].imshow(batch[key][i].permute(1, 2, 0).numpy())
        axes[row][i].axis('off')
        if i == 0:
            axes[row][i].set_ylabel(key, fontsize=11)
fig.suptitle('MultiRandomResizedCrop: 2 named views @ 224px', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
from slipstream import SlipstreamLoader, MultiRandomResizedCrop

loader_2x = SlipstreamLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    seed=42,
    pipelines={'image': [
        MultiRandomResizedCrop({
            "view_0": dict(size=98, scale=[0.05, 0.30], seed=42),
            "view_1": dict(size=98, scale=[0.05, 0.30], seed=43),
        }),
    ]},
    exclude_fields=['path'],
)

batch = next(iter(loader_2x))
print(f"Batch keys: {list(batch.keys())}")
print(f"view_0 shape: {batch['view_0'].shape}, dtype: {batch['view_0'].dtype}")
print(f"view_1 shape: {batch['view_1'].shape}, dtype: {batch['view_1'].dtype}")
loader_2x.shutdown()

In [None]:
n = min(6, batch['view_0'].shape[0])
fig, axes = plt.subplots(2, n, figsize=(n * 2.2, 4.6))
for i in range(n):
    for row, key in enumerate(['view_0', 'view_1']):
        axes[row][i].imshow(batch[key][i].permute(1, 2, 0).numpy())
        axes[row][i].axis('off')
        if i == 0:
            axes[row][i].set_ylabel(key, fontsize=11)
fig.suptitle('MultiRandomResizedCrop: 2 named views @ 96px', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 2. DINO-style: 2 global (224) + 4 local (96) crops

Mixed-size crops with different scale ranges — the key new capability.
All 6 crops are decoded from the same JPEG (decode-once fusion).

In [None]:
loader_dino = SlipstreamLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    seed=42,
    pipelines={'image': [
        MultiRandomResizedCrop({
            "global_0": dict(size=224, scale=(0.4, 1.0), seed=42),
            "global_1": dict(size=224, scale=(0.4, 1.0), seed=43),
            "local_0":  dict(size=96,  scale=(0.05, 0.4), seed=44),
            "local_1":  dict(size=96,  scale=(0.05, 0.4), seed=45),
            "local_2":  dict(size=96,  scale=(0.05, 0.4), seed=46),
            "local_3":  dict(size=96,  scale=(0.05, 0.4), seed=47),
        }),
    ]},
    exclude_fields=['path'],
)

batch = next(iter(loader_dino))
crop_keys = [k for k in batch.keys() if k.startswith(('global_', 'local_'))]
print(f"Batch keys: {list(batch.keys())}")
for k in crop_keys:
    print(f"  {k}: {batch[k].shape}")
loader_dino.shutdown()

In [None]:
# Show all 6 views for the first 4 images
n_images = min(6, batch['global_0'].shape[0])
crop_names = ['global_0', 'global_1', 'local_0', 'local_1', 'local_2', 'local_3']
n_crops = len(crop_names)

fig, axes = plt.subplots(n_crops, n_images, figsize=(n_images * 2.5, n_crops * 2.5))
for row, name in enumerate(crop_names):
    crop = batch[name]
    for col in range(n_images):
        img = crop[col].permute(1, 2, 0).numpy()
        axes[row][col].imshow(img)
        axes[row][col].axis('off')
        if col == 0:
            size = crop.shape[-1]
            axes[row][col].set_ylabel(f"{name}\n({size}px)", fontsize=10)
        if row == 0:
            axes[row][col].set_title(f"Image {col}", fontsize=10)
fig.suptitle('DINO-style: 2 global (224px) + 4 local (96px) crops per image',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 3. Per-crop augmentation with MultiCropPipeline

`MultiCropPipeline` applies different transform chains to each named crop.
Here we normalize the globals and leave the locals as uint8 to show the
composability.

In [None]:
from slipstream import MultiCropPipeline, Normalize

loader_pipe = SlipstreamLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    seed=42,
    pipelines={'image': [
        MultiRandomResizedCrop({
            "global_0": dict(size=224, scale=(0.4, 1.0)),
            "local_0":  dict(size=96,  scale=(0.05, 0.4)),
        }),
        MultiCropPipeline({
            "global_0": [Normalize()],  # float32, normalized
            "local_0":  [],             # uint8, untouched
        }),
    ]},
    exclude_fields=['path'],
)

batch = next(iter(loader_pipe))
print(f"global_0: shape={batch['global_0'].shape}, dtype={batch['global_0'].dtype}, "
      f"range=[{batch['global_0'].min():.3f}, {batch['global_0'].max():.3f}]")
print(f"local_0:  shape={batch['local_0'].shape}, dtype={batch['local_0'].dtype}, "
      f"range=[{batch['local_0'].min()}, {batch['local_0'].max()}]")
loader_pipe.shutdown()

## 4. Seed reproducibility

Per-crop seeds produce deterministic crops across runs.

In [None]:
def get_dino_batch(seed):
    loader = SlipstreamLoader(
        dataset,
        batch_size=8,
        shuffle=True,
        seed=seed,
        pipelines={'image': [
            MultiRandomResizedCrop({
                "global_0": dict(size=224, scale=(0.4, 1.0), seed=100),
                "local_0":  dict(size=96,  scale=(0.05, 0.4), seed=200),
            }),
        ]},
        exclude_fields=['path'],
        verbose=False,
    )
    batch = next(iter(loader))
    loader.shutdown()
    return batch

batch_a = get_dino_batch(seed=42)
batch_b = get_dino_batch(seed=42)
batch_c = get_dino_batch(seed=99)

same_global = torch.equal(batch_a['global_0'], batch_b['global_0'])
same_local = torch.equal(batch_a['local_0'], batch_b['local_0'])
same_indices = torch.equal(batch_a['_indices'], batch_b['_indices'])
print(f"seed=42 vs seed=42: same indices={same_indices}, same global={same_global}, same local={same_local}")

diff_indices = not torch.equal(batch_a['_indices'], batch_c['_indices'])
print(f"seed=42 vs seed=99: different indices={diff_indices}")

In [None]:
fig, axes = plt.subplots(3, 6, figsize=(11, 5.5))
for row, (b, label) in enumerate([
    (batch_a, 'seed=42 (run 1)'),
    (batch_b, 'seed=42 (run 2)'),
    (batch_c, 'seed=99'),
]):
    for i in range(6):
        axes[row][i].imshow(b['global_0'][i].permute(1, 2, 0).numpy())
        axes[row][i].axis('off')
    axes[row][0].set_ylabel(label, fontsize=10)
fig.suptitle('Seed reproducibility: same seed \u2192 identical crops', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Comparison with original MultiCropRandomResizedCrop

The original returns `batch["image"]` as a **list** of tensors.
The new API returns **named top-level keys** — less ambiguous downstream.

In [None]:
from slipstream import MultiCropRandomResizedCrop

# Original API
loader_old = SlipstreamLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    seed=42,
    pipelines={'image': [MultiCropRandomResizedCrop(num_crops=2, size=224)]},
    exclude_fields=['path'],
)
batch_old = next(iter(loader_old))
loader_old.shutdown()

# New API
loader_new = SlipstreamLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    seed=42,
    pipelines={'image': [
        MultiRandomResizedCrop({
            "view_0": dict(size=224),
            "view_1": dict(size=224),
        }),
    ]},
    exclude_fields=['path'],
)
batch_new = next(iter(loader_new))
loader_new.shutdown()

print("Original API:")
print(f"  batch['image'] type: {type(batch_old['image'])}")
print(f"  batch['image'][0] shape: {batch_old['image'][0].shape}")

print("\nNew API:")
print(f"  batch keys: {[k for k in batch_new.keys() if k != '_indices']}")
print(f"  batch['view_0'] shape: {batch_new['view_0'].shape}")

In [None]:
# Side-by-side: old (list) vs new (named)
n = min(6, batch_old['image'][0].shape[0])
fig, axes = plt.subplots(4, n, figsize=(n * 2.2, 9))

for i in range(n):
    axes[0][i].imshow(batch_old['image'][0][i].permute(1, 2, 0).numpy())
    axes[0][i].axis('off')
    axes[1][i].imshow(batch_old['image'][1][i].permute(1, 2, 0).numpy())
    axes[1][i].axis('off')
    axes[2][i].imshow(batch_new['view_0'][i].permute(1, 2, 0).numpy())
    axes[2][i].axis('off')
    axes[3][i].imshow(batch_new['view_1'][i].permute(1, 2, 0).numpy())
    axes[3][i].axis('off')

axes[0][0].set_ylabel('old [0]', fontsize=10)
axes[1][0].set_ylabel('old [1]', fontsize=10)
axes[2][0].set_ylabel('new view_0', fontsize=10)
axes[3][0].set_ylabel('new view_1', fontsize=10)
fig.suptitle('Original (list) vs New (named) — same images, different random crops',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()