# Convert Mining Dataset to Zarr Format and Benchmark

This notebook converts the existing memory-mapped `.npy` tile structure to uncompressed Zarr format with optimized chunking, then benchmarks data loading performance.

## Objectives
1. Convert existing data structure (manifests + .npy files) to Zarr
2. Use chunk size of 8 tiles for optimal I/O performance
3. Benchmark data loading speed: old design vs. Zarr
4. Analyze memory usage and I/O patterns

In [14]:
import sys
from pathlib import Path
import time
import numpy as np
import zarr
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

# Add src to path
sys.path.insert(0, str(Path.cwd().parent))

from config import Config
from manifest_reader import ManifestReader
from network.data_loader import MiningSegmentationDataLoader

print(f"Zarr version: {zarr.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

Zarr version: 3.1.5
NumPy version: 2.3.5
PyTorch version: 2.10.0+cu130


## 1. Setup and Explore Current Data Structure

In [15]:
# Initialize config and manifest reader
config = Config()
manifests_dir = config.DATA_DIR / "manifests"
mmap_dir = config.MMAP_DIR
zarr_dir = config.DATA_DIR / "landsat_zarr"

print(f"Manifests directory: {manifests_dir}")
print(f"MMAP directory: {mmap_dir}")
print(f"Zarr output directory: {zarr_dir}")
print(f"\nManifests exist: {manifests_dir.exists()}")
print(f"MMAP data exist: {mmap_dir.exists()}")

Manifests directory: /scicore/home/meiera/schulz0022/projects/mining-net/data/manifests
MMAP directory: /scicore/home/meiera/schulz0022/projects/mining-net/data/landsat_mmap
Zarr output directory: /scicore/home/meiera/schulz0022/projects/mining-net/data/landsat_zarr

Manifests exist: True
MMAP data exist: True


In [16]:
# Explore available data using manifest reader
reader = ManifestReader(manifests_dir)
all_manifests = reader.list_all_manifests()

print(f"Found {len(all_manifests)} clusters\n")
print("Sample clusters:")
for manifest in all_manifests[:5]:
    print(f"  Cluster {manifest['cluster_id']:>10} ({manifest['country_code']}) - "
          f"{manifest['tile_count']:>4} tiles, years {manifest['years']}")

if len(all_manifests) > 5:
    print(f"  ... and {len(all_manifests) - 5} more clusters")

Found 227 clusters

Sample clusters:
  Cluster 6155687546600015 (ZAF) -   72 tiles, years [2019]
  Cluster 16714675725600725 (ZAF) -  169 tiles, years [2019]
  Cluster 17315396644879081 (ZAF) -  110 tiles, years [2019]
  Cluster 52652953960301590 (ZAF) -  100 tiles, years [2019]
  Cluster 62084203559214856 (ZAF) -   63 tiles, years [2019]
  ... and 222 more clusters


In [17]:
# Get total tile count and data statistics
total_tiles = sum(m['tile_count'] for m in all_manifests)
countries = set(m['country_code'] for m in all_manifests)
all_years = set()
for m in all_manifests:
    all_years.update(m['years'])

print(f"Total tiles: {total_tiles:,}")
print(f"Countries: {sorted(countries)}")
print(f"Years: {sorted(all_years)}")
print(f"\nEstimated data size (assuming 256x256 tiles, 7 bands + 1 label):")
tile_size = 256 * 256 * 8 * 4  # 8 channels, 4 bytes per float32
total_size_gb = (total_tiles * tile_size) / (1024**3)
print(f"  ~{total_size_gb:.2f} GB")

Total tiles: 43,023
Countries: ['ZAF']
Years: [2019]

Estimated data size (assuming 256x256 tiles, 7 bands + 1 label):
  ~84.03 GB


## 2. Inspect a Sample Tile

In [18]:
# Load a sample tile to understand data structure
sample_manifest = reader.read_manifest(all_manifests[0]['cluster_id'])
sample_tile = sample_manifest['tiles'][0]

print("Sample tile metadata:")
print(f"  Cluster: {sample_manifest['cluster_id']}")
print(f"  Country: {sample_manifest['country_code']}")
print(f"  Year: {sample_tile['year']}")
print(f"  Tile indices: ({sample_tile['tile_ix']}, {sample_tile['tile_iy']})")

# Load the actual data
tile_path = mmap_dir / str(sample_manifest['cluster_id']) / str(sample_tile['year']) / f"{sample_tile['tile_ix']}_{sample_tile['tile_iy']}"
print(f"\nTile path: {tile_path}")
print(f"Exists: {tile_path.exists()}")

if tile_path.exists():
    features = np.load(tile_path / "features.npy", mmap_mode='r')
    labels = np.load(tile_path / "labels.npy", mmap_mode='r')
    
    print(f"\nFeatures shape: {features.shape}")
    print(f"Features dtype: {features.dtype}")
    print(f"Labels shape: {labels.shape}")
    print(f"Labels dtype: {labels.dtype}")
    print(f"\nFeatures range: [{features.min():.4f}, {features.max():.4f}]")
    print(f"Labels unique values: {np.unique(labels)}")

Sample tile metadata:
  Cluster: 6155687546600015
  Country: ZAF
  Year: 2019
  Tile indices: (6841, 11646)

Tile path: /scicore/home/meiera/schulz0022/projects/mining-net/data/landsat_mmap/6155687546600015/2019/6841_11646
Exists: True

Features shape: (7, 64, 64)
Features dtype: float32
Labels shape: (1, 64, 64)
Labels dtype: float32

Features range: [0.0514, 316.3787]
Labels unique values: [0.]


## 3. Design Zarr Structure

We'll create a single Zarr group with multiple arrays:
```
landsat_zarr/
└── data.zarr/           # Zarr group containing:
    ├── features         # Shape: (N_tiles, C, H, W), chunks: (8, C, H, W)
    ├── labels           # Shape: (N_tiles, 1, H, W), chunks: (8, 1, H, W)
    ├── cluster_ids      # Shape: (N_tiles,) - cluster ID for each tile
    ├── tile_ix          # Shape: (N_tiles,) - tile X index
    ├── tile_iy          # Shape: (N_tiles,) - tile Y index
    └── years            # Shape: (N_tiles,) - year for each tile
```

Benefits:
- Single unified Zarr group for all data
- Chunk size of 8 tiles optimizes for batch loading
- Index arrays enable efficient filtering and queries
- Uncompressed for maximum read speed
- Scalable structure for cloud storage

In [19]:
# Create Zarr directory structure
zarr_dir.mkdir(parents=True, exist_ok=True)
(zarr_dir / "metadata").mkdir(exist_ok=True)

print(f"Created Zarr directory: {zarr_dir}")

Created Zarr directory: /scicore/home/meiera/schulz0022/projects/mining-net/data/landsat_zarr


## 4. Convert Data to Zarr Format

In [None]:
def convert_to_zarr(
    manifests_dir: Path,
    mmap_dir: Path,
    zarr_dir: Path,
    chunk_size: int = 8,
    max_tiles: int = None
):
    """
    Convert memory-mapped tile structure to Zarr format with index arrays.
    
    Args:
        manifests_dir: Directory with manifest files
        mmap_dir: Directory with .npy tiles
        zarr_dir: Output directory for Zarr group
        chunk_size: Number of tiles per chunk
        max_tiles: Maximum tiles to convert (for testing)
    """
    import json
    
    reader = ManifestReader(manifests_dir)
    all_manifests = reader.list_all_manifests()
    
    # Build complete tile index
    print("Building tile index...")
    tile_index = []
    
    for manifest_meta in tqdm(all_manifests, desc="Reading manifests"):
        manifest = reader.read_manifest(manifest_meta['cluster_id'])
        if not manifest:
            continue
        
        for tile in manifest['tiles']:
            tile_index.append({
                'cluster_id': manifest['cluster_id'],
                'country_code': manifest['country_code'],
                'year': tile['year'],
                'tile_ix': tile['tile_ix'],
                'tile_iy': tile['tile_iy'],
                'geometry_hash': tile.get('geometry_hash'),
            })
            
            if max_tiles and len(tile_index) >= max_tiles:
                break
        
        if max_tiles and len(tile_index) >= max_tiles:
            break
    
    n_tiles = len(tile_index)
    print(f"Found {n_tiles:,} tiles to convert")
    
    # Load first tile to get dimensions
    first_tile = tile_index[0]
    tile_path = mmap_dir / str(first_tile['cluster_id']) / str(first_tile['year']) / f"{first_tile['tile_ix']}_{first_tile['tile_iy']}"
    sample_features = np.load(tile_path / "features.npy", mmap_mode='r')
    sample_labels = np.load(tile_path / "labels.npy", mmap_mode='r')
    
    n_channels, height, width = sample_features.shape
    print(f"Tile dimensions: {n_channels} channels, {height}x{width} pixels")
    
    # Create Zarr group with all arrays
    print(f"\nCreating Zarr group (chunk size: {chunk_size} tiles)...")
    
    group_path = zarr_dir / "data.zarr"
    zarr_group = zarr.open_group(store=str(group_path), mode='w')
    
    # Create data arrays
    features_store = zarr_group.create_array(
        'features',
        shape=(n_tiles, n_channels, height, width),
        chunks=(chunk_size, n_channels, height, width),
        dtype=np.float32,
        compressors=zarr.codecs.BloscCodec(cname='zstd', clevel=0, shuffle="shuffle", blocksize=0),
    )
    
    labels_store = zarr_group.create_array(
        'labels',
        shape=(n_tiles, 1, height, width),
        chunks=(chunk_size, 1, height, width),
        dtype=np.float32,
        compressors=zarr.codecs.BloscCodec(cname='zstd', clevel=0, shuffle="shuffle", blocksize=0),
    )
    
    # Create index arrays
    index_chunk = chunk_size * 1000  # Larger chunks for 1D index arrays
    
    cluster_ids_store = zarr_group.create_array(
        'cluster_ids',
        shape=(n_tiles,),
        chunks=(index_chunk,),
        dtype=np.int64,
        compressors=zarr.codecs.BloscCodec(cname='zstd', clevel=0, shuffle="shuffle", blocksize=0),
    )
    
    tile_ix_store = zarr_group.create_array(
        'tile_ix',
        shape=(n_tiles,),
        chunks=(index_chunk,),
        dtype=np.int32,
        compressors=zarr.codecs.BloscCodec(cname='zstd', clevel=0, shuffle="shuffle", blocksize=0),
    )
    
    tile_iy_store = zarr_group.create_array(
        'tile_iy',
        shape=(n_tiles,),
        chunks=(index_chunk,),
        dtype=np.int32,
        compressors=zarr.codecs.BloscCodec(cname='zstd', clevel=0, shuffle="shuffle", blocksize=0),
    )
    
    years_store = zarr_group.create_array(
        'years',
        shape=(n_tiles,),
        chunks=(index_chunk,),
        dtype=np.int32,
        compressors=zarr.codecs.BloscCodec(cname='zstd', clevel=0, shuffle="shuffle", blocksize=0),
    )
    
    print(f"Features array: shape={features_store.shape}, chunks={features_store.chunks}")
    print(f"Labels array: shape={labels_store.shape}, chunks={labels_store.chunks}")
    print(f"Index arrays: shape=({n_tiles},), chunks=({index_chunk},)")
    
    # Copy data tile by tile
    print("\nCopying tiles...")
    start_time = time.time()
    
    for idx, tile_meta in enumerate(tqdm(tile_index, desc="Converting tiles")):
        tile_path = mmap_dir / str(tile_meta['cluster_id']) / str(tile_meta['year']) / f"{tile_meta['tile_ix']}_{tile_meta['tile_iy']}"
        
        try:
            features = np.load(tile_path / "features.npy", mmap_mode='r')
            labels = np.load(tile_path / "labels.npy", mmap_mode='r')
            
            features_store[idx] = features
            labels_store[idx] = labels
            
            # Store index data
            cluster_ids_store[idx] = tile_meta['cluster_id']
            tile_ix_store[idx] = tile_meta['tile_ix']
            tile_iy_store[idx] = tile_meta['tile_iy']
            years_store[idx] = tile_meta['year']
            
        except Exception as e:
            print(f"\nError loading tile {idx}: {e}")
            # Fill with zeros on error
            features_store[idx] = np.zeros((n_channels, height, width), dtype=np.float32)
            labels_store[idx] = np.zeros((1, height, width), dtype=np.float32)
            cluster_ids_store[idx] = 0
            tile_ix_store[idx] = 0
            
        except Exception as e:
            print(f"\nError loading tile {idx}: {e}")
            # Fill with zeros on error
            features_store[idx] = np.zeros((n_channels, height, width), dtype=np.float32)
            labels_store[idx] = np.zeros((1, height, width), dtype=np.float32)
            cluster_ids_store[idx] = 0
            tile_ix_store[idx] = 0
            tile_iy_store[idx] = 0
            years_store[idx] = 0
    
    elapsed = time.time() - start_time
    print(f"\nConversion complete in {elapsed:.2f}s ({n_tiles/elapsed:.1f} tiles/s)")
    
    # Save tile index as JSON for backward compatibility
    metadata_path = zarr_dir / "metadata" / "tiles.json"
    metadata_path.parent.mkdir(exist_ok=True)
    with open(metadata_path, 'w') as f:
        json.dump(tile_index, f, indent=2)
        print(f"Saved metadata: {metadata_path}")        # Print storage info    zarr_size = sum(f.stat().st_size for f in zarr_dir.rglob('*') if f.is_file())    zarr_size_gb = zarr_size / (1024**3)    print(f"\nTotal Zarr size: {zarr_size_gb:.2f} GB")        return tile_index

In [25]:
# Convert a subset for testing (set max_tiles=None to convert all)
tile_index = convert_to_zarr(
    manifests_dir=manifests_dir,
    mmap_dir=mmap_dir,
    zarr_dir=zarr_dir,
    chunk_size=8,
    max_tiles=1000  # Start with 1000 tiles for testing
)

Building tile index...


Reading manifests:   0%|          | 0/227 [00:00<?, ?it/s]

Found 1,000 tiles to convert
Tile dimensions: 7 channels, 64x64 pixels

Creating Zarr group (chunk size: 8 tiles)...
Features array: shape=(1000, 7, 64, 64), chunks=(8, 7, 64, 64)
Labels array: shape=(1000, 1, 64, 64), chunks=(8, 1, 64, 64)
Index arrays: shape=(1000,), chunks=(8000,)

Copying tiles...


Converting tiles:   0%|          | 0/1000 [00:00<?, ?it/s]


Conversion complete in 51.05s (19.6 tiles/s)
Saved metadata: /scicore/home/meiera/schulz0022/projects/mining-net/data/landsat_zarr/metadata/tiles.json


In [26]:
# Verify the conversion
print("Verifying converted Zarr group...")

zarr_group = zarr.open_group(store=str(zarr_dir / "data.zarr"), mode='r')

print(f"\nZarr group arrays:")
for name in zarr_group.array_keys():
    array = zarr_group[name]
    print(f"\n{name}:")
    print(f"  Shape: {array.shape}")
    print(f"  Chunks: {array.chunks}")
    print(f"  Dtype: {array.dtype}")
    if len(array.shape) > 1:  # Data arrays
        print(f"  Size: {array.nbytes / (1024**3):.2f} GB")

# Verify we can read data
print("\nVerifying data access...")
sample_features = zarr_group['features'][0]
sample_labels = zarr_group['labels'][0]
print(f"  Sample features: {sample_features.shape}, range=[{sample_features.min():.4f}, {sample_features.max():.4f}]")
print(f"  Sample labels: {sample_labels.shape}, unique values={np.unique(sample_labels)}")

# Verify index arrays
print(f"\n  Sample indices:")
print(f"    cluster_id: {zarr_group['cluster_ids'][0]}")
print(f"    tile_ix: {zarr_group['tile_ix'][0]}")
print(f"    tile_iy: {zarr_group['tile_iy'][0]}")
print(f"    year: {zarr_group['years'][0]}")

print("\n✓ Zarr conversion verified successfully!")

Verifying converted Zarr group...

Zarr group arrays:

tile_ix:
  Shape: (1000,)
  Chunks: (8000,)
  Dtype: int32

years:
  Shape: (1000,)
  Chunks: (8000,)
  Dtype: int32

features:
  Shape: (1000, 7, 64, 64)
  Chunks: (8, 7, 64, 64)
  Dtype: float32
  Size: 0.11 GB

labels:
  Shape: (1000, 1, 64, 64)
  Chunks: (8, 1, 64, 64)
  Dtype: float32
  Size: 0.02 GB

tile_iy:
  Shape: (1000,)
  Chunks: (8000,)
  Dtype: int32

cluster_ids:
  Shape: (1000,)
  Chunks: (8000,)
  Dtype: int64

Verifying data access...
  Sample features: (7, 64, 64), range=[0.0514, 316.3787]
  Sample labels: (1, 64, 64), unique values=[0.]

  Sample indices:
    cluster_id: 6155687546600015
    tile_ix: 6841
    tile_iy: 11646
    year: 2019

✓ Zarr conversion verified successfully!


In [27]:
# Verify the index arrays
print("Verifying index arrays...")

zarr_group = zarr.open_group(store=str(zarr_dir / "data.zarr"), mode='r')

print(f"\nAvailable arrays:")
for name in zarr_group.array_keys():
    array = zarr_group[name]
    print(f"  {name}: shape={array.shape}, dtype={array.dtype}")

# Test accessing with indices
if 'cluster_ids' in zarr_group:
    print(f"\nSample data (first 5 tiles):")
    print(f"  cluster_ids: {zarr_group['cluster_ids'][:5]}")
    print(f"  tile_ix: {zarr_group['tile_ix'][:5]}")
    print(f"  tile_iy: {zarr_group['tile_iy'][:5]}")
    print(f"  years: {zarr_group['years'][:5]}")
    
    # Example: Find all tiles for a specific cluster
    cluster_id = zarr_group['cluster_ids'][0]
    mask = zarr_group['cluster_ids'][:] == cluster_id
    n_tiles_in_cluster = mask.sum()
    print(f"\nExample query: Cluster {cluster_id} has {n_tiles_in_cluster} tiles")
    
    print("\n✓ Index arrays are working correctly!")

Verifying index arrays...

Available arrays:
  labels: shape=(1000, 1, 64, 64), dtype=float32
  tile_ix: shape=(1000,), dtype=int32
  cluster_ids: shape=(1000,), dtype=int64
  tile_iy: shape=(1000,), dtype=int32
  features: shape=(1000, 7, 64, 64), dtype=float32
  years: shape=(1000,), dtype=int32

Sample data (first 5 tiles):
  cluster_ids: [6155687546600015 6155687546600015 6155687546600015 6155687546600015
 6155687546600015]
  tile_ix: [6841 6841 6841 6841 6841]
  tile_iy: [11646 11647 11648 11649 11650]
  years: [2019 2019 2019 2019 2019]

Example query: Cluster 6155687546600015 has 72 tiles

✓ Index arrays are working correctly!


## 4.2 Convert Entire Dataset to Zarr

Now let's convert the full dataset (not just a subset for testing).

In [None]:
# Convert the entire dataset
# WARNING: This will take significant time depending on dataset size
# Set max_tiles=None to convert ALL tiles
print("Starting full dataset conversion to Zarr...")
print("=" * 70)

full_tile_index = convert_to_zarr(
    manifests_dir=manifests_dir,
    mmap_dir=mmap_dir,
    zarr_dir=zarr_dir,
    chunk_size=8,
    max_tiles=None  # Convert ALL tiles
)

print("\n" + "=" * 70)
print(f"✓ Full conversion complete: {len(full_tile_index):,} tiles")
print("=" * 70)

Starting full dataset conversion to Zarr...
Building tile index...


Reading manifests:   0%|          | 0/227 [00:00<?, ?it/s]

Found 43,023 tiles to convert
Tile dimensions: 7 channels, 64x64 pixels

Creating Zarr group (chunk size: 8 tiles)...
Features array: shape=(43023, 7, 64, 64), chunks=(8, 7, 64, 64)
Labels array: shape=(43023, 1, 64, 64), chunks=(8, 1, 64, 64)
Index arrays: shape=(43023,), chunks=(8000,)

Copying tiles...


Converting tiles:   0%|          | 0/43023 [00:00<?, ?it/s]

## 5. Implement Zarr-based DataLoader

In [None]:
class ZarrMiningDataLoader(Dataset):
    """
    PyTorch Dataset for mining segmentation using Zarr backend.
    
    Optimized for batch loading with chunked Zarr arrays.
    """
    
    def __init__(
        self,
        zarr_dir: Path,
        normalize: bool = True,
        band_means: list = None,
        band_stds: list = None,
        cluster_filter: list = None
    ):
        """
        Initialize Zarr-based data loader.
        
        Args:
            zarr_dir: Directory containing Zarr group
            normalize: Whether to normalize inputs
            band_means: Precomputed band means
            band_stds: Precomputed band stds
            cluster_filter: Optional list of cluster IDs to include
        """
        self.zarr_dir = Path(zarr_dir)
        
        # Open Zarr group (lazy, no data loaded yet)
        self.zarr_group = zarr.open_group(store=str(self.zarr_dir / "data.zarr"), mode='r')
        
        self.features = self.zarr_group['features']
        self.labels = self.zarr_group['labels']
        
        # Apply cluster filter if provided
        if cluster_filter is not None:
            cluster_ids = self.zarr_group['cluster_ids'][:]
            self.valid_indices = np.where(np.isin(cluster_ids, cluster_filter))[0]
            print(f"Filtered to {len(self.valid_indices):,} tiles from clusters {cluster_filter}")
        else:
            self.valid_indices = None
        
        self.normalize = normalize
        self.band_means = band_means
        self.band_stds = band_stds
        
        n_tiles = len(self.valid_indices) if self.valid_indices is not None else self.features.shape[0]
        
        print(f"Loaded Zarr dataset:")
        print(f"  Features: shape={self.features.shape}, chunks={self.features.chunks}")
        print(f"  Labels: shape={self.labels.shape}, chunks={self.labels.chunks}")
        print(f"  Tiles: {n_tiles:,}")
    
    def __len__(self):
        if self.valid_indices is not None:
            return len(self.valid_indices)
        return self.features.shape[0]
    
    def __getitem__(self, idx):
        """
        Load a single tile by index.
        
        Zarr will efficiently load from the appropriate chunk.
        """
        # Map to actual index if filtering
        if self.valid_indices is not None:
            idx = self.valid_indices[idx]
        
        # Load from Zarr (efficient, uses chunks)
        features = torch.from_numpy(np.array(self.features[idx])).float()
        labels = torch.from_numpy(np.array(self.labels[idx])).float()
        
        # Normalize if enabled
        if self.normalize and self.band_means is not None:
            mean = torch.tensor(self.band_means, dtype=torch.float32).view(-1, 1, 1)
            std = torch.tensor(self.band_stds, dtype=torch.float32).view(-1, 1, 1)
            features = (features - mean) / (std + 1e-8)
        
        return features, labels
    
    def get_tile_metadata(self, idx):
        """Get metadata for a specific tile."""
        # Map to actual index if filtering
        if self.valid_indices is not None:
            idx = self.valid_indices[idx]
        
        return {
            'cluster_id': int(self.zarr_group['cluster_ids'][idx]),
            'tile_ix': int(self.zarr_group['tile_ix'][idx]),
            'tile_iy': int(self.zarr_group['tile_iy'][idx]),
            'year': int(self.zarr_group['years'][idx])
        }

In [None]:
# Test Zarr loader
zarr_dataset = ZarrMiningDataLoader(zarr_dir)

print(f"\nDataset length: {len(zarr_dataset)}")

# Load a sample
features, labels = zarr_dataset[0]
print(f"\nSample tile:")
print(f"  Features: {features.shape}, dtype={features.dtype}")
print(f"  Labels: {labels.shape}, dtype={labels.dtype}")
print(f"  Metadata: {zarr_dataset.get_tile_metadata(0)}")

## 6. Benchmark: Old vs. Zarr Design

We'll benchmark:
1. **Single tile access**: Random single tile reads
2. **Sequential batch loading**: Loading batches in order
3. **Random batch loading**: Loading batches randomly (realistic training)
4. **Memory usage**: Peak memory during loading

In [None]:
def benchmark_single_tile_access(dataset, n_samples=100):
    """
    Benchmark random single tile access.
    """
    indices = np.random.choice(len(dataset), n_samples, replace=False)
    
    start = time.time()
    for idx in indices:
        features, labels = dataset[idx]
    elapsed = time.time() - start
    
    return elapsed, elapsed / n_samples


def benchmark_batch_loading(dataset, batch_size=32, n_batches=50, shuffle=True, num_workers=0):
    """
    Benchmark batch loading with DataLoader.
    """
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=False  # Disable for fair comparison
    )
    
    start = time.time()
    for i, (features, labels) in enumerate(loader):
        if i >= n_batches:
            break
        # Simulate some processing
        _ = features.shape
    elapsed = time.time() - start
    
    samples_loaded = min(n_batches * batch_size, len(dataset))
    return elapsed, samples_loaded / elapsed

### 6.1 Initialize Both Datasets

In [None]:
# Old design (memory-mapped .npy)
print("Initializing old dataset (mmap .npy)...")
old_dataset = MiningSegmentationDataLoader(
    normalize=False,  # Disable for fair comparison
    auto_compute_stats=False
)

print(f"\nOld dataset: {len(old_dataset)} tiles")

In [None]:
# New design (Zarr)
print("Initializing new dataset (Zarr)...")
new_dataset = ZarrMiningDataLoader(
    zarr_dir=zarr_dir,
    normalize=False
)

print(f"\nNew dataset: {len(new_dataset)} tiles")

### 6.2 Benchmark Single Tile Access

In [None]:
print("=" * 60)
print("BENCHMARK 1: Single Tile Random Access")
print("=" * 60)

n_samples = 100

# Warmup
_ = old_dataset[0]
_ = new_dataset[0]

# Old design
print(f"\nTesting old design ({n_samples} random tiles)...")
old_total, old_per_tile = benchmark_single_tile_access(old_dataset, n_samples)
print(f"  Total time: {old_total:.3f}s")
print(f"  Per tile: {old_per_tile*1000:.2f}ms")

# New design
print(f"\nTesting new design ({n_samples} random tiles)...")
new_total, new_per_tile = benchmark_single_tile_access(new_dataset, n_samples)
print(f"  Total time: {new_total:.3f}s")
print(f"  Per tile: {new_per_tile*1000:.2f}ms")

# Comparison
speedup = old_per_tile / new_per_tile
print(f"\n{'='*60}")
print(f"Speedup: {speedup:.2f}x {'(Zarr faster)' if speedup > 1 else '(mmap faster)'}")
print(f"{'='*60}")

### 6.3 Benchmark Sequential Batch Loading

In [None]:
print("=" * 60)
print("BENCHMARK 2: Sequential Batch Loading (shuffle=False)")
print("=" * 60)

batch_size = 32
n_batches = 50

# Old design
print(f"\nTesting old design (batch_size={batch_size}, {n_batches} batches)...")
old_time, old_throughput = benchmark_batch_loading(old_dataset, batch_size, n_batches, shuffle=False)
print(f"  Total time: {old_time:.3f}s")
print(f"  Throughput: {old_throughput:.1f} tiles/s")

# New design
print(f"\nTesting new design (batch_size={batch_size}, {n_batches} batches)...")
new_time, new_throughput = benchmark_batch_loading(new_dataset, batch_size, n_batches, shuffle=False)
print(f"  Total time: {new_time:.3f}s")
print(f"  Throughput: {new_throughput:.1f} tiles/s")

# Comparison
speedup = new_throughput / old_throughput
print(f"\n{'='*60}")
print(f"Speedup: {speedup:.2f}x {'(Zarr faster)' if speedup > 1 else '(mmap faster)'}")
print(f"{'='*60}")

### 6.4 Benchmark Random Batch Loading (Training Scenario)

In [None]:
print("=" * 60)
print("BENCHMARK 3: Random Batch Loading (shuffle=True, realistic training)")
print("=" * 60)

batch_size = 32
n_batches = 50

# Old design
print(f"\nTesting old design (batch_size={batch_size}, {n_batches} batches)...")
old_time, old_throughput = benchmark_batch_loading(old_dataset, batch_size, n_batches, shuffle=True)
print(f"  Total time: {old_time:.3f}s")
print(f"  Throughput: {old_throughput:.1f} tiles/s")

# New design
print(f"\nTesting new design (batch_size={batch_size}, {n_batches} batches)...")
new_time, new_throughput = benchmark_batch_loading(new_dataset, batch_size, n_batches, shuffle=True)
print(f"  Total time: {new_time:.3f}s")
print(f"  Throughput: {new_throughput:.1f} tiles/s")

# Comparison
speedup = new_throughput / old_throughput
print(f"\n{'='*60}")
print(f"Speedup: {speedup:.2f}x {'(Zarr faster)' if speedup > 1 else '(mmap faster)'}")
print(f"{'='*60}")

### 6.5 Comprehensive Benchmark Comparison

In [None]:
# Run comprehensive benchmarks with multiple configurations
print("Running comprehensive benchmarks...\n")

results = []

# Test different batch sizes
for batch_size in [8, 16, 32, 64]:
    print(f"Testing batch_size={batch_size}...")
    
    # Old design
    old_time, old_throughput = benchmark_batch_loading(old_dataset, batch_size, 30, shuffle=True)
    
    # New design
    new_time, new_throughput = benchmark_batch_loading(new_dataset, batch_size, 30, shuffle=True)
    
    results.append({
        'batch_size': batch_size,
        'old_throughput': old_throughput,
        'new_throughput': new_throughput,
        'speedup': new_throughput / old_throughput
    })

# Create results dataframe
df_results = pd.DataFrame(results)
print("\nResults:")
print(df_results.to_string(index=False))

## 7. Visualize Benchmark Results

In [None]:
# Plot throughput comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Throughput comparison
x = df_results['batch_size']
width = 0.35
x_pos = np.arange(len(x))

axes[0].bar(x_pos - width/2, df_results['old_throughput'], width, label='Old (mmap .npy)', alpha=0.8)
axes[0].bar(x_pos + width/2, df_results['new_throughput'], width, label='New (Zarr)', alpha=0.8)
axes[0].set_xlabel('Batch Size')
axes[0].set_ylabel('Throughput (tiles/s)')
axes[0].set_title('Data Loading Throughput Comparison')
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(x)
axes[0].legend()
axes[0].grid(axis='y', alpha=0.3)

# Speedup
axes[1].plot(df_results['batch_size'], df_results['speedup'], marker='o', linewidth=2, markersize=8)
axes[1].axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='No speedup')
axes[1].set_xlabel('Batch Size')
axes[1].set_ylabel('Speedup (Zarr / mmap)')
axes[1].set_title('Zarr Speedup vs Batch Size')
axes[1].grid(alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.savefig('benchmark_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("Saved plot: benchmark_results.png")

## 8. Summary and Recommendations

In [None]:
print("=" * 70)
print("BENCHMARK SUMMARY")
print("=" * 70)

avg_speedup = df_results['speedup'].mean()
best_speedup = df_results['speedup'].max()
worst_speedup = df_results['speedup'].min()

print(f"\nAverage speedup: {avg_speedup:.2f}x")
print(f"Best speedup: {best_speedup:.2f}x (batch_size={df_results.loc[df_results['speedup'].idxmax(), 'batch_size']})")
print(f"Worst speedup: {worst_speedup:.2f}x (batch_size={df_results.loc[df_results['speedup'].idxmin(), 'batch_size']})")

print("\n" + "=" * 70)
print("KEY FINDINGS")
print("=" * 70)

if avg_speedup > 1.1:
    print("\n✓ Zarr format provides SIGNIFICANT performance improvement")
    print("  Recommendation: MIGRATE to Zarr for production use")
elif avg_speedup > 0.9:
    print("\n≈ Zarr and mmap .npy have SIMILAR performance")
    print("  Recommendation: Consider other factors (storage, scalability)")
else:
    print("\n✗ Zarr format is SLOWER than current mmap approach")
    print("  Recommendation: Keep current mmap .npy design")

print("\nAdditional Benefits of Zarr:")
print("  • Better chunking for batch operations")
print("  • Single file per array (easier management)")
print("  • Optional compression for storage savings")
print("  • Cloud-native storage support (S3, GCS)")
print("  • Parallel write capabilities")

print("\n" + "=" * 70)

## 9. Optional: Test Different Chunk Sizes

In [None]:
# This section can be used to experiment with different chunk sizes
# Uncomment and run to test chunk_size = 4, 8, 16, 32

# chunk_results = []
# for chunk_size in [4, 8, 16, 32]:
#     print(f"\nTesting chunk_size={chunk_size}...")
#     
#     # Convert with this chunk size
#     test_zarr_dir = config.DATA_DIR / f"landsat_zarr_chunk{chunk_size}"
#     convert_to_zarr(
#         manifests_dir=manifests_dir,
#         mmap_dir=mmap_dir,
#         zarr_dir=test_zarr_dir,
#         chunk_size=chunk_size,
#         max_tiles=500
#     )
#     
#     # Benchmark
#     test_dataset = ZarrMiningDataLoader(test_zarr_dir, normalize=False)
#     elapsed, throughput = benchmark_batch_loading(test_dataset, 32, 30, shuffle=True)
#     
#     chunk_results.append({
#         'chunk_size': chunk_size,
#         'throughput': throughput
#     })
# 
# df_chunks = pd.DataFrame(chunk_results)
# print("\nChunk size analysis:")
# print(df_chunks)

## Conclusion

This notebook demonstrates:
1. ✓ Conversion from mmap .npy tiles to Zarr format
2. ✓ Implementation of Zarr-based PyTorch DataLoader
3. ✓ Comprehensive benchmarks comparing both approaches
4. ✓ Performance analysis across different batch sizes

**Next Steps:**
- Run full conversion if Zarr shows better performance
- Update training scripts to use ZarrMiningDataLoader
- Consider cloud storage backends (S3, GCS) for distributed training
- Experiment with compression codecs if storage is a concern