In [5]:
# test_grid_alignment.ipynb

import sys, os
sys.path.append(os.path.abspath(".."))

import xarray as xr
import matplotlib.pyplot as plt


# === Project modules ===
from data_loading.load_data import load_dataset, load_ensemble_any_latlon
from data_loading.subset_time import get_common_time_range, subset_time
from data_loading.subset_region import subset_dataset, get_reference_extent
from utils.temporal_stats import aggregate_to_daily
from utils.bias_metrics import mean_error
# The new functions we'll test:
from data_loading.unify_grids import force_match_dimensions

# === 1. Load Example Data ===
ensemble_pattern = os.path.join("..", "data", "total_precipitation_2017010*.nc")
ds_ensemble = load_ensemble_any_latlon(ensemble_pattern)

reference_file = os.path.join("..", "data", "SPARTACUS2-DAILY_RR_2017.nc")
ds_ref = xr.open_dataset(reference_file)

print("Original Ensemble:", ds_ensemble)
print("Original Reference:", ds_ref)

from data_loading.subset_region import unify_lat_lon_names

ds_ensemble = unify_lat_lon_names(ds_ensemble, lat_name="lat", lon_name="lon")

ds_ref = unify_lat_lon_names(ds_ref, lat_name="lat", lon_name="lon")

# === 2. Spatial Subset (Optional) ===
# If we want a smaller region, we find the ref extent and subset both
ref_bounds = get_reference_extent(ds_ref, lat_var="lat", lon_var="lon")
ds_ensemble_sub = subset_dataset(ds_ensemble, lat_var="lat", lon_var="lon", bounds=ref_bounds)
ds_ref_sub      = subset_dataset(ds_ref, lat_var="lat", lon_var="lon", bounds=ref_bounds)

# === 3. Temporal Subset (Optional) ===
common_time = get_common_time_range(ds_ensemble_sub, ds_ref_sub, time_var="time")
ds_ensemble_sub = subset_time(ds_ensemble_sub, time_bounds=common_time)
ds_ref_sub      = subset_time(ds_ref_sub, time_bounds=common_time)

# === 4. Daily Aggregation if Ensemble is Hourly ===
ds_ens_daily = aggregate_to_daily(ds_ensemble_sub, "precipitation", method="sum", compute_ens_mean=True)
ds_ref_daily = ds_ref_sub["RR"]  # Already daily

print("\nEnsemble daily shape:", ds_ens_daily.shape)
print("Reference daily shape:", ds_ref_daily.shape)





Original Ensemble: <xarray.Dataset> Size: 2GB
Dimensions:        (member: 11, time: 48, lat: 492, lon: 594)
Coordinates:
  * member         (member) object 88B '00' '01' '02' '03' ... '08' '09' '10'
    lon            (lat, lon) float32 1MB 5.498 5.526 5.554 ... 22.05 22.07 22.1
  * time           (time) datetime64[ns] 384B 2017-01-01 ... 2017-01-02T23:00:00
Dimensions without coordinates: lat
Data variables:
    latitude       (time, lat, lon, member) float32 617MB 42.98 42.98 ... 51.82
    precipitation  (time, lat, lon, member) float64 1GB 0.0 0.0 0.0 ... 0.0 0.0
Attributes:
    history:  Fri Jan 31 09:51:14 2025: ncrename -v lat,latitude -v lon,longi...
    NCO:      netCDF Operators version 5.1.9 (Homepage = http://nco.sf.net, C...
Original Reference: <xarray.Dataset> Size: 563MB
Dimensions:                  (time: 365, y: 329, x: 584)
Coordinates:
    lambert_conformal_conic  float64 8B ...
    lat                      (y, x) float32 769kB ...
    lon                      (y, x) 

In [6]:
# ===========================================================
#  A) Force Dimension Names (Quick approach)
# ===========================================================
try:
    ds_ref_matched, ds_ens_daily_matched = force_match_dimensions(
        ds_ref_daily, 
        ds_ens_daily, 
        ref_dims=("y", "x"),  # or your actual dims
        ens_dims=("lat", "lon")
    )
    print("\n=== Forcing dimension names successful! ===")
    # Now ds_ref_matched and ds_ens_daily_matched should share dims: (time, lat, lon)
    
    # Compute mean error
    bias_forced = ds_ens_daily_matched - ds_ref_matched
    bias_me_forced = bias_forced.mean_error() if hasattr(bias_forced, "mean_error") else bias_forced.mean()
    
    print("Mean Error (forced dims):", bias_me_forced.values)
except ValueError as e:
    print("\n[Force Dims] Could not force dimension names:", e)




TypeError: tuple indices must be integers or slices, not str

In [1]:
# test_bias_metrics.ipynb

import sys
import os
sys.path.append(os.path.abspath(".."))  # Ensure we can import from the parent directory

import xarray as xr
import matplotlib.pyplot as plt

# === 1. Import from our project modules ===
from data_loading.load_data import load_ensemble_any_latlon, load_dataset
from data_loading.subset_region import (
    get_reference_extent,
    subset_dataset,
    unify_lat_lon_names,  # if you're using a separate function to unify lat/lon names
)
from data_loading.subset_time import (
    get_common_time_range,
    subset_time
)
from utils.temporal_stats import aggregate_to_daily
from utils.bias_metrics import (
    mean_error,
    mean_absolute_error,
    root_mean_squared_error,
    compute_all_bias_metrics
)

# === 2. Define file paths / patterns ===
ensemble_pattern = os.path.join("..", "data", "total_precipitation_2017010*.nc")
reference_file   = os.path.join("..", "data", "SPARTACUS2-DAILY_RR_2017.nc")

# === 3. Load datasets ===
ds_ensemble = load_ensemble_any_latlon(ensemble_pattern)
ds_ref      = xr.open_dataset(reference_file)

print(ds_ensemble)

<xarray.Dataset> Size: 2GB
Dimensions:        (member: 11, time: 48, lat: 492, lon: 594)
Coordinates:
  * member         (member) object 88B '00' '01' '02' '03' ... '08' '09' '10'
    lon            (lat, lon) float32 1MB 5.498 5.526 5.554 ... 22.05 22.07 22.1
  * time           (time) datetime64[ns] 384B 2017-01-01 ... 2017-01-02T23:00:00
Dimensions without coordinates: lat
Data variables:
    latitude       (time, lat, lon, member) float32 617MB 42.98 42.98 ... 51.82
    precipitation  (time, lat, lon, member) float64 1GB 0.0 0.0 0.0 ... 0.0 0.0
Attributes:
    history:  Fri Jan 31 09:51:14 2025: ncrename -v lat,latitude -v lon,longi...
    NCO:      netCDF Operators version 5.1.9 (Homepage = http://nco.sf.net, C...


In [2]:
# Optional: unify lat/lon variable names if needed
ds_ensemble = unify_lat_lon_names(ds_ensemble, lat_name="lat", lon_name="lon")
ds_ref      = unify_lat_lon_names(ds_ref, lat_name="lat", lon_name="lon")

# === 4. Spatial Subsetting ===
ref_bounds = get_reference_extent(ds_ref, lat_var="lat", lon_var="lon")
ds_ens_subset = subset_dataset(ds_ensemble, lat_var="lat", lon_var="lon", bounds=ref_bounds)
ds_ref_subset = subset_dataset(ds_ref,      lat_var="lat", lon_var="lon", bounds=ref_bounds)

# === 5. Temporal Subsetting ===
common_time = get_common_time_range(ds_ens_subset, ds_ref_subset, time_var="time")
ds_ens_subset = subset_time(ds_ens_subset, time_bounds=common_time, time_var="time")
ds_ref_subset = subset_time(ds_ref_subset, time_bounds=common_time, time_var="time")

# === 6. Daily Aggregation ===
# Ensemble might be hourly -> aggregate to daily sums; also compute ensemble mean if needed
ds_ens_daily = aggregate_to_daily(ds_ens_subset, "precipitation", method="sum", compute_ens_mean=True)

# Reference is already daily -> just pick the variable (adjust name if needed)
ds_ref_daily = ds_ref_subset["RR"]

print("\nEnsemble daily precipitation:", ds_ens_daily)
print("Reference daily:", ds_ref_daily)


Ensemble daily precipitation: <xarray.DataArray 'precipitation' (time: 2, lat: 168, lon: 285)> Size: 766kB
array([[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         3.46790661e-07, 3.46790661e-07, 3.46790661e-07],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         6.93581321e-07, 3.46790661e-07, 3.46790661e-07],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         6.93581321e-07, 6.93581321e-07, 3.46790661e-07],
        ...,
        [1.80234042e-01, 1.58514196e-01, 1.43266158e-01, ...,
         1.91775235e-04, 5.45848500e-04, 5.02915816e-03],
        [2.16471239e-01, 1.92285018e-01, 1.73217773e-01, ...,
         0.00000000e+00, 3.67598100e-05, 1.56888095e-03],
        [2.42437189e-01, 2.21659227e-01, 2.04820806e-01, ...,
         1.68887052e-04, 3.88405540e-05, 3.19047408e-04]],

       [[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+

In [None]:
# === 7. Compute Bias Metrics ===
# Option A: Individually
#bias_me   = mean_error(ds_ens_daily, ds_ref_daily)
#bias_mae  = mean_absolute_error(ds_ens_daily, ds_ref_daily)
#bias_rmse = root_mean_squared_error(ds_ens_daily, ds_ref_daily)

# Option B: All metrics at once (returns a dictionary or dataset, depending on your function)
#bias_all = compute_all_bias_metrics(ds_ens_daily, ds_ref_daily)

#print("Mean Error:", bias_me.compute().values if hasattr(bias_me, 'compute') else bias_me.values)
#print("MAE:", bias_mae.compute().values if hasattr(bias_mae, 'compute') else bias_mae.values)
#print("RMSE:", bias_rmse.compute().values if hasattr(bias_rmse, 'compute') else bias_rmse.values)
#print("All Metrics:", bias_all)



In [None]:
# === 8. Visualize the Bias (e.g., spatial mean bias over the time dimension) ===
# Compute the mean error over time, resulting in lat-lon grid (or y-x grid).
bias_spatial_me = (ds_ens_daily - ds_ref_daily).mean(dim="time")

# Plot using xarray's .plot() (works if lat/lon are recognized as coords or you have 2D lat/lon)
plt.figure(figsize=(8, 6))
bias_spatial_me.plot(
    robust=True,  # improves color scaling
    cmap="RdBu",
    cbar_kwargs={"label": "Mean Error (mm/day)"}
)
plt.title("Spatial Mean Error (Ensemble Mean - Reference)")
plt.show()
