# HuggingFace Dataset Support

This notebook demonstrates how to use HuggingFace datasets with slipstream via the `hf://` URI scheme.

## HuggingFace Image Format

HuggingFace stores images in Parquet as dicts:
```python
{'bytes': b'\x89PNG...', 'path': None}  # inline bytes
{'bytes': None, 'path': '/path/to/image.jpg'}  # path reference
```

Slipstream automatically detects and decodes this format when `decode_images=True`.

## Prerequisites

The required dependencies are already included in slipstream:
- `huggingface_hub` - HuggingFace Hub client
- `hf_transfer` - Fast file transfers (optional, for high-bandwidth networks)

To enable faster downloads:
```bash
export HF_HUB_ENABLE_HF_TRANSFER=1
```

In [None]:
import os
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'

from slipstream import SlipstreamDataset, SlipstreamLoader
from slipstream.dataset import is_hf_image_dict, is_image_bytes, decode_image
import matplotlib.pyplot as plt

## Understanding HuggingFace Image Dicts

Let's look at what the raw HuggingFace format looks like:

In [None]:
# Load MNIST without decoding to see the raw HuggingFace format
dataset_raw = SlipstreamDataset(
    input_dir="hf://datasets/ylecun/mnist/mnist/",
    decode_images=False,  # See raw format
)

sample = dataset_raw[0]
print("Raw sample keys:", list(sample.keys()))
print("Raw 'image' type:", type(sample['image']))
print("Raw 'image' value:", {k: type(v).__name__ for k, v in sample['image'].items()})

In [None]:
# Slipstream automatically detects HuggingFace image dicts
hf_dict = sample['image']
print(f"Is HF image dict: {is_hf_image_dict(hf_dict)}")
print(f"Is valid image: {is_image_bytes(hf_dict)}")
print(f"Detected field types: {dataset_raw.field_types}")
print(f"Image fields: {dataset_raw.image_fields}")

In [None]:
# Manually decode using decode_image()
img_tensor = decode_image(hf_dict, to_pil=False)
img_pil = decode_image(hf_dict, to_pil=True)

print(f"Tensor shape: {img_tensor.shape} (CHW)")
print(f"PIL size: {img_pil.size}")
img_pil

## Automatic Decoding with SlipstreamDataset

Set `decode_images=True` to automatically decode HuggingFace image dicts:

In [None]:
# Load MNIST with automatic decoding
dataset = SlipstreamDataset(
    input_dir="hf://datasets/ylecun/mnist/mnist/",
    decode_images=True,
    to_pil=True,
)

print(dataset)
sample = dataset[0]
print(f"\nSample 'image' type: {type(sample['image'])}")
sample['image']

## CIFAR-10 Example

CIFAR-10 uses the field name 'img' instead of 'image' - slipstream handles this automatically:

In [None]:
# Load CIFAR-10 - note the field is 'img' not 'image'
dataset = SlipstreamDataset(
    input_dir="hf://datasets/uoft-cs/cifar10/plain_text/",
    decode_images=True,
    to_pil=True,
)

print(dataset)
sample = dataset[0]
print(f"\nField names: {list(sample.keys())}")
sample['img']  # CIFAR-10 uses 'img' field name

In [None]:
# Visualize some CIFAR-10 samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))

cifar10_classes = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

for i, ax in enumerate(axes.flat):
    sample = dataset[i]
    ax.imshow(sample['img'])
    ax.set_title(cifar10_classes[sample['label']])
    ax.axis('off')

plt.suptitle('CIFAR-10 from HuggingFace (auto-decoded)', fontsize=14)
plt.tight_layout()
plt.show()

## Field Name Variability

HuggingFace doesn't enforce column names. Common patterns:
- Images: `image`, `img`, `pixel_values`
- Labels: `label`, `labels`, `fine_label`, `coarse_label`

Slipstream detects images by format (not name), so any field containing
HuggingFace image dicts will be automatically detected and decoded.

## Other HuggingFace Datasets

You can load any LitData-compatible HuggingFace dataset using the `hf://` URI:

```python
# ImageNet subset
dataset = SlipstreamDataset(input_dir="hf://datasets/imagenet-1k/data")

# MNIST
dataset = SlipstreamDataset(input_dir="hf://datasets/ylecun/mnist/mnist/")

# Custom datasets
dataset = SlipstreamDataset(input_dir="hf://datasets/username/my-dataset/data")
```

Note: The dataset must be in a format compatible with LitData's streaming protocol.

# SlipstreamLoader

Automatically convert huggingface dataset to .slip format

In [None]:
from slipstream import SlipstreamDataset, SlipstreamLoader, DecodeCenterCrop, DecodeYUVFullRes
from PIL import Image

dataset = SlipstreamDataset(                                                                                                                                                             
    input_dir="hf://datasets/uoft-cs/cifar10/plain_text/",
    decode_images=False,                                                                                                                                                               
)                                                                                                                                                                                      
print(dataset)

loader = SlipstreamLoader(                                                                                                                                                             
    dataset,                                                                                                                                                                           
    batch_size=100,     
    pipelines={'img': [DecodeCenterCrop(size=32)]},                                                                                                                                       
    force_rebuild=True,  # Rebuild cache with correct format                                                                                                                                  
) 

In [None]:
# dataset = SlipstreamDataset(                                                                                                                                                             
#     input_dir="hf://datasets/uoft-cs/cifar10/plain_text/",
#     decode_images=False,
# ) 
# dataset[0]['img']

In [None]:
# dataset[0]

In [None]:
batch = next(iter(loader))
batch.keys()

In [None]:
batch['img'].shape

In [None]:
print(batch['label'][0])
Image.fromarray(batch['img'][0])

In [None]:
from slipstream import SlipstreamDataset, SlipstreamLoader, DecodeCenterCrop                                                                                                           
from PIL import Image

dataset = SlipstreamDataset(                                                                                                                                                           
  input_dir="hf://datasets/ylecun/mnist/mnist/",    
  # input_dir="hf://datasets/uoft-cs/cifar10/plain_text/",
  decode_images=False,                                                                                                                                                               
)                                                                                                                                                                                      
print(dataset)

loader = SlipstreamLoader(                                                                                                                                                             
    dataset,                                                                                                                                                                           
    batch_size=100,     
    pipelines={'image': [DecodeCenterCrop(28)]},                                                                                                                                       
    force_rebuild=True,  # Rebuild cache with correct format                                                                                                                                  
) 

In [None]:
dataset[0]

In [None]:
batch = next(iter(loader))
batch.keys()

In [None]:
batch['image'][0].shape

In [None]:
Image.fromarray(batch['image'][0])