# ImageFolder Dataset Support

This notebook demonstrates slipstream's support for ImageFolder-style datasets.

**Supported formats:**
- Local ImageFolder directories (torchvision-style class subdirectories)
- S3 tar archives (auto-download, hash, and extract)
- Automatic format detection via `SlipstreamDataset`

## 1. Dataset Creation from S3 Tar Archive

The easiest way to load an ImageFolder dataset is via `SlipstreamDataset`, which auto-detects the format.

In [None]:
from slipstream import SlipstreamDataset, SlipstreamImageFolder

# Load ImageNet validation set from S3 tar archive
# This will automatically:
# 1. Download the tar file to local cache
# 2. Compute SHA256 hash for deduplication
# 3. Extract to cache directory
# 4. Return a SlipstreamImageFolder instance

dataset = SlipstreamDataset(
    remote_dir="s3://visionlab-datasets/imagenet1k-raw/val.tar.gz",
    decode_images=False
)

print(f"Dataset type: {type(dataset).__name__}")
print(f"Is SlipstreamImageFolder: {isinstance(dataset, SlipstreamImageFolder)}")
print(f"Number of samples: {len(dataset):,}")
print(f"Number of classes: {len(dataset.classes)}")
print(f"Field types: {dataset.field_types}")

In [None]:
from IPython.display import display
import PIL

# Inspect a sample
sample = dataset[0]

print(f"Sample keys: {list(sample.keys())}")
print(f"Label: {sample['label']} ({dataset.classes[sample['label']]})")
print(f"Index: {sample['index']}")
print(f"Path: {sample['path']}")

if isinstance(sample['image'],(PIL.Image.Image)):
    display(sample['image'])
else:
    print(f"Image bytes (first 20): {sample['image'][:20]}...")
    print(f"Image size: {len(sample['image']):,} bytes")   

In [None]:
# sample['image']

In [None]:
# View a few class names
print("First 10 classes:")
for i, cls in enumerate(dataset.classes[:10]):
    print(f"  {i}: {cls}")

## 2. Alternative Creation Methods

### Using open_imagefolder() explicitly

In [None]:
from slipstream import open_imagefolder

# Explicit creation (same result, more control)
dataset = open_imagefolder(
    "s3://visionlab-datasets/imagenet1k-raw/val.tar.gz",
    # cache_dir="/custom/cache/path",  # Optional: override cache location
)

print(dataset)

### Using input_dir parameter

In [None]:
# Also works with input_dir
dataset = SlipstreamDataset(
    input_dir="s3://visionlab-datasets/imagenet1k-raw/val.tar.gz"
)

print(f"Loaded {len(dataset):,} samples")

## 3. Dataset Type Detection

Slipstream can auto-detect the dataset format from paths.

In [None]:
from slipstream.dataset import detect_local_dataset_type, is_imagefolder_structure

# After extraction, check the local path
local_path = dataset._root_path
print(f"Local extracted path: {local_path}")

# Check detection
dataset_type = detect_local_dataset_type(local_path)
print(f"Detected type: {dataset_type}")

# Check structure
is_imagefolder = is_imagefolder_structure(local_path)
print(f"Is ImageFolder structure: {is_imagefolder}")

## 4. DataLoader Creation

### Option A: SlipstreamLoader (high-performance training)

In [None]:
from slipstream import SlipstreamLoader, DecodeCenterCrop, ToTorchImage, Normalize

# Create high-performance loader with decode pipeline
loader = SlipstreamLoader(
    dataset,
    batch_size=32,
    pipelines={
        "image": [
            DecodeCenterCrop(size=224),  # Decode + center crop
        ]
    },
    force_rebuild=True,
)

print(f"Loader batches: {len(loader):,}")

In [None]:
# Iterate through batches
for batch in loader:
    print(f"Batch keys: {list(batch.keys())}")
    print(f"Image shape: {batch['image'].shape}")  # [B, C, H, W]
    print(f"Image dtype: {batch['image'].dtype}")
    print(f"Labels: {batch['label'][:8].tolist()}...")
    break

### Option B: SlipstreamLoader with Training Augmentations

In [None]:
from slipstream import DecodeRandomResizedCrop
from slipstream.transforms import RandomHorizontalFlip, RandomColorJitterHSV

# Create training loader with augmentations
train_loader = SlipstreamLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    pipelines={
        "image": [
            DecodeRandomResizedCrop(size=224, scale=(0.08, 1.0)),
            ToTorchImage(device="cpu"),
            RandomHorizontalFlip(p=0.5),
            RandomColorJitterHSV(value=0.4, contrast=0.4, saturation=0.2, hue=0.1),
        ]
    },
)

# Get a batch
batch = next(iter(train_loader))
print(f"Training batch shape: {batch['image'].shape}")
print(f"Value range: [{batch['image'].min():.2f}, {batch['image'].max():.2f}]")

In [None]:
# RandomColorJitterHSV?

### Option C: Using Pipeline Presets

In [None]:
from slipstream.pipelines import supervised_train, supervised_val

# Validation loader with preset
val_loader = SlipstreamLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    pipelines=supervised_val(size=224, device="cpu"),
)

batch = next(iter(val_loader))
print(f"Val batch shape: {batch['image'].shape}")

In [None]:
# Training loader with preset
train_loader = SlipstreamLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    pipelines=supervised_train(size=224, seed=42, device="cpu"),
)

batch = next(iter(train_loader))
print(f"Train batch shape: {batch['image'].shape}")

## 5. Visualize Some Images

In [None]:
import matplotlib.pyplot as plt
from slipstream import decode_image

# Get a few samples
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

for i, ax in enumerate(axes.flat):
    sample = dataset[i * 100]  # Sample every 100th image
    img = decode_image(sample['image'], to_pil=True)
    ax.imshow(img)
    ax.set_title(f"{dataset.classes[sample['label']][:15]}")
    ax.axis('off')

plt.tight_layout()
plt.show()

## Summary

**Dataset Creation:**
```python
# Auto-detection (recommended)
dataset = SlipstreamDataset(remote_dir="s3://bucket/data.tar.gz")
dataset = SlipstreamDataset(local_dir="/path/to/imagefolder")

# Explicit
dataset = open_imagefolder("s3://bucket/data.tar.gz")
dataset = SlipstreamImageFolder("/path/to/imagefolder")
```

**DataLoader Creation:**
```python
# High-performance training
loader = SlipstreamLoader(
    dataset,
    batch_size=256,
    pipelines=supervised_train(224, device="cuda"),
)

# Simple iteration
loader = DataLoader(dataset, batch_size=32, collate_fn=list_collate_fn)
```