# Visualize Lambda Production Results

This notebook reads parquet files from the Lambda production run and visualizes them using xdggs/xarray.

**Input:** Parquet files in `s3://{bucket}/{prefix}/{morton}.parquet`

**Columns in each parquet file:**
- `child_morton`: Morton index at order 12
- `child_healpix`: HEALPix cell ID at order 12
- `count`: Number of observations
- `h_mean`: Weighted mean elevation
- `h_sigma`: Uncertainty in mean
- `h_min`, `h_max`: Elevation range
- `h_variance`, `h_q25`, `h_q50`, `h_q75`: Distribution stats

## 1. Imports and Configuration

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import xdggs
import s3fs
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import warnings
warnings.filterwarnings('ignore')

print("Imports complete")

In [None]:
# Configuration
S3_BUCKET = "xagg"
S3_PREFIX = "atl06/production"
CHILD_ORDER = 12

print(f"Reading from: s3://{S3_BUCKET}/{S3_PREFIX}/")

## 2. List and Load Parquet Files from S3

In [None]:
# Connect to S3
s3 = s3fs.S3FileSystem(anon=False)

# List all parquet files
parquet_files = s3.glob(f"{S3_BUCKET}/{S3_PREFIX}/*.parquet")
print(f"Found {len(parquet_files)} parquet files")

if parquet_files:
    print(f"\nFirst 5 files:")
    for f in sorted(parquet_files)[:5]:
        print(f"  s3://{f}")

In [None]:
# Read all parquet files into a single DataFrame
print(f"Reading {len(parquet_files)} parquet files...")

all_dfs = []
for i, parquet_path in enumerate(sorted(parquet_files)):
    try:
        with s3.open(parquet_path, 'rb') as f:
            df = pd.read_parquet(f)
            # Only keep rows with data
            df = df[df['count'] > 0]
            if len(df) > 0:
                all_dfs.append(df)
        
        if (i + 1) % 200 == 0:
            print(f"  Loaded {i + 1}/{len(parquet_files)} files...")
    except Exception as e:
        print(f"  Error loading {parquet_path}: {e}")

# Combine all DataFrames
df_all = pd.concat(all_dfs, ignore_index=True)
print(f"\nLoaded {len(df_all):,} cells with data from {len(all_dfs)} files")
print(f"\nDataFrame shape: {df_all.shape}")
print(f"Columns: {list(df_all.columns)}")
df_all.head()

## 3. Convert to xdggs Dataset

In [None]:
# Create xarray Dataset with HEALPix cell_ids
cell_ids = df_all['child_healpix'].values

ds = xr.Dataset(
    data_vars={
        'count': ('cell_ids', df_all['count'].values.astype(np.int32)),
        'h_mean': ('cell_ids', df_all['h_mean'].values.astype(np.float32)),
        'h_sigma': ('cell_ids', df_all['h_sigma'].values.astype(np.float32)),
        'h_min': ('cell_ids', df_all['h_min'].values.astype(np.float32)),
        'h_max': ('cell_ids', df_all['h_max'].values.astype(np.float32)),
        'h_variance': ('cell_ids', df_all['h_variance'].values.astype(np.float32)),
        'h_q25': ('cell_ids', df_all['h_q25'].values.astype(np.float32)),
        'h_q50': ('cell_ids', df_all['h_q50'].values.astype(np.float32)),
        'h_q75': ('cell_ids', df_all['h_q75'].values.astype(np.float32)),
    },
    coords={
        'cell_ids': (
            'cell_ids',
            cell_ids,
            {'grid_name': 'healpix', 'level': CHILD_ORDER, 'indexing_scheme': 'nested'}
        ),
        'morton': ('cell_ids', df_all['child_morton'].values)
    },
    attrs={
        'title': 'ATL06 Lambda Production Results',
        'child_order': CHILD_ORDER,
        'grid_type': 'healpix',
        'indexing_scheme': 'nested',
        'source': f's3://{S3_BUCKET}/{S3_PREFIX}/'
    }
)

print("Created xarray Dataset:")
print(ds)

In [None]:
# Decode with xdggs to enable DGGS operations
ds = xdggs.decode(ds, index_options={"index_kind": "moc"})

# Add lat/lon coordinates
ds = ds.dggs.assign_latlon_coords()

print("Decoded with xdggs and added lat/lon coords:")
print(ds)

## 4. Summary Statistics

In [None]:
print("=" * 60)
print("SUMMARY STATISTICS")
print("=" * 60)

print(f"\nTotal cells with data: {len(ds['cell_ids']):,}")
print(f"Total observations: {ds['count'].sum().values:,}")

print(f"\nElevation (h_mean):")
print(f"  Min:  {ds['h_mean'].min().values:.2f} m")
print(f"  Max:  {ds['h_mean'].max().values:.2f} m")
print(f"  Mean: {ds['h_mean'].mean().values:.2f} m")
print(f"  Std:  {ds['h_mean'].std().values:.2f} m")

print(f"\nUncertainty (h_sigma):")
print(f"  Min:  {ds['h_sigma'].min().values:.4f} m")
print(f"  Max:  {ds['h_sigma'].max().values:.2f} m")
print(f"  Mean: {ds['h_sigma'].mean().values:.4f} m")

print(f"\nObservation counts per cell:")
print(f"  Min:  {ds['count'].min().values}")
print(f"  Max:  {ds['count'].max().values:,}")
print(f"  Mean: {ds['count'].mean().values:.1f}")

print(f"\nCoverage:")
print(f"  Lat: {ds['latitude'].min().values:.2f} to {ds['latitude'].max().values:.2f}")
print(f"  Lon: {ds['longitude'].min().values:.2f} to {ds['longitude'].max().values:.2f}")

## 5. Visualization - Antarctic Overview

In [None]:
# Antarctic Polar Stereographic projection
proj = ccrs.SouthPolarStereo()
data_crs = ccrs.PlateCarree()

fig, axes = plt.subplots(2, 2, figsize=(18, 16), 
                          subplot_kw={'projection': proj})

# Add Antarctic coastline to all subplots
for ax in axes.flat:
    ax.coastlines(resolution='50m', linewidth=0.5)
    ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3)
    ax.gridlines(draw_labels=False, alpha=0.3)
    ax.set_extent([-180, 180, -90, -60], crs=data_crs)

# 1. Mean elevation
ax = axes[0, 0]
valid = ~np.isnan(ds['h_mean'].values)
scatter = ax.scatter(
    ds['longitude'].values[valid],
    ds['latitude'].values[valid],
    c=ds['h_mean'].values[valid],
    s=0.5, cmap='terrain', alpha=0.8,
    vmin=0, vmax=4000,
    transform=data_crs
)
ax.set_title(f'Mean Elevation ({np.sum(valid):,} cells)', fontsize=14, weight='bold')
plt.colorbar(scatter, ax=ax, label='Elevation (m)', shrink=0.7)

# 2. Observation count
ax = axes[0, 1]
valid = ds['count'].values > 0
scatter = ax.scatter(
    ds['longitude'].values[valid],
    ds['latitude'].values[valid],
    c=ds['count'].values[valid],
    s=0.5, cmap='viridis', alpha=0.8,
    norm=plt.matplotlib.colors.LogNorm(vmin=1),
    transform=data_crs
)
ax.set_title(f'Observation Count ({np.sum(valid):,} cells)', fontsize=14, weight='bold')
plt.colorbar(scatter, ax=ax, label='Count (log scale)', shrink=0.7)

# 3. Uncertainty (sigma)
ax = axes[1, 0]
valid = ~np.isnan(ds['h_sigma'].values)
scatter = ax.scatter(
    ds['longitude'].values[valid],
    ds['latitude'].values[valid],
    c=ds['h_sigma'].values[valid],
    s=0.5, cmap='plasma', alpha=0.8,
    vmax=1.0,
    transform=data_crs
)
ax.set_title(f'Uncertainty (h_sigma) ({np.sum(valid):,} cells)', fontsize=14, weight='bold')
plt.colorbar(scatter, ax=ax, label='Uncertainty (m)', shrink=0.7)

# 4. Elevation range (max - min)
ax = axes[1, 1]
h_range = ds['h_max'].values - ds['h_min'].values
valid = ~np.isnan(h_range)
scatter = ax.scatter(
    ds['longitude'].values[valid],
    ds['latitude'].values[valid],
    c=h_range[valid],
    s=0.5, cmap='hot', alpha=0.8,
    vmax=100,
    transform=data_crs
)
ax.set_title(f'Elevation Range (max-min) ({np.sum(valid):,} cells)', fontsize=14, weight='bold')
plt.colorbar(scatter, ax=ax, label='Range (m)', shrink=0.7)

plt.suptitle('ATL06 Lambda Production Results - Cycle 22', fontsize=16, weight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 6. Histograms

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Elevation distribution
ax = axes[0, 0]
valid = ~np.isnan(ds['h_mean'].values)
ax.hist(ds['h_mean'].values[valid], bins=100, edgecolor='none', alpha=0.7)
ax.set_xlabel('Elevation (m)')
ax.set_ylabel('Count')
ax.set_title('Mean Elevation Distribution')
ax.axvline(ds['h_mean'].mean().values, color='red', linestyle='--', label=f'Mean: {ds["h_mean"].mean().values:.0f}m')
ax.legend()

# Observation count distribution (log scale)
ax = axes[0, 1]
valid = ds['count'].values > 0
ax.hist(np.log10(ds['count'].values[valid]), bins=50, edgecolor='none', alpha=0.7)
ax.set_xlabel('log10(Observation Count)')
ax.set_ylabel('Number of Cells')
ax.set_title('Observation Count Distribution')

# Uncertainty distribution
ax = axes[1, 0]
valid = ~np.isnan(ds['h_sigma'].values) & (ds['h_sigma'].values < 10)
ax.hist(ds['h_sigma'].values[valid], bins=100, edgecolor='none', alpha=0.7)
ax.set_xlabel('Uncertainty (m)')
ax.set_ylabel('Count')
ax.set_title('Uncertainty Distribution (h_sigma < 10m)')

# Latitude distribution
ax = axes[1, 1]
ax.hist(ds['latitude'].values, bins=50, edgecolor='none', alpha=0.7)
ax.set_xlabel('Latitude')
ax.set_ylabel('Count')
ax.set_title('Latitude Distribution')

plt.tight_layout()
plt.show()

## 7. Regional Zoom - West Antarctica

In [None]:
# Zoom into West Antarctica (Thwaites/Pine Island region)
proj = ccrs.SouthPolarStereo()
data_crs = ccrs.PlateCarree()

fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={'projection': proj})

ax.coastlines(resolution='10m', linewidth=0.5)
ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3)
ax.gridlines(draw_labels=True, alpha=0.3)

# West Antarctica extent
ax.set_extent([-140, -70, -85, -70], crs=data_crs)

valid = ~np.isnan(ds['h_mean'].values)
scatter = ax.scatter(
    ds['longitude'].values[valid],
    ds['latitude'].values[valid],
    c=ds['h_mean'].values[valid],
    s=2, cmap='terrain', alpha=0.9,
    vmin=0, vmax=2500,
    transform=data_crs
)

plt.colorbar(scatter, ax=ax, label='Elevation (m)', shrink=0.7)
ax.set_title('West Antarctica - Mean Elevation (Thwaites/Pine Island Region)', 
             fontsize=14, weight='bold')
plt.show()

## 8. Save Combined Dataset

In [None]:
# Optionally save combined dataset to zarr
# output_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/combined.zarr"
# ds.to_zarr(output_path, mode='w')
# print(f"Saved to: {output_path}")

# Or save locally
# ds.to_zarr("production_results_combined.zarr", mode='w')
# print("Saved to: production_results_combined.zarr")

print("To save the combined dataset, uncomment the code above.")