# 13 - TFRecord DataLoader

This notebook implements the PyTorch DataLoader for TFRecord format.

**Format:** TensorFlow's TFRecord (binary protocol buffer)
- Reads from TFRecord shards (train-*.tfrecord[.gz], val-*.tfrecord[.gz])
- Supports multiple variants (different shard sizes and compression)
- Efficient binary serialization
- Sequential I/O friendly

**Usage in other notebooks:**
```python
%run ./13_loader_tfrecord.ipynb
loader = make_dataloader('cifar10', 'train', batch_size=64, num_workers=4, variant='shard256_none')
```

In [1]:
import os
from pathlib import Path
from typing import Optional
import io

import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
import tensorflow as tf
from PIL import Image
import numpy as np

# Load common utilities
%run ./10_common_utils.ipynb

✓ Common utilities loaded successfully

Available functions:
  - set_seed(seed)
  - get_transforms(augment)
  - write_sysinfo(path)
  - time_first_batch(dataloader, device)
  - start_monitor(log_path, interval)
  - stop_monitor(thread, stop_event)
  - append_to_summary(path, row_dict)
  - compute_metrics_from_logs(log_path)
  - get_device()
  - format_bytes(bytes)
  - count_parameters(model)

Constants:
  - STANDARD_TRANSFORM


## TFRecord Dataset Class

In [2]:
class TFRecordDataset(IterableDataset):
    """
    PyTorch IterableDataset for TFRecord format.
    
    Args:
        tfrecord_paths: List of paths to TFRecord files
        transform: Torchvision transforms to apply
        compression: Compression type ('none' or 'gzip')
        shuffle_buffer: Buffer size for shuffling (0 = no shuffle)
    """
    
    def __init__(
        self,
        tfrecord_paths: list,
        transform=None,
        compression: str = 'none',
        shuffle_buffer: int = 0
    ):
        self.tfrecord_paths = [str(p) for p in tfrecord_paths]
        self.transform = transform
        self.compression = 'GZIP' if compression == 'gzip' else None
        self.shuffle_buffer = shuffle_buffer
        
        print(f"Loaded TFRecord dataset: {len(self.tfrecord_paths)} shard(s)")
    
    def _parse_example(self, serialized_example):
        """
        Parse a serialized TFRecord example.
        
        Args:
            serialized_example: Serialized tf.train.Example
        
        Returns:
            Tuple of (image_tensor, label)
        """
        # Define feature description
        feature_description = {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
        }
        
        # Parse the example
        example = tf.io.parse_single_example(serialized_example, feature_description)
        
        # Decode image
        image_bytes = example['image'].numpy()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Get label
        label = int(example['label'].numpy())
        
        return image, label
    
    def __iter__(self):
        """
        Iterate over the dataset.
        
        Yields:
            Tuple of (image_tensor, label)
        """
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            paths = self.tfrecord_paths[worker_info.id::worker_info.num_workers]
            if not paths:
                paths = self.tfrecord_paths
        else:
            paths = self.tfrecord_paths
        
        num_parallel_reads = tf.data.AUTOTUNE if len(paths) > 1 else None
        dataset = tf.data.TFRecordDataset(
            paths,
            compression_type=self.compression,
            num_parallel_reads=num_parallel_reads
        )
        
        # Add shuffling if requested
        if self.shuffle_buffer > 0:
            dataset = dataset.shuffle(buffer_size=self.shuffle_buffer)
        
        # Iterate and parse examples
        for serialized_example in dataset:
            try:
                image, label = self._parse_example(serialized_example)
                yield image, label
            except Exception as e:
                # Skip corrupted examples
                print(f"Warning: Failed to parse example: {e}")
                continue


## DataLoader Factory Function

In [3]:
def make_dataloader(
    dataset: str,
    split: str,
    batch_size: int,
    num_workers: int,
    pin_memory: bool = True,
    variant: str = 'shard256_none',
    shuffle: bool = True,
    transform=None,
) -> DataLoader:
    """
    Create a DataLoader for TFRecord format.
    
    Args:
        dataset: Dataset name (e.g., 'cifar10', 'imagenet-mini')
        split: Split name ('train' or 'val')
        batch_size: Batch size
        num_workers: Number of data loading workers
        pin_memory: Whether to pin memory for faster GPU transfer
        variant: TFRecord variant (e.g., 'shard256_none', 'shard64_gzip')
        shuffle: Whether to shuffle data (uses buffer for TFRecord)
        transform: Custom transform (uses STANDARD_TRANSFORM if None)
    
    Returns:
        PyTorch DataLoader
    """
    # Detect environment
    IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
    BASE_DIR = Path('/kaggle/working/format-matters') if IS_KAGGLE else Path('..').resolve()
    
    # Build path to TFRecord shards
    tfr_dir = BASE_DIR / 'data' / 'built' / dataset / 'tfrecord' / variant
    
    if not tfr_dir.exists():
        raise FileNotFoundError(f"TFRecord directory not found: {tfr_dir}")
    
    # Find shard files for the split
    shard_files = sorted(tfr_dir.glob(f"{split}-*.tfrecord*"))
    
    if not shard_files:
        raise FileNotFoundError(f"No TFRecord shards found in: {tfr_dir}")
    
    print(f"Found {len(shard_files)} shard(s) for {dataset}/{split} ({variant})")
    
    # Use standard transform if none provided
    if transform is None:
        transform = STANDARD_TRANSFORM
    
    # Determine compression from variant name
    compression = 'gzip' if 'gzip' in variant else 'none'
    
    # Create dataset
    shuffle_buffer = 50000 if shuffle else 0  # Enough for entire CIFAR-10 train set
    
    dataset_obj = TFRecordDataset(
        tfrecord_paths=shard_files,
        transform=transform,
        compression=compression,
        shuffle_buffer=shuffle_buffer
    )
    
    # Create dataloader
    dataloader = DataLoader(
        dataset_obj,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    
    return dataloader

## Smoke Test

In [4]:
if __name__ == "__main__":
    print("Running TFRecord DataLoader smoke test...\n")
    
    # Detect environment
    IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
    BASE_DIR = Path('/kaggle/working/format-matters') if IS_KAGGLE else Path('..').resolve()
    BUILT_DIR = BASE_DIR / 'data' / 'built'
    
    # Find available datasets and variants
    available_configs = []
    for dataset_name in ['cifar10', 'imagenet-mini', 'tiny-imagenet-200']:
        tfr_base = BUILT_DIR / dataset_name / 'tfrecord'
        if tfr_base.exists():
            for variant_dir in tfr_base.iterdir():
                if variant_dir.is_dir():
                    train_shards = list(variant_dir.glob('train-*.tfrecord*'))
                    if train_shards:
                        available_configs.append((dataset_name, variant_dir.name))
    
    if not available_configs:
        print("⚠ No TFRecord datasets found. Run 04_build_tfrecord.ipynb first.")
    else:
        # Test with first available dataset/variant
        test_dataset, test_variant = available_configs[0]
        print(f"Testing with dataset: {test_dataset}, variant: {test_variant}\n")
        
        try:
            # Create dataloader
            loader = make_dataloader(
                dataset=test_dataset,
                split='train',
                batch_size=32,
                num_workers=0,
                pin_memory=False,
                variant=test_variant,
                shuffle=True
            )
            
            print(f"\nDataLoader created:")
            print(f"  Batch size: 32")
            print(f"  Num workers: 0")
            print(f"  Variant: {test_variant}")
            
            # Load first batch
            print("\nLoading first batch...")
            with Timer("First batch"):
                images, labels = next(iter(loader))
            
            print(f"\nBatch shapes:")
            print(f"  Images: {images.shape} ({images.dtype})")
            print(f"  Labels: {labels.shape} ({labels.dtype})")
            print(f"  Image range: [{images.min():.3f}, {images.max():.3f}]")
            print(f"  Label range: [{labels.min()}, {labels.max()}]")
            
            # Load a few more batches to test throughput
            print("\nTesting throughput (10 batches)...")
            with Timer("10 batches"):
                for i, (images, labels) in enumerate(loader):
                    if i >= 9:
                        break
            
            print("\n✓ TFRecord DataLoader smoke test passed!")
            
        except Exception as e:
            print(f"\n✗ Smoke test failed: {e}")
            import traceback
            traceback.print_exc()

Running TFRecord DataLoader smoke test...

Testing with dataset: cifar10, variant: shard1024_gzip

Found 1 shard(s) for cifar10/train (shard1024_gzip)
Loaded TFRecord dataset: 1 shard(s)

DataLoader created:
  Batch size: 32
  Num workers: 0
  Variant: shard1024_gzip

Loading first batch...
First batch took 1.40s

Batch shapes:
  Images: torch.Size([32, 3, 224, 224]) (torch.float32)
  Labels: torch.Size([32]) (torch.int64)
  Image range: [-2.118, 2.640]
  Label range: [0, 9]

Testing throughput (10 batches)...
10 batches took 1.03s

✓ TFRecord DataLoader smoke test passed!


## ✅ TFRecord DataLoader Ready

**Usage:**
```python
# In training notebooks
%run ./13_loader_tfrecord.ipynb

train_loader = make_dataloader(
    dataset='cifar10',
    split='train',
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    variant='shard256_none',  # or 'shard64_gzip', etc.
    shuffle=True
)

val_loader = make_dataloader(
    dataset='cifar10',
    split='val',
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    variant='shard256_none',
    shuffle=False
)
```

**Features:**
- TFRecord binary format with optional compression
- Efficient serialization and deserialization
- Sequential I/O friendly
- Multiple variants (different shard sizes and compression)
- Buffer-based shuffling for training
- Standard PyTorch DataLoader interface
- Compatible with TensorFlow ecosystem

**Available variants:**
- `shard64_none`: 64MB shards, no compression
- `shard64_gzip`: 64MB shards, gzip compression
- `shard256_none`: 256MB shards, no compression
- `shard256_gzip`: 256MB shards, gzip compression
- `shard1024_none`: 1024MB shards, no compression
- `shard1024_gzip`: 1024MB shards, gzip compression

**Next steps:**
1. Create other format loaders (14)
2. Run training experiments (20-21)