# 14 - LMDB DataLoader

This notebook implements the PyTorch DataLoader for LMDB format.

**Format:** LMDB (Lightning Memory-Mapped Database)
- Reads from LMDB databases (train.lmdb, val.lmdb)
- Supports multiple variants (different compression options)
- Memory-mapped for fast random access
- Excellent for random access patterns

**Usage in other notebooks:**
```python
%run ./14_loader_lmdb.ipynb
loader = make_dataloader('cifar10', 'train', batch_size=64, num_workers=4, variant='compress_none')
```

In [1]:
import os
from pathlib import Path
from typing import Optional
import io
import pickle

import torch
from torch.utils.data import Dataset, DataLoader
import lmdb
from PIL import Image

# Optional compression libraries
try:
    import zstandard as zstd
    HAS_ZSTD = True
except ImportError:
    HAS_ZSTD = False

try:
    import lz4.frame
    HAS_LZ4 = True
except ImportError:
    HAS_LZ4 = False

# Load common utilities
%run ./10_common_utils.ipynb

✓ Common utilities loaded successfully

Available functions:
  - set_seed(seed)
  - get_transforms(augment)
  - write_sysinfo(path)
  - time_first_batch(dataloader, device)
  - start_monitor(log_path, interval)
  - stop_monitor(thread, stop_event)
  - append_to_summary(path, row_dict)
  - compute_metrics_from_logs(log_path)
  - get_device()
  - format_bytes(bytes)
  - count_parameters(model)

Constants:
  - STANDARD_TRANSFORM


## LMDB Dataset Class

In [2]:
class LMDBDataset(Dataset):
    """
    PyTorch Dataset for LMDB format.
    
    Args:
        lmdb_path: Path to LMDB database directory
        transform: Torchvision transforms to apply
    """
    
    def __init__(self, lmdb_path: Path, transform=None):
        self.lmdb_path = Path(lmdb_path)
        self.transform = transform
        
        # Open LMDB environment (read-only)
        self.env = lmdb.open(
            str(self.lmdb_path),
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False
        )
        
        # Read metadata
        with self.env.begin() as txn:
            metadata_bytes = txn.get(b'__metadata__')
            if metadata_bytes:
                self.metadata = pickle.loads(metadata_bytes)
                self.length = self.metadata['num_samples']
                self.compression = self.metadata.get('compression', 'none')
            else:
                raise ValueError(f"No metadata found in LMDB: {lmdb_path}")
        
        print(f"Loaded LMDB dataset: {self.length:,} samples from {lmdb_path.name}")
        print(f"  Compression: {self.compression}")
    
    def __len__(self):
        return self.length
    
    def _decompress(self, data: bytes, compression: str) -> bytes:
        """
        Decompress data using specified compression algorithm.
        
        Args:
            data: Compressed bytes
            compression: Compression type ('none', 'zstd', 'lz4')
        
        Returns:
            Decompressed bytes
        """
        if compression == 'none':
            return data
        elif compression == 'zstd' and HAS_ZSTD:
            decompressor = zstd.ZstdDecompressor()
            return decompressor.decompress(data)
        elif compression == 'lz4' and HAS_LZ4:
            return lz4.frame.decompress(data)
        else:
            raise ValueError(f"Unsupported compression: {compression}")
    
    def __getitem__(self, idx):
        # Create key
        key = f"{idx:08d}".encode('utf-8')
        
        # Read from LMDB
        with self.env.begin() as txn:
            entry_bytes = txn.get(key)
            
            if entry_bytes is None:
                raise IndexError(f"Index {idx} not found in LMDB")
            
            # Deserialize entry
            entry = pickle.loads(entry_bytes)
        
        # Decompress image if needed
        img_bytes = self._decompress(entry['image'], entry['compression'])
        
        # Load image
        try:
            image = Image.open(io.BytesIO(img_bytes)).convert('RGB')
        except Exception as e:
            raise RuntimeError(f"Failed to load image at index {idx}: {e}")
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Get label
        label = entry['label']
        
        return image, label
    
    def __del__(self):
        """Close LMDB environment on deletion."""
        if hasattr(self, 'env'):
            self.env.close()

## DataLoader Factory Function

In [3]:
def make_dataloader(
    dataset: str,
    split: str,
    batch_size: int,
    num_workers: int,
    pin_memory: bool = True,
    variant: str = 'compress_none',
    shuffle: bool = True,
    transform=None,
) -> DataLoader:
    """
    Create a DataLoader for LMDB format.
    
    Args:
        dataset: Dataset name (e.g., 'cifar10', 'imagenet-mini')
        split: Split name ('train' or 'val')
        batch_size: Batch size
        num_workers: Number of data loading workers
        pin_memory: Whether to pin memory for faster GPU transfer
        variant: LMDB variant (e.g., 'compress_none', 'compress_zstd')
        shuffle: Whether to shuffle data
        transform: Custom transform (uses STANDARD_TRANSFORM if None)
    
    Returns:
        PyTorch DataLoader
    """
    # Detect environment
    IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
    BASE_DIR = Path('/kaggle/working/format-matters') if IS_KAGGLE else Path('..').resolve()
    
    # Build path to LMDB database
    lmdb_path = BASE_DIR / 'data' / 'built' / dataset / 'lmdb' / variant / f'{split}.lmdb'
    
    if not lmdb_path.exists():
        raise FileNotFoundError(f"LMDB database not found: {lmdb_path}")
    
    # Use standard transform if none provided
    if transform is None:
        transform = STANDARD_TRANSFORM
    
    # Create dataset
    dataset_obj = LMDBDataset(lmdb_path, transform=transform)
    
    # Create dataloader
    dataloader = DataLoader(
        dataset_obj,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    
    return dataloader

## Smoke Test

In [4]:
if __name__ == "__main__":
    print("Running LMDB DataLoader smoke test...\n")
    
    # Detect environment
    IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
    BASE_DIR = Path('/kaggle/working/format-matters') if IS_KAGGLE else Path('..').resolve()
    BUILT_DIR = BASE_DIR / 'data' / 'built'
    
    # Find available datasets and variants
    available_configs = []
    for dataset_name in ['cifar10', 'imagenet-mini', 'tiny-imagenet-200']:
        lmdb_base = BUILT_DIR / dataset_name / 'lmdb'
        if lmdb_base.exists():
            for variant_dir in lmdb_base.iterdir():
                if variant_dir.is_dir():
                    train_db = variant_dir / 'train.lmdb'
                    if train_db.exists():
                        available_configs.append((dataset_name, variant_dir.name))
    
    if not available_configs:
        print("⚠ No LMDB datasets found. Run 05_build_lmdb.ipynb first.")
    else:
        # Test with first available dataset/variant
        test_dataset, test_variant = available_configs[0]
        print(f"Testing with dataset: {test_dataset}, variant: {test_variant}\n")
        
        try:
            # Create dataloader
            loader = make_dataloader(
                dataset=test_dataset,
                split='train',
                batch_size=32,
                num_workers=0,
                pin_memory=False,
                variant=test_variant,
                shuffle=True
            )
            
            print(f"\nDataLoader created:")
            print(f"  Dataset size: {len(loader.dataset):,}")
            print(f"  Batch size: {loader.batch_size}")
            print(f"  Num batches: {len(loader):,}")
            print(f"  Num workers: {loader.num_workers}")
            print(f"  Variant: {test_variant}")
            
            # Load first batch
            print("\nLoading first batch...")
            with Timer("First batch"):
                images, labels = next(iter(loader))
            
            print(f"\nBatch shapes:")
            print(f"  Images: {images.shape} ({images.dtype})")
            print(f"  Labels: {labels.shape} ({labels.dtype})")
            print(f"  Image range: [{images.min():.3f}, {images.max():.3f}]")
            print(f"  Label range: [{labels.min()}, {labels.max()}]")
            
            # Load a few more batches to test throughput
            print("\nTesting throughput (10 batches)...")
            with Timer("10 batches"):
                for i, (images, labels) in enumerate(loader):
                    if i >= 9:
                        break
            
            print("\n✓ LMDB DataLoader smoke test passed!")
            
        except Exception as e:
            print(f"\n✗ Smoke test failed: {e}")
            import traceback
            traceback.print_exc()

Running LMDB DataLoader smoke test...

Testing with dataset: cifar10, variant: compress_lz4

Loaded LMDB dataset: 50,000 samples from train.lmdb
  Compression: lz4

DataLoader created:
  Dataset size: 50,000
  Batch size: 32
  Num batches: 1,563
  Num workers: 0
  Variant: compress_lz4

Loading first batch...
First batch took 0.14s

Batch shapes:
  Images: torch.Size([32, 3, 224, 224]) (torch.float32)
  Labels: torch.Size([32]) (torch.int64)
  Image range: [-2.118, 2.640]
  Label range: [0, 9]

Testing throughput (10 batches)...
10 batches took 0.52s

✓ LMDB DataLoader smoke test passed!


## ✅ LMDB DataLoader Ready

**Usage:**
```python
# In training notebooks
%run ./14_loader_lmdb.ipynb

train_loader = make_dataloader(
    dataset='cifar10',
    split='train',
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    variant='compress_none',  # or 'compress_zstd', 'compress_lz4'
    shuffle=True
)

val_loader = make_dataloader(
    dataset='cifar10',
    split='val',
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    variant='compress_none',
    shuffle=False
)
```

**Features:**
- Memory-mapped database for fast random access
- Excellent for random access patterns
- Multiple variants (different compression options)
- Standard PyTorch DataLoader interface
- Efficient for training with shuffling

**Available variants:**
- `compress_none`: No compression
- `compress_zstd`: Zstandard compression (if available)
- `compress_lz4`: LZ4 compression (if available)

**Next steps:**
1. Run training experiments (20-21)
2. Run analysis notebooks (30-31, 40)