# 12 - WebDataset DataLoader

This notebook implements the PyTorch DataLoader for WebDataset (TAR shards) format.

**Format:** TAR-based shards with optional compression
- Reads from TAR shards (train-*.tar[.zst], val-*.tar[.zst])
- Supports multiple variants (different shard sizes and compression)
- Efficient streaming from disk or object storage
- Sequential I/O friendly

**Usage in other notebooks:**
```python
%run ./12_loader_webdataset.ipynb
loader = make_dataloader('cifar10', 'train', batch_size=64, num_workers=4, variant='shard256_none')
```

In [1]:
import os
from pathlib import Path
from typing import Optional
import io

import torch
from torch.utils.data import DataLoader
import webdataset as wds
from PIL import Image

# Load common utilities
%run ./10_common_utils.ipynb

✓ Common utilities loaded successfully

Available functions:
  - set_seed(seed)
  - get_transforms(augment)
  - write_sysinfo(path)
  - time_first_batch(dataloader, device)
  - start_monitor(log_path, interval)
  - stop_monitor(thread, stop_event)
  - append_to_summary(path, row_dict)
  - compute_metrics_from_logs(log_path)
  - get_device()
  - format_bytes(bytes)
  - count_parameters(model)

Constants:
  - STANDARD_TRANSFORM


## WebDataset Decoder Functions

In [2]:
def decode_image(sample):
    """
    Decode image from WebDataset sample.
    
    Args:
        sample: WebDataset sample dict with image bytes
    
    Returns:
        Sample dict with decoded PIL Image
    """
    # Find image key (jpg, jpeg, or png)
    img_key = None
    for key in ['jpg', 'jpeg', 'png']:
        if key in sample:
            img_key = key
            break
    
    if img_key is None:
        raise ValueError(f"No image found in sample keys: {list(sample.keys())}")
    
    # Decode image bytes to PIL Image
    img_bytes = sample[img_key]
    image = Image.open(io.BytesIO(img_bytes)).convert('RGB')
    
    # Replace bytes with PIL Image
    sample['image'] = image
    
    return sample


def decode_label(sample):
    """
    Decode label from WebDataset sample.
    
    Args:
        sample: WebDataset sample dict with label
    
    Returns:
        Sample dict with decoded integer label
    """
    # Decode label from bytes to int
    if 'cls' in sample:
        label_bytes = sample['cls']
        if isinstance(label_bytes, bytes):
            label = int(label_bytes.decode('utf-8'))
        else:
            label = int(label_bytes)
        sample['label'] = label
    
    return sample


def apply_transform(sample, transform):
    """
    Apply transform to image in sample.
    
    Args:
        sample: Sample dict with PIL Image
        transform: Torchvision transform
    
    Returns:
        Tuple of (transformed_image, label)
    """
    image = sample['image']
    label = sample['label']
    
    if transform:
        image = transform(image)
    
    return image, label

## DataLoader Factory Function

In [3]:
# def make_dataloader(
#     dataset: str,
#     split: str,
#     batch_size: int,
#     num_workers: int,
#     pin_memory: bool = True,
#     variant: str = 'shard256_none',
#     shuffle: bool = True,
#     transform=None,
# ) -> DataLoader:
#     """
#     Create a DataLoader for WebDataset format.
    
#     Args:
#         dataset: Dataset name (e.g., 'cifar10', 'imagenet-mini')
#         split: Split name ('train' or 'val')
#         batch_size: Batch size
#         num_workers: Number of data loading workers
#         pin_memory: Whether to pin memory for faster GPU transfer
#         variant: WebDataset variant (e.g., 'shard256_none', 'shard64_zstd')
#         shuffle: Whether to shuffle data (uses buffer for WebDataset)
#         transform: Custom transform (uses STANDARD_TRANSFORM if None)
    
#     Returns:
#         PyTorch DataLoader
#     """
#     # Detect environment
#     IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
#     BASE_DIR = Path('/kaggle/working/format-matters') if IS_KAGGLE else Path('..').resolve()
    
#     # Build path to WebDataset shards
#     wds_dir = BASE_DIR / 'data' / 'built' / dataset / 'webdataset' / variant
    
#     if not wds_dir.exists():
#         raise FileNotFoundError(f"WebDataset directory not found: {wds_dir}")
    
#     # Find shard files for the split
#     # Convert to a proper POSIX path for WebDataset
#     # Find shard files for the split
#     shard_pattern = "file://" + (wds_dir / f"{split}-%06d.tar").as_posix()

    
#     shard_files = sorted(wds_dir.glob(f"{split}-*.tar*"))
#     if not shard_files:
#         raise FileNotFoundError(f"No shards found matching: {shard_pattern}")
    
#     print(f"Found {len(shard_files)} shard(s) for {dataset}/{split} ({variant})")
    
#     # Use standard transform if none provided
#     if transform is None:
#         transform = STANDARD_TRANSFORM
    
#     dataset_obj = wds.WebDataset(shard_pattern)

#     if shuffle:
#         dataset_obj = dataset_obj.shuffle(min(5000, batch_size * 50))
    
#     dataset_obj = (
#         dataset_obj
#         .map(decode_image)
#         .map(decode_label)
#         .map(lambda sample: apply_transform(sample, transform))
#         .batched(batch_size, partial=False)
#     )
    
#     dataloader = wds.WebLoader(
#         dataset_obj,
#         num_workers=num_workers,
#         pin_memory=pin_memory,
#     )
    
    
#     # Unbatch for standard PyTorch interface (returns tensors, not lists)
#     dataloader = dataloader.unbatched().batched(batch_size, collation_fn=torch.utils.data.default_collate)
    
#     return dataloader

In [4]:
def make_dataloader(
    dataset: str,
    split: str,
    batch_size: int,
    num_workers: int,
    pin_memory: bool = True,
    variant: str = "shard256_none",
    shuffle: bool = True,
    transform=None,
) -> DataLoader:
    """
    Create a DataLoader for WebDataset format.

    Works on both Windows and Linux. On Windows, uses an explicit list of shard
    files (wildcards like *.tar are not supported). On Linux, wildcards work normally.

    Args:
        dataset: Dataset name (e.g., 'cifar10', 'imagenet-mini')
        split: Split name ('train' or 'val')
        batch_size: Batch size
        num_workers: Number of data loading workers
        pin_memory: Whether to pin memory for faster GPU transfer
        variant: WebDataset variant (e.g., 'shard256_none', 'shard64_zstd')
        shuffle: Whether to shuffle data (uses buffer for WebDataset)
        transform: Custom transform (uses STANDARD_TRANSFORM if None)

    Returns:
        PyTorch DataLoader
    """
    import webdataset as wds
    import torch
    import os
    from pathlib import Path

    # Detect environment (Kaggle or local)
    IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
    BASE_DIR = (
        Path("/kaggle/working/format-matters")
        if IS_KAGGLE
        else Path("..").resolve()
    )

    # Build path to WebDataset shards
    wds_dir = BASE_DIR / "data" / "built" / dataset / "webdataset" / variant
    if not wds_dir.exists():
        raise FileNotFoundError(f"WebDataset directory not found: {wds_dir}")

    # ✅ Gather all shard files explicitly (Windows-safe)
    shard_files = sorted(wds_dir.glob(f"{split}-*.tar*"))
    if not shard_files:
        raise FileNotFoundError(f"No shards found in {wds_dir}")

    print(f"Found {len(shard_files)} shard(s) for {dataset}/{split} ({variant})")
    print(f"Example files: {[f.name for f in shard_files[:5]]}")

    # Convert to POSIX paths
    shard_paths = [str(p.as_posix()) for p in shard_files]

    # ✅ Use explicit file list for Windows; wildcard for Linux/Kaggle
    if os.name == "nt":
        shard_paths = ["file://" + str(p.as_posix()) for p in shard_files]
        dataset_obj = wds.WebDataset(shard_paths, handler=wds.handlers.warn_and_continue)

    else:
        # Linux / Kaggle can use pattern
        if "zstd" in variant or "zst" in variant:
            shard_pattern = "file://" + (wds_dir / f"{split}-*.tar.zst").as_posix()
        else:
            shard_pattern = "file://" + (wds_dir / f"{split}-*.tar").as_posix()
        dataset_obj = wds.WebDataset(shard_pattern, handler=wds.handlers.warn_and_continue)

    # Use standard transform if none provided
    if transform is None:
        transform = STANDARD_TRANSFORM

    # ✅ Optional shuffle (buffer-based)
    if shuffle:
        shuffle_buffer = 50000  # Enough to hold entire CIFAR-10 train set (50k samples)
        dataset_obj = dataset_obj.shuffle(shuffle_buffer)

    # ✅ Apply decoding and transforms
    dataset_obj = (
        dataset_obj
        .map(decode_image)
        .map(decode_label)
        .map(lambda sample: apply_transform(sample, transform))
    )

    # ✅ Create WebLoader and properly batch samples
    # Need to unbatch first, then rebatch with default_collate to get correct shape
    dataloader = wds.WebLoader(
        dataset_obj,
        num_workers=num_workers,
        pin_memory=pin_memory,
    ).unbatched().batched(batch_size, collation_fn=torch.utils.data.default_collate)

    return dataloader

## Smoke Test

In [5]:
if __name__ == "__main__":
    print("Running WebDataset DataLoader smoke test...\n")
    
    # Detect environment
    IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
    BASE_DIR = Path('/kaggle/working/format-matters') if IS_KAGGLE else Path('..').resolve()
    BUILT_DIR = BASE_DIR / 'data' / 'built'
    
    # Find available datasets and variants
    available_configs = []
    for dataset_name in ['cifar10', 'imagenet-mini']:
        wds_base = BUILT_DIR / dataset_name / 'webdataset'
        if wds_base.exists():
            for variant_dir in wds_base.iterdir():
                if variant_dir.is_dir():
                    train_shards = list(variant_dir.glob('train-*.tar*'))
                    if train_shards:
                        available_configs.append((dataset_name, variant_dir.name))
    
    if not available_configs:
        print("⚠ No WebDataset datasets found. Run 03_build_webdataset.ipynb first.")
    else:
        # Test with first available dataset/variant
        test_dataset, test_variant = available_configs[0]
        print(f"Testing with dataset: {test_dataset}, variant: {test_variant}\n")
        
        try:
            # Create dataloader
            loader = make_dataloader(
                dataset=test_dataset,
                split='train',
                batch_size=16,
                num_workers=0,
                pin_memory=False,
                variant=test_variant,
                shuffle=True
            )
            
            print(f"\nDataLoader created:")
            print(f"  Batch size: 32")
            print(f"  Num workers: 0")
            print(f"  Variant: {test_variant}")
            print(available_configs)
            # Load first batch
            print("\nLoading first batch...")
            with Timer("First batch"):
                images, labels = next(iter(loader))
            
            print(f"\nBatch shapes:")
            print(f"  Images: {images.shape} ({images.dtype})")
            print(f"  Labels: {labels.shape} ({labels.dtype})")
            print(f"  Image range: [{images.min():.3f}, {images.max():.3f}]")
            print(f"  Label range: [{labels.min()}, {labels.max()}]")
            
            # Load a few more batches to test throughput
            print("\nTesting throughput (10 batches)...")
            with Timer("10 batches"):
                for i, (images, labels) in enumerate(loader):
                    if i >= 9:
                        break
            
            print("\n✓ WebDataset DataLoader smoke test passed!")
            
        except Exception as e:
            print(f"\n✗ Smoke test failed: {e}")
            import traceback
            traceback.print_exc()

Running WebDataset DataLoader smoke test...

Testing with dataset: cifar10, variant: shard1024_none

Found 1 shard(s) for cifar10/train (shard1024_none)
Example files: ['train-000000.tar']

DataLoader created:
  Batch size: 32
  Num workers: 0
  Variant: shard1024_none
[('cifar10', 'shard1024_none'), ('cifar10', 'shard1024_zstd'), ('cifar10', 'shard256_none'), ('cifar10', 'shard256_zstd'), ('cifar10', 'shard64_none'), ('cifar10', 'shard64_zstd'), ('imagenet-mini', 'shard1024_none'), ('imagenet-mini', 'shard1024_zstd'), ('imagenet-mini', 'shard256_none'), ('imagenet-mini', 'shard256_zstd'), ('imagenet-mini', 'shard64_none'), ('imagenet-mini', 'shard64_zstd')]

Loading first batch...
First batch took 0.08s

Batch shapes:
  Images: torch.Size([16, 3, 224, 224]) (torch.float32)
  Labels: torch.Size([16]) (torch.int64)
  Image range: [-2.118, 2.640]
  Label range: [0, 0]

Testing throughput (10 batches)...
10 batches took 0.28s

✓ WebDataset DataLoader smoke test passed!


## ✅ WebDataset DataLoader Ready

**Usage:**
```python
# In training notebooks
%run ./12_loader_webdataset.ipynb

train_loader = make_dataloader(
    dataset='cifar10',
    split='train',
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    variant='shard256_none',  # or 'shard64_zstd', etc.
    shuffle=True
)

val_loader = make_dataloader(
    dataset='cifar10',
    split='val',
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    variant='shard256_none',
    shuffle=False
)
```

**Features:**
- TAR-based shards with optional compression
- Efficient streaming from disk or object storage
- Sequential I/O friendly
- Multiple variants (different shard sizes and compression)
- Buffer-based shuffling for training
- Standard PyTorch DataLoader interface

**Available variants:**
- `shard64_none`: 64MB shards, no compression
- `shard64_zstd`: 64MB shards, zstd compression
- `shard256_none`: 256MB shards, no compression
- `shard256_zstd`: 256MB shards, zstd compression
- `shard1024_none`: 1024MB shards, no compression
- `shard1024_zstd`: 1024MB shards, zstd compression

**Next steps:**
1. Create other format loaders (13-14)
2. Run training experiments (20-21)