# SlipstreamDataset Basics

This notebook demonstrates the basic usage of `SlipstreamDataset` for loading streaming datasets.

## Features

- **Intuitive API**: Use `remote_dir` and `cache_dir` instead of `Dir(...)`
- **Automatic field detection**: Identifies image fields automatically
- **Flexible decoding**: Optional automatic image decoding
- **Pipeline support**: Per-field transforms for training

In [None]:
# Test dataset path (ImageNet validation, streaming format)
LITDATA_VAL_PATH = "s3://visionlab-datasets/imagenet1k/pre-processed/s256-l512-jpgbytes-q100-streaming/val/"

## 1. Basic Usage: Load and Inspect Dataset

In [None]:
from slipstream import SlipstreamDataset

# Create dataset with automatic decoding
dataset = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=True,
    to_pil=True,
)

# Show dataset info
dataset

In [None]:
# Get a sample
sample = dataset[0]
print(f"Sample keys: {list(sample.keys())}")
print(f"Image type: {type(sample['image'])}")
print(f"Label: {sample['label']}")

In [None]:
# Display the image
sample['image']

## 2. Raw Bytes Mode (for high-performance training)

In [None]:
# Create dataset WITHOUT automatic decoding
# This is what you'd use with SlipstreamLoader for training
dataset_raw = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=False,
)

dataset_raw

In [None]:
# Get raw sample
sample_raw = dataset_raw[0]
print(f"Image type: {type(sample_raw['image'])}")
print(f"Image size: {len(sample_raw['image'])} bytes")
print(f"First 16 bytes (JPEG header): {sample_raw['image'][:16].hex()}")

In [None]:
# Manual decoding (what the loader will do)
from slipstream import decode_image

image_tensor = decode_image(sample_raw['image'], to_pil=False)
print(f"Decoded tensor shape: {image_tensor.shape}")
print(f"Decoded tensor dtype: {image_tensor.dtype}")

## 3. Using with DataLoaders (variable-sized images)

ImageNet images have varying sizes (256x384, 256x376, etc.), so we can't use `torch.stack`. 
We need a custom collate function that keeps images as a list.

In [None]:
from slipstream import SlipstreamDataset, list_collate_fn
import torch

# Dataset with tensor output
dataset_tensor = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=True,
    to_pil=False,  # Return tensors instead of PIL
)

# StreamingDataLoader with custom collate
loader = StreamingDataLoader(
    dataset_tensor,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    collate_fn=list_collate_fn,
)

# Get a batch
batch = next(iter(loader))
print(f"Batch keys: {list(batch.keys())}")
print(f"Images: list of {len(batch['image'])} tensors")
print(f"  First image shape: {batch['image'][0].shape}")
print(f"  Second image shape: {batch['image'][1].shape}")
print(f"Label: {batch['label'][0]}")

In [None]:
# Note: The same collate_fn works with standard PyTorch DataLoader
from torch.utils.data import DataLoader

loader_pytorch = DataLoader(
    dataset_tensor,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    collate_fn=list_collate_fn,
)

batch = next(iter(loader_pytorch))
print(f"PyTorch DataLoader also works with list_collate_fn:")
print(f"  Images: list of {len(batch['image'])} tensors")
print(f"  Label: {batch['label'][0]}")

## 4. Using Pipelines for Uniform Sizes (enables torch.stack)

For training, you typically want stacked tensors `[B, C, H, W]`. Using `CenterCrop` or `RandomResizedCrop` in pipelines ensures all images are the same size.

In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader

# Define a pipeline that produces uniform 224x224 images
image_pipeline = transforms.Compose([
    transforms.Lambda(lambda x: decode_image(x, to_pil=True)),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create dataset with pipeline
dataset_pipeline = SlipstreamDataset(
    remote_dir=LITDATA_VAL_PATH,
    decode_images=False,  # Pipeline handles decoding
    pipelines={'image': image_pipeline},
)

dataset_pipeline

In [None]:
# Get a processed sample
sample_processed = dataset_pipeline[0]
print(f"Processed image shape: {sample_processed['image'].shape}")
print(f"Processed image dtype: {sample_processed['image'].dtype}")
print(f"Processed image range: [{sample_processed['image'].min():.3f}, {sample_processed['image'].max():.3f}]")

# Now PyTorch DataLoader works because all images are 224x224!
def collate_fn(batch):
    images = torch.stack([sample['image'] for sample in batch])
    labels = torch.tensor([sample['label'] for sample in batch])
    return {'image': images, 'label': labels}

loader_pytorch = DataLoader(
    dataset_pipeline,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
)

batch = next(iter(loader_pytorch))
print(f"\nPyTorch DataLoader batch:")
print(f"  Image shape: {batch['image'].shape}")  # [16, 3, 224, 224]
print(f"  Label shape: {batch['label'].shape}")

## Summary

**SlipstreamDataset** provides:
- **Intuitive API**: `remote_dir` and `cache_dir` instead of `Dir(...)`
- **Automatic field detection**: Identifies image fields automatically
- **Pipeline support**: Per-field transforms for training
- **LitData caching**: Automatic cache management under `~/.lightning/`

**Key patterns**:
- `decode_images=True` for interactive exploration (PIL/tensor output)
- `decode_images=False` with `pipelines` for training (custom transforms)
- Use `CenterCrop`/`RandomResizedCrop` in pipelines to enable `torch.stack`

**Next**: See `02_loader_benchmarks.ipynb` for high-performance batch loading with `SlipstreamLoader`.