# Convert Zarr Data to PyTorch MemoryMappedTensor Format

This notebook converts existing Zarr-stored tiles to PyTorch's MemoryMappedTensor format for faster data loading.

**Benefits of MemoryMappedTensor:**
- Zero-copy memory mapping for instant access
- Native PyTorch integration (no conversion overhead)
- Extremely fast random access
- Lower memory footprint during training
- Direct GPU transfer without intermediate copies

In [None]:
import sys
import logging
from pathlib import Path
import numpy as np
import torch
import xarray as xr
from tqdm.auto import tqdm
import json
import hashlib

# Add src to path
sys.path.insert(0, str(Path.cwd().parent))

from data.database import DownloadDatabase
from data.config import Config
from data.tasks import compute_cluster_id
from odc.geo.geobox import GeoBox, GeoboxTiles

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## Configuration

## Database Migration (Old to New Schema)

If you're upgrading from the old database schema, run this cell first to add the new columns.


In [3]:
import sqlite3
from pathlib import Path

def migrate_database_schema(db_path: str):
    """
    Migrate database from old schema to new schema.
    Adds mmap_path and mmap_written columns to tiles table if they don't exist.
    
    Args:
        db_path: Path to the SQLite database file
    """
    db_path = Path(db_path)
    if not db_path.exists():
        print(f"Database does not exist: {db_path}")
        return
    
    conn = sqlite3.connect(str(db_path))
    cursor = conn.cursor()
    
    try:
        # Get current table info
        cursor.execute("PRAGMA table_info(tiles)")
        columns = {row[1] for row in cursor.fetchall()}
        
        print(f"Current columns in tiles table: {columns}")
        
        # Add missing columns if they don't exist
        columns_to_add = []
        
        if 'mmap_path' not in columns:
            columns_to_add.append("mmap_path TEXT")
        
        if 'mmap_written' not in columns:
            columns_to_add.append("mmap_written BOOLEAN DEFAULT 0")
        
        if 'mmap_written_at' not in columns:
            columns_to_add.append("mmap_written_at TEXT")
        
        if columns_to_add:
            print(f"\nAdding columns: {columns_to_add}")
            
            for column_def in columns_to_add:
                alter_sql = f"ALTER TABLE tiles ADD COLUMN {column_def}"
                try:
                    cursor.execute(alter_sql)
                    print(f"  ✓ Added: {column_def}")
                except sqlite3.OperationalError as e:
                    print(f"  ✗ Failed to add {column_def}: {e}")
            
            # Create indices on new columns
            index_commands = [
                ("idx_tiles_mmap_written", "ON tiles(mmap_written)"),
                ("idx_tiles_mmap_path", "ON tiles(mmap_path)")
            ]
            
            for idx_name, idx_def in index_commands:
                try:
                    cursor.execute(f"CREATE INDEX IF NOT EXISTS {idx_name} {idx_def}")
                    print(f"  ✓ Created index: {idx_name}")
                except sqlite3.OperationalError as e:
                    print(f"  ✗ Failed to create index {idx_name}: {e}")
            
            conn.commit()
            print("\n✓ Migration completed successfully!")
        else:
            print("✓ Schema is already up to date. No migration needed.")
    
    except Exception as e:
        print(f"✗ Migration failed: {e}")
        conn.rollback()
        raise
    finally:
        conn.close()


# Perform migration
if config.DB_PATH.exists():
    print(f"Migrating database: {config.DB_PATH}\n")
    migrate_database_schema(str(config.DB_PATH))
else:
    print(f"Database will be created fresh: {config.DB_PATH}")


Migrating database: C:\Users\schulz0022\Documents\mining-net\data\mining_segmentation.db

Current columns in tiles table: {'year', 'written_at', 'tile_ix', 'cluster_id', 'zarr_written', 'created_at', 'geometry_hash', 'tile_iy'}

Adding columns: ['mmap_path TEXT', 'mmap_written BOOLEAN DEFAULT 0', 'mmap_written_at TEXT']
  ✓ Added: mmap_path TEXT
  ✓ Added: mmap_written BOOLEAN DEFAULT 0
  ✓ Added: mmap_written_at TEXT
  ✓ Created index: idx_tiles_mmap_written
  ✓ Created index: idx_tiles_mmap_path

✓ Migration completed successfully!


In [4]:
# Configuration
config = Config()
db = DownloadDatabase(str(config.DB_PATH))

# Paths
ZARR_PATH = config.DATA_DIR / "global_landsat.zarr"
MMAP_DIR = config.DATA_DIR / "landsat_mmap"
MMAP_DIR.mkdir(exist_ok=True, parents=True)

# Bands to convert
BANDS = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'thermal']

print(f"Zarr source: {ZARR_PATH}")
print(f"Memory-mapped output: {MMAP_DIR}")
print(f"Bands: {BANDS}")

Zarr source: C:\Users\schulz0022\Documents\mining-net\data\global_landsat.zarr
Memory-mapped output: C:\Users\schulz0022\Documents\mining-net\data\landsat_mmap
Bands: ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'thermal']


## Load Zarr Dataset and Get Tiles to Convert

In [5]:
# Open zarr dataset
zarr_ds = xr.open_zarr(str(ZARR_PATH), consolidated=False, chunks='auto')

print("Zarr dataset loaded:")
print(zarr_ds)

Zarr dataset loaded:
<xarray.Dataset> Size: 26TB
Dimensions:           (latitude: 667916, longitude: 1335832)
Coordinates:
  * latitude          (latitude) float64 5MB 90.0 90.0 90.0 ... -90.0 -90.0
  * longitude         (longitude) float64 11MB -180.0 -180.0 ... 180.0 180.0
Data variables:
    swir1             (latitude, longitude) float32 4TB dask.array<chunksize=(5632, 5632), meta=np.ndarray>
    blue              (latitude, longitude) float32 4TB dask.array<chunksize=(5632, 5632), meta=np.ndarray>
    green             (latitude, longitude) float32 4TB dask.array<chunksize=(5632, 5632), meta=np.ndarray>
    nir               (latitude, longitude) float32 4TB dask.array<chunksize=(5632, 5632), meta=np.ndarray>
    red               (latitude, longitude) float32 4TB dask.array<chunksize=(5632, 5632), meta=np.ndarray>
    mining_footprint  (latitude, longitude) uint8 892GB dask.array<chunksize=(11264, 11264), meta=np.ndarray>
    swir2             (latitude, longitude) float32 4TB da

In [None]:
# Get all written tiles from database with country code for cluster ID computation
with db.get_connection() as conn:
    cursor = conn.cursor()
    cursor.execute("""
        SELECT t.tile_ix, t.tile_iy, t.cluster_id, t.year, 
               tasks.country_code, tasks.geometry_hash, tasks.mining_footprint_json
        FROM tiles t
        JOIN tasks ON t.geometry_hash = tasks.geometry_hash 
            AND t.year = tasks.year
        WHERE t.zarr_written = 1
        ORDER BY t.cluster_id, t.year, t.tile_ix, t.tile_iy
    """)
    tiles = [dict(row) for row in cursor.fetchall()]

print(f"Found {len(tiles)} tiles to convert")
if tiles:
    sample = tiles[0]
    print(f"\nSample tile: {sample}")
    print(f"  Old cluster_id: {sample['cluster_id']}")
    print(f"  Country code: {sample['country_code']}")
    print(f"  Has mining_footprint: {sample['mining_footprint_json'] is not None}")

Found 34931 tiles to convert

Sample tile: {'tile_ix': 7043, 'tile_iy': 12013, 'cluster_id': 0, 'year': 2019, 'country_code': 'ZAF', 'geometry_hash': '23fcc4766d9bf67565d3221bd0313d14149927ac5e6182b34b8a6c859a104d68'}


## Helper Functions

In [None]:
def load_tile_from_zarr(
    tile_ix: int,
    tile_iy: int,
    bands: list,
    include_footprint: bool = True
):
    """Load tile data from Zarr."""
    # Reconstruct world geobox
    world_geobox = GeoBox.from_bbox(
        [-180, -90, 180, 90],
        resolution=config.WORLD_GEOBOX_RESOLUTION,
        crs=4326
    )
    world_geobox_tiles = GeoboxTiles(
        world_geobox,
        tile_shape=config.WORLD_GEOBOX_TILE_SIZE
    )
    
    # Get tile geobox
    tile_geobox = world_geobox_tiles[tile_ix, tile_iy]
    bounds = tile_geobox.boundingbox
    
    # Load band data
    band_arrays = [
        zarr_ds[band].sel(
            latitude=slice(bounds.top, bounds.bottom),
            longitude=slice(bounds.left, bounds.right)
        ).values
        for band in bands
    ]
    
    # Stack to (H, W, C)
    features = np.stack(band_arrays, axis=-1).astype(np.float32)
    
    # Load mining footprint
    labels = None
    if include_footprint:
        labels = zarr_ds['mining_footprint'].sel(
            latitude=slice(bounds.top, bounds.bottom),
            longitude=slice(bounds.left, bounds.right)
        ).values
        
        if labels.ndim == 2:
            labels = labels[..., np.newaxis]
        
        labels = labels.astype(np.float32)
    
    return features, labels


def save_tile_as_mmap(
    features: np.ndarray,
    labels: np.ndarray,
    tile_ix: int,
    tile_iy: int,
    year: int,
    cluster_id: int,
    geometry_hash: str,
    country_code: str,
    mining_footprint_json: dict,
    output_dir: Path
) -> Path:
    """Save tile as memory-mapped PyTorch tensors.
    
    Args:
        features: Feature array (H, W, C)
        labels: Label array (H, W, C)
        tile_ix: Tile X index
        tile_iy: Tile Y index
        year: Year
        cluster_id: Original (old) cluster ID
        geometry_hash: Geometry hash
        country_code: ISO3 country code
        mining_footprint_json: Mining footprint GeoJSON
        output_dir: Output directory root
    
    Returns:
        Tuple of (tile_dir, new_cluster_id)
    """
    # Compute new globally unique cluster ID from country + mining footprint
    new_cluster_id = compute_cluster_id(country_code, cluster_id, mining_footprint_json)
    
    # Create tile directory: new_cluster_id/year/tile_ix_tile_iy/
    tile_dir = output_dir / str(new_cluster_id) / str(year) / f"{tile_ix}_{tile_iy}"
    tile_dir.mkdir(parents=True, exist_ok=True)
    
    # Replace NaN with 0
    features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
    labels = np.nan_to_num(labels, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Convert to torch tensors: (H, W, C) -> (C, H, W)
    features_tensor = torch.from_numpy(features).float().permute(2, 0, 1)
    labels_tensor = torch.from_numpy(labels).float().permute(2, 0, 1)
    
    # Save as memory-mapped tensors
    features_path = tile_dir / "features.pt"
    labels_path = tile_dir / "labels.pt"
    
    # Save with shared memory for fast loading
    torch.save(features_tensor, features_path)
    torch.save(labels_tensor, labels_path)
    
    # Save metadata (without mmap_written_at - database tracks that)
    metadata = {
        "tile_ix": tile_ix,
        "tile_iy": tile_iy,
        "year": year,
        "cluster_id_old": cluster_id,
        "cluster_id_new": new_cluster_id,
        "geometry_hash": geometry_hash,
        "country_code": country_code,
        "features_shape": list(features_tensor.shape),
        "labels_shape": list(labels_tensor.shape),
        "bands": BANDS,
        "dtype": "float32"
    }
    
    metadata_path = tile_dir / "metadata.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    return tile_dir, new_cluster_id


def update_database_mmap_status(
    tile_ix: int,
    tile_iy: int,
    geometry_hash: str,
    year: int,
    new_cluster_id: int
):
    """Update database to track memory-mapped tile with new cluster ID.
    
    Args:
        tile_ix: Tile X index
        tile_iy: Tile Y index
        geometry_hash: Geometry hash
        year: Year
        new_cluster_id: New globally unique cluster ID
    """
    with db.get_connection() as conn:
        cursor = conn.cursor()
        cursor.execute("""
            UPDATE tiles 
            SET cluster_id = ?, mmap_written = 1, mmap_written_at = datetime('now')
            WHERE tile_ix = ? AND tile_iy = ? 
              AND geometry_hash = ? AND year = ?
        """, (new_cluster_id, tile_ix, tile_iy, geometry_hash, year))

print("Helper functions defined with cluster ID migration")

Helper functions defined


## Convert Tiles

In [None]:
# Test conversion on a single tile first
if tiles:
    test_tile = tiles[0]
    print(f"Testing with tile: {test_tile}")
    
    # Load from zarr
    features, labels = load_tile_from_zarr(
        test_tile['tile_ix'],
        test_tile['tile_iy'],
        BANDS
    )
    
    print(f"Loaded features shape: {features.shape}, dtype: {features.dtype}")
    print(f"Loaded labels shape: {labels.shape}, dtype: {labels.dtype}")
    
    # Parse mining footprint
    mining_footprint = json.loads(test_tile['mining_footprint_json']) if test_tile['mining_footprint_json'] else None
    
    # Save as mmap (computes new cluster ID)
    tile_dir, new_cluster_id = save_tile_as_mmap(
        features, labels,
        test_tile['tile_ix'], test_tile['tile_iy'],
        test_tile['year'], test_tile['cluster_id'],
        test_tile['geometry_hash'],
        test_tile['country_code'],
        mining_footprint,
        MMAP_DIR
    )
    
    print(f"\nCluster ID migration:")
    print(f"  Old: {test_tile['cluster_id']}")
    print(f"  New: {new_cluster_id}")
    print(f"  Saved to: {tile_dir}")
    print(f"  Files: {list(tile_dir.iterdir())}")
    
    # Verify by loading back
    loaded_features = torch.load(tile_dir / "features.pt")
    loaded_labels = torch.load(tile_dir / "labels.pt")
    
    print(f"\nLoaded back features shape: {loaded_features.shape}")
    print(f"Loaded back labels shape: {loaded_labels.shape}")
    print(f"\nVerification: shapes match = {loaded_features.shape == torch.Size([len(BANDS), features.shape[0], features.shape[1]])}")
    
    # Check metadata doesn't include mmap_written_at
    with open(tile_dir / "metadata.json") as f:
        metadata = json.load(f)
    print(f"\nMetadata keys: {list(metadata.keys())}")
    print(f"Has mmap_written_at: {'mmap_written_at' in metadata} (should be False - tracked in DB)")
    
    # Show memory-mapped access speed
    import time
    start = time.time()
    for _ in range(10):
        _ = torch.load(tile_dir / "features.pt")
    elapsed = time.time() - start
    print(f"\n10 loads took {elapsed:.3f}s ({elapsed/10*1000:.2f}ms per load)")

Testing with tile: {'tile_ix': 7043, 'tile_iy': 12013, 'cluster_id': 0, 'year': 2019, 'country_code': 'ZAF', 'geometry_hash': '23fcc4766d9bf67565d3221bd0313d14149927ac5e6182b34b8a6c859a104d68'}
Loaded features shape: (64, 64, 7), dtype: float32
Loaded labels shape: (64, 64, 1), dtype: float32

Saved to: C:\Users\schulz0022\Documents\mining-net\data\landsat_mmap\0\2019\7043_12013
Files: [WindowsPath('C:/Users/schulz0022/Documents/mining-net/data/landsat_mmap/0/2019/7043_12013/features.pt'), WindowsPath('C:/Users/schulz0022/Documents/mining-net/data/landsat_mmap/0/2019/7043_12013/labels.pt'), WindowsPath('C:/Users/schulz0022/Documents/mining-net/data/landsat_mmap/0/2019/7043_12013/metadata.json')]

Loaded back features shape: torch.Size([7, 64, 64])
Loaded back labels shape: torch.Size([1, 64, 64])

Verification: shapes match = True

10 loads took 0.004s (0.36ms per load)


## Batch Convert All Tiles

In [None]:
# Convert all tiles
converted_count = 0
error_count = 0
errors = []
cluster_id_migrations = {}  # Track old -> new cluster ID mappings

print(f"Converting {len(tiles)} tiles...\n")

for tile in tqdm(tiles, desc="Converting tiles"):
    try:
        # Load from zarr
        features, labels = load_tile_from_zarr(
            tile['tile_ix'],
            tile['tile_iy'],
            BANDS
        )
        
        # Parse mining footprint
        mining_footprint = json.loads(tile['mining_footprint_json']) if tile['mining_footprint_json'] else None
        
        # Save as mmap (computes new cluster ID)
        tile_dir, new_cluster_id = save_tile_as_mmap(
            features, labels,
            tile['tile_ix'], tile['tile_iy'],
            tile['year'], tile['cluster_id'],
            tile['geometry_hash'],
            tile['country_code'],
            mining_footprint,
            MMAP_DIR
        )
        
        # Track cluster ID migration
        old_id = tile['cluster_id']
        if old_id not in cluster_id_migrations:
            cluster_id_migrations[old_id] = new_cluster_id
        
        # Update database with new cluster ID
        update_database_mmap_status(
            tile['tile_ix'], tile['tile_iy'],
            tile['geometry_hash'], tile['year'],
            new_cluster_id
        )
        
        converted_count += 1
        
    except Exception as e:
        error_count += 1
        errors.append({
            'tile': tile,
            'error': str(e)
        })
        logger.error(f"Error converting tile {tile}: {e}")

print(f"\n{'='*60}")
print(f"Conversion complete!")
print(f"Successfully converted: {converted_count}")
print(f"Errors: {error_count}")
print(f"Unique old cluster IDs migrated: {len(cluster_id_migrations)}")

if cluster_id_migrations:
    print(f"\nCluster ID migrations (first 5):")
    for old_id, new_id in list(cluster_id_migrations.items())[:5]:
        print(f"  {old_id} → {new_id}")

if errors:
    print(f"\nFirst few errors:")
    for err in errors[:5]:
        print(f"  Tile {err['tile']}: {err['error']}")

Converting 34931 tiles...



Converting tiles:   0%|          | 0/34931 [00:00<?, ?it/s]


Conversion complete!
Successfully converted: 34931
Errors: 0


## Performance Comparison

In [10]:
# Compare load times: Zarr vs MemoryMapped
import time

if tiles:
    test_tile = tiles[0]
    n_iterations = 20
    
    # Time Zarr loading
    start = time.time()
    for _ in range(n_iterations):
        features, labels = load_tile_from_zarr(
            test_tile['tile_ix'],
            test_tile['tile_iy'],
            BANDS
        )
        # Convert to torch (what data loader does)
        features_t = torch.from_numpy(features).float().permute(2, 0, 1)
        labels_t = torch.from_numpy(labels).float().permute(2, 0, 1)
    zarr_time = (time.time() - start) / n_iterations
    
    # Time MemoryMapped loading
    tile_dir = MMAP_DIR / str(test_tile['cluster_id']) / str(test_tile['year']) / f"{test_tile['tile_ix']}_{test_tile['tile_iy']}"
    
    start = time.time()
    for _ in range(n_iterations):
        features_t = torch.load(tile_dir / "features.pt")
        labels_t = torch.load(tile_dir / "labels.pt")
    mmap_time = (time.time() - start) / n_iterations
    
    speedup = zarr_time / mmap_time
    
    print(f"Performance Comparison ({n_iterations} iterations):")
    print(f"  Zarr + conversion:  {zarr_time*1000:.2f} ms/tile")
    print(f"  MemoryMapped:       {mmap_time*1000:.2f} ms/tile")
    print(f"  Speedup:            {speedup:.2f}x faster")

Performance Comparison (20 iterations):
  Zarr + conversion:  58.91 ms/tile
  MemoryMapped:       1.42 ms/tile
  Speedup:            41.49x faster


## Summary

## Verification

Verify the conversion was successful and check data integrity.


In [None]:
# Verify MMAP conversion
print("Verifying MMAP conversion...\n")

with db.get_connection() as conn:
    cursor = conn.cursor()
    
    # Count mmap_written tiles
    cursor.execute("SELECT COUNT(*) FROM tiles WHERE mmap_written = 1")
    mmap_count = cursor.fetchone()[0]
    
    # Get sample mmap tiles
    cursor.execute("""
        SELECT tile_ix, tile_iy, year, cluster_id
        FROM tiles
        WHERE mmap_written = 1
        LIMIT 5
    """)
    samples = cursor.fetchall()

print(f"Database Statistics:")
print(f"  Tiles marked as mmap_written: {mmap_count}")

if samples:
    print(f"\nSample tiles (with new cluster IDs):")
    for sample in samples:
        print(f"  Tile {sample[0]}_{sample[1]} (year={sample[2]}, cluster_id={sample[3]})")
        # Path can be reconstructed: {cluster_id}/{year}/{tile_ix}_{tile_iy}/
        path = MMAP_DIR / str(sample[3]) / str(sample[2]) / f"{sample[0]}_{sample[1]}"
        print(f"    Path: {path}")

# Verify file integrity
print(f"\nFile System Statistics:")
print(f"  MMAP directory: {MMAP_DIR}")
print(f"  MMAP directory size: {get_dir_size(MMAP_DIR) / (1024**3):.2f} GB")

# Check for missing metadata files
missing_metadata = 0
for cluster_dir in MMAP_DIR.iterdir():
    if not cluster_dir.is_dir():
        continue
    for year_dir in cluster_dir.iterdir():
        if not year_dir.is_dir():
            continue
        for tile_dir in year_dir.iterdir():
            if not tile_dir.is_dir():
                continue
            if not (tile_dir / "metadata.json").exists():
                missing_metadata += 1

if missing_metadata == 0:
    print(f"  ✓ All tiles have metadata.json")
else:
    print(f"  ✗ Missing metadata: {missing_metadata} tiles")

# Verify metadata doesn't store mmap_written_at (should be in DB only)
checked_metadata_count = 0
has_mmap_written_at = 0

for cluster_dir in MMAP_DIR.iterdir():
    if not cluster_dir.is_dir():
        continue
    for year_dir in cluster_dir.iterdir():
        if not year_dir.is_dir():
            continue
        for tile_dir in year_dir.iterdir():
            if not tile_dir.is_dir():
                continue
            metadata_path = tile_dir / "metadata.json"
            if metadata_path.exists():
                checked_metadata_count += 1
                with open(metadata_path) as f:
                    metadata = json.load(f)
                if 'mmap_written_at' in metadata:
                    has_mmap_written_at += 1
            if checked_metadata_count >= 10:
                break
        if checked_metadata_count >= 10:
            break
    if checked_metadata_count >= 10:
        break

print(f"  ✓ Metadata sanity check: {checked_metadata_count} tiles checked")
print(f"    - Has mmap_written_at in metadata: {has_mmap_written_at} (should be 0 - tracked in DB)")

print(f"\n{'='*60}")
print(f"Conversion Status: {'✓ COMPLETE' if mmap_count > 0 else '✗ INCOMPLETE'}")
print(f"{'='*60}")

Verifying MMAP conversion...

Database Statistics:
  Tiles marked as mmap_written: 0
  Tiles with mmap_path: 0

File System Statistics:
  MMAP directory: C:\Users\schulz0022\Documents\mining-net\data\landsat_mmap
  MMAP directory size: 4.38 GB
  ✓ All tiles have metadata.json
  ✓ Metadata sanity check: 10 tiles checked
    - Has mmap_written_at in metadata: 0 (should be 0 - tracked in DB)

Conversion Status: ✗ INCOMPLETE


In [12]:
# Calculate storage sizes
import os

def get_dir_size(path):
    total = 0
    for dirpath, dirnames, filenames in os.walk(path):
        for filename in filenames:
            filepath = os.path.join(dirpath, filename)
            total += os.path.getsize(filepath)
    return total

mmap_size = get_dir_size(MMAP_DIR)
mmap_size_gb = mmap_size / (1024**3)

print(f"{'='*60}")
print(f"CONVERSION SUMMARY")
print(f"{'='*60}")
print(f"Source:              {ZARR_PATH}")
print(f"Destination:         {MMAP_DIR}")
print(f"Tiles converted:     {converted_count}")
print(f"Errors:              {error_count}")
print(f"Total size:          {mmap_size_gb:.2f} GB")
print(f"Avg per tile:        {mmap_size_gb/converted_count*1024:.2f} MB" if converted_count > 0 else "N/A")
print(f"Bands:               {BANDS}")
print(f"Format:              PyTorch MemoryMappedTensor")
print(f"Index file:          {MMAP_DIR / 'index.json'}")
print(f"\nNext steps:")
print(f"  1. Update database schema (see update_database_mmap_status)")
print(f"  2. Update data_loader.py to use new format")
print(f"  3. Test training with new data loader")
print(f"{'='*60}")

CONVERSION SUMMARY
Source:              C:\Users\schulz0022\Documents\mining-net\data\global_landsat.zarr
Destination:         C:\Users\schulz0022\Documents\mining-net\data\landsat_mmap
Tiles converted:     34931
Errors:              0
Total size:          4.38 GB
Avg per tile:        0.13 MB
Bands:               ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'thermal']
Format:              PyTorch MemoryMappedTensor
Index file:          C:\Users\schulz0022\Documents\mining-net\data\landsat_mmap\index.json

Next steps:
  1. Update database schema (see update_database_mmap_status)
  2. Update data_loader.py to use new format
  3. Test training with new data loader


## Sync Database Status from Disk

Update the database to mark all tiles as mmap_written based on what exists on disk. This is useful if:
- You're running this after a conversion
- The database is out of sync with the filesystem
- You need to rebuild the database status from existing MMAP files

In [14]:
def sync_database_from_disk(
    db_path: str,
    mmap_dir: str,
    verbose: bool = True
):
    """
    Sync database status from what exists on disk.
    
    Scans the MMAP directory structure and updates the database with:
    - mmap_written = 1 for each tile found on disk
    - cluster_id = new cluster ID (from metadata.json)
    - mmap_written_at = current timestamp
    
    This is useful when:
    - Database is out of sync with filesystem
    - Rebuilding database from existing MMAP files
    - Verifying conversion completeness
    
    Args:
        db_path: Path to SQLite database
        mmap_dir: Path to MMAP output directory
        verbose: Print progress information
    
    Returns:
        Dictionary with update statistics
    """
    db_path = Path(db_path)
    mmap_dir = Path(mmap_dir)
    
    if not db_path.exists():
        print(f"✗ Database not found: {db_path}")
        return None
    
    if not mmap_dir.exists():
        print(f"✗ MMAP directory not found: {mmap_dir}")
        return None
    
    conn = sqlite3.connect(str(db_path))
    cursor = conn.cursor()
    
    stats = {
        'tiles_found': 0,
        'tiles_updated': 0,
        'tiles_not_in_db': 0,
        'errors': 0,
        'cluster_migrations': {}
    }
    
    try:
        # Scan MMAP directory structure: cluster_id/year/tile_ix_tile_iy/
        for cluster_dir in sorted(mmap_dir.iterdir()):
            if not cluster_dir.is_dir():
                continue
            
            cluster_id = int(cluster_dir.name)
            
            for year_dir in sorted(cluster_dir.iterdir()):
                if not year_dir.is_dir():
                    continue
                
                year = int(year_dir.name)
                
                for tile_dir in sorted(year_dir.iterdir()):
                    if not tile_dir.is_dir():
                        continue
                    
                    stats['tiles_found'] += 1
                    
                    # Parse tile coordinates from directory name
                    tile_name = tile_dir.name
                    try:
                        tile_ix, tile_iy = map(int, tile_name.split('_'))
                    except (ValueError, IndexError):
                        if verbose:
                            print(f"  ✗ Could not parse tile name: {tile_name}")
                        stats['errors'] += 1
                        continue
                    
                    # Load metadata to get geometry_hash
                    metadata_path = tile_dir / "metadata.json"
                    if not metadata_path.exists():
                        if verbose:
                            print(f"  ✗ No metadata for {tile_name}")
                        stats['errors'] += 1
                        continue
                    
                    with open(metadata_path, 'r') as f:
                        metadata = json.load(f)
                    
                    geometry_hash = metadata.get('geometry_hash')
                    if not geometry_hash:
                        if verbose:
                            print(f"  ✗ No geometry_hash in metadata for {tile_name}")
                        stats['errors'] += 1
                        continue
                    
                    # Track cluster migrations
                    old_cluster_id = metadata.get('cluster_id_old')
                    new_cluster_id = metadata.get('cluster_id_new', cluster_id)
                    if old_cluster_id and old_cluster_id not in stats['cluster_migrations']:
                        stats['cluster_migrations'][old_cluster_id] = new_cluster_id
                    
                    # Update database
                    try:
                        cursor.execute("""
                            UPDATE tiles 
                            SET mmap_written = 1, 
                                cluster_id = ?,
                                mmap_written_at = datetime('now')
                            WHERE tile_ix = ? AND tile_iy = ? 
                              AND geometry_hash = ? AND year = ?
                        """, (new_cluster_id, tile_ix, tile_iy, geometry_hash, year))
                        
                        if cursor.rowcount > 0:
                            stats['tiles_updated'] += 1
                        else:
                            stats['tiles_not_in_db'] += 1
                            if verbose:
                                print(f"  ⚠ Tile not in DB: {tile_ix}_{tile_iy} ({geometry_hash[:8]}, {year})")
                    except Exception as e:
                        if verbose:
                            print(f"  ✗ Error updating {tile_name}: {e}")
                        stats['errors'] += 1
        
        conn.commit()
        
        print(f"\n{'='*60}")
        print(f"DATABASE SYNC COMPLETE")
        print(f"{'='*60}")
        print(f"Tiles found on disk:        {stats['tiles_found']}")
        print(f"Tiles updated in database:  {stats['tiles_updated']}")
        print(f"Tiles not in database:      {stats['tiles_not_in_db']}")
        print(f"Errors:                     {stats['errors']}")
        print(f"Cluster ID migrations:      {len(stats['cluster_migrations'])}")
        
        if stats['cluster_migrations'] and verbose:
            print(f"\nSample cluster migrations (old → new):")
            for old_id, new_id in list(stats['cluster_migrations'].items())[:3]:
                print(f"  {old_id} → {new_id}")
        
        print(f"{'='*60}")
        
    except Exception as e:
        print(f"✗ Sync failed: {e}")
        conn.rollback()
        raise
    finally:
        conn.close()
    
    return stats


# Run database sync from disk
print("Syncing database status from MMAP files on disk...\n")
sync_stats = sync_database_from_disk(
    str(config.DB_PATH),
    str(MMAP_DIR),
    verbose=True
)

Syncing database status from MMAP files on disk...


DATABASE SYNC COMPLETE
Tiles found on disk:        34931
Tiles updated in database:  34931
Tiles not in database:      0
Errors:                     0
Cluster ID migrations:      0
