In [28]:
import os
from typing import Optional
import xarray as xr
from typing import Any
import numpy as np
import geopandas as gpd
import pprint
import pandas as pd

In [None]:
# Loader

def load_era5(path: str, engine: Optional[str] = "cfgrib"):
    """Load an ERA5 GRIB file as an xarray Dataset.

    Args:
        path: path to the .grib file
        engine: preferred xarray engine (default: 'cfgrib')

    Returns:
        xarray.Dataset

    Raises:
        ImportError: if xarray or the engine backend is not available
        Exception: any error raised by xarray when opening the file
    """
    # Try the requested engine first, then fall back
    try:
        ds = xr.open_dataset(path, engine=engine)
        return ds
    except Exception as first_err:
        # Try without specifying engine (let xarray guess)
        try:
            ds = xr.open_dataset(path)
            return ds
        except Exception:
            # Re-raise the original error with a helpful message
            raise RuntimeError(
                f"Could not open ERA5 file at {path}. Tried engine='{engine}'. "
                "Make sure 'cfgrib' is installed for GRIB support (pip install cfgrib)."
            ) from first_err


In [None]:
# Visualisation utils

# String Utils
def _safe_str(x: Any) -> str:
    try:
        return str(x)
    except Exception:
        return repr(x)

# Grib summary
def summary(ds) -> str:
    """Return a one-page summary for an xarray.Dataset representing ERA5/GRIB.

    The summary includes:
    - number of data variables
    - list of variable names and dtypes
    - time range (if time coord exists)
    - lat/lon coords (names and lengths)
    - dataset attributes

    Returns a multi-line string. Safe to call with None or non-xarray objects.
    """
    if xr is None:
        return "xarray not available"

    if ds is None:
        return "<None dataset>"

    if not isinstance(ds, xr.Dataset):
        return f"Not an xarray.Dataset (type={type(ds)})"

    lines = []
    lines.append(f"Dataset summary: {ds.__class__.__name__}")

    # Variables
    vars_info = []
    for name, da in ds.data_vars.items():
        try:
            vars_info.append(f"{name}: {da.dtype} {list(da.dims)} shape={tuple(da.shape)}")
        except Exception:
            vars_info.append(f"{name}: <unable to inspect>")

    lines.append(f"Data variables ({len(vars_info)}):")
    for v in vars_info:
        lines.append(f"  - {v}")

    # Coordinates and dims
    lines.append("Dimensions:")
    for d, size in ds.sizes.items():
        lines.append(f"  - {d}: {size}")

    # Time range
    if "time" in ds.coords:
        try:
            t = ds.coords["time"]
            t0 = pd.to_datetime(t.values.min())
            t1 = pd.to_datetime(t.values.max())
            lines.append(f"Time range: {t0} -> {t1} (n={t.size})")
        except Exception:
            lines.append("Time range: <unable to determine>")

    # Latitude / longitude detection
    lat_candidates = [k for k in ds.coords.keys() if k.lower().startswith("lat")]
    lon_candidates = [k for k in ds.coords.keys() if k.lower().startswith("lon") or k.lower().startswith("lon")]  # keep simple

    if lat_candidates:
        lat = lat_candidates[0]
        lines.append(f"Latitude coord: '{lat}' (n={ds.coords[lat].size})")
    if lon_candidates:
        lon = lon_candidates[0]
        lines.append(f"Longitude coord: '{lon}' (n={ds.coords[lon].size})")

    # attrs
    if getattr(ds, "attrs", None):
        lines.append("Dataset attributes:")
        for k, v in ds.attrs.items():
            lines.append(f"  - {k}: {_safe_str(v)}")

    return "\n".join(lines)

# Show era5 structure
def show_structure(ds) -> None:
    """Print detailed structure of an xarray.Dataset to stdout.

    Includes dims, coords (with sample values), each data var's dims/shape/dtype,
    and useful variable attributes like units and long_name.
    """
    if xr is None:
        print("xarray not available")
        return

    if ds is None:
        print("<None dataset>")
        return

    if not isinstance(ds, xr.Dataset):
        print(f"Not an xarray.Dataset (type={type(ds)})")
        return

    print("--- Dataset structure ---")
    print(f"Class: {ds.__class__.__name__}")
    print("\nDimensions:")
    for d, size in ds.sizes.items():
        print(f"  {d}: {size}")

    print("\nCoordinates (name: example values):")
    for name, coord in ds.coords.items():
        try:
            # show up to 6 example values
            vals = coord.values
            if hasattr(vals, "tolist"):
                vals = coord.values.tolist()
            sample = vals[:6]
            print(f"  - {name}: {sample} (n={coord.size})")
        except Exception:
            print(f"  - {name}: <unavailable>")

    print("\nData variables:")
    for name, da in ds.data_vars.items():
        print(f"  - {name}")
        try:
            print(f"      dims: {list(da.dims)}")
            print(f"      shape: {tuple(da.shape)}")
            print(f"      dtype: {da.dtype}")
            # common metadata
            md = {}
            for k in ("units", "long_name", "standard_name", "GRIB_name", "GRIB_code"):
                if k in da.attrs:
                    md[k] = da.attrs.get(k)
            if md:
                print(f"      attrs: {pprint.pformat(md)}")
        except Exception:
            print("      <unable to inspect variable>")

    # global attrs
    if getattr(ds, "attrs", None):
        print("\nGlobal attributes:")
        pprint.pprint(ds.attrs)


In [None]:
# Load ERA5 GRIB
era5_path = os.path.join('data', 'era5_meteo_data.grib')
ds = load_era5(era5_path)

# Print summary
print('ERA5 dataset summary:')
print(summary(ds))

# show detailed structure to help understand dims/coords/vars
print('\n--- Detailed structure ---')
show_structure(ds)

# Do Yearly mean
ds_year = ds.groupby(f"time.year").mean(dim="time")
print('\nCleaned dataset summary:')
print(summary(ds_year))
print('\n--- Yearly dataset structure ---')
show_structure(ds_year)



skipping variable: paramId==228 shortName='tp'
Traceback (most recent call last):
  File "/home/raspigeon/Documents/HEIG/Semestre 5/GML/.venv/lib64/python3.12/site-packages/cfgrib/dataset.py", line 726, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/home/raspigeon/Documents/HEIG/Semestre 5/GML/.venv/lib64/python3.12/site-packages/cfgrib/dataset.py", line 642, in dict_merge
    raise DatasetBuildError(
cfgrib.dataset.DatasetBuildError: key present and new value is different: key='time' value=Variable(dimensions=('time',), data=array([ 315532800,  318211200,  320716800,  323395200,  325987200,
        328665600,  331257600,  333936000,  336614400,  339206400,
        341884800,  344476800,  347155200,  349833600,  352252800,
        354931200,  357523200,  360201600,  362793600,  365472000,
        368150400,  370742400,  373420800,  376012800,  378691200,
        381369600,  383788800,  386467200,  389059200,  391737600,
        394329600,  397008000,  399686

ERA5 dataset summary:
Dataset summary: Dataset
Data variables (8):
  - u10: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
  - v10: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
  - d2m: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
  - t2m: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
  - sst: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
  - sp: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
  - skt: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
  - blh: float32 ['time', 'latitude', 'longitude'] shape=(548, 721, 1440)
Dimensions:
  - time: 548
  - latitude: 721
  - longitude: 1440
Time range: 1980-01-01 00:00:00 -> 2025-08-01 00:00:00 (n=548)
Latitude coord: 'latitude' (n=721)
Longitude coord: 'longitude' (n=1440)
Dataset attributes:
  - GRIB_edition: 1
  - GRIB_centre: ecmf
  - GRIB_centreDescription: European Centre for Medium-Range Weather Forecasts
 

In [None]:
# Build country mask using geopandas spatial join

# read country polygons (GeoJSON contains ISO3166-1-Alpha-3)
world = gpd.read_file("https://raw.githubusercontent.com/datasets/geo-countries/master/data/countries.geojson")

# Filter out invalid entries
world = world[world['ISO3166-1-Alpha-3'] != "-99"].reset_index(drop=True)

print(world)

# 1D lon/lat from ds_year
lons = ds_year.longitude.values
lats = ds_year.latitude.values

# make 2D grid of points and flatten
lon2d, lat2d = np.meshgrid(lons, lats)
pts = np.column_stack((lon2d.ravel(), lat2d.ravel()))

# GeoDataFrame of points (keep original index so we can reorder after spatial join)
pts_gdf = gpd.GeoDataFrame(
    {"pt_index": np.arange(pts.shape[0]), "lon": pts[:, 0], "lat": pts[:, 1]},
    geometry=gpd.points_from_xy(pts[:, 0], pts[:, 1]),
    crs="EPSG:4326",
)

# spatial join: points -> country polygons
joined = gpd.sjoin(pts_gdf, world[["ISO3166-1-Alpha-3", "geometry"]], how="left", predicate="within")

# restore original order and extract iso codes (NaN for ocean)
joined = joined.sort_values("pt_index")
iso_series = joined["ISO3166-1-Alpha-3"].values  # length = n_points

# map iso codes to integer indices (consistent ordering)
iso_codes = list(world["ISO3166-1-Alpha-3"].values)
iso_to_idx = {code: i for i, code in enumerate(iso_codes)}

# convert to index array (NaN where no country)
idx_arr = np.array([iso_to_idx.get(x, np.nan) for x in iso_series])

# reshape to (lat, lon)
mask_2d = idx_arr.reshape(len(lats), len(lons))

# convenience lists used by aggregation code
region_indices = np.unique(mask_2d[~np.isnan(mask_2d)].astype(int))
region_codes = [iso_codes[int(i)] for i in region_indices]

                         name ISO3166-1-Alpha-3 ISO3166-1-Alpha-2  \
0                   Indonesia               IDN                ID   
1                    Malaysia               MYS                MY   
2                       Chile               CHL                CL   
3                     Bolivia               BOL                BO   
4                        Peru               PER                PE   
..                        ...               ...               ...   
231                     Palau               PLW                PW   
232                      Guam               GUM                GU   
233  Northern Mariana Islands               MNP                MP   
234                   Bahrain               BHR                BH   
235               Macao S.A.R               MAC                MO   

                                              geometry  
0    MULTIPOLYGON (((117.70361 4.16342, 117.70361 4...  
1    MULTIPOLYGON (((117.70361 4.16342, 117.69711 4...  


In [None]:
# Aggregate yearly means to country-level and create period CSVs directly

# Convert longitudes to -180..180 and ensure ordering
ds_year = ds_year.assign_coords(longitude=(((ds_year.longitude + 180) % 360) - 180)).sortby("longitude")

# area weights (cosine latitude)
w = np.cos(np.deg2rad(ds_year.latitude))
w_da = xr.DataArray(w, coords={"latitude": ds_year.latitude}, dims=("latitude",))

out_dir = os.path.join("export")
os.makedirs(out_dir, exist_ok=True)

# Define time ranges
time_ranges = {
    "1721": (2017, 2021),
    "8021": (1980, 2021),
    "2225": (2022, 2025)
}

# Process all years and organize by time range
range_data = {range_name: [] for range_name in time_ranges.keys()}

# Iterate years and aggregate to country level
for year in ds_year["year"].values:
    ds_y = ds_year.sel(year=year)
    # broadcast weights to 2D like a data variable
    weights_2d = w_da.broadcast_like(ds_y[list(ds_y.data_vars)[0]])

    rows = []
    for i, r in enumerate(region_indices):
        mask_bool = xr.DataArray(
            (mask_2d == r), 
            coords={"latitude": ds_year.latitude, "longitude": ds_year.longitude}, 
            dims=("latitude", "longitude")
        )
        denom = (weights_2d * mask_bool).sum(dim=("latitude", "longitude"))
        row = {"Year": int(year), "Country Code": region_codes[i]}
        for var in ds_y.data_vars:
            num = ((ds_y[var] * weights_2d) * mask_bool).sum(dim=("latitude", "longitude"))
            row[var] = float((num / denom).values)
        rows.append(row)

    df_year = pd.DataFrame(rows)
    
    # Assign to appropriate time ranges
    for range_name, (start_year, end_year) in time_ranges.items():
        if start_year <= int(year) <= end_year:
            range_data[range_name].append(df_year)

# Create CSV for each time range
for range_name, dfs in range_data.items():
    if dfs:
        df_combined = pd.concat(dfs, ignore_index=True)
        out_path = os.path.join(out_dir, f"era5_climate_country_{range_name}.csv")
        df_combined.to_csv(out_path, index=False)
        
        start_year, end_year = time_ranges[range_name]
        years_in_range = len(dfs)
        print(f"Created {out_path} with {len(df_combined)} rows ({years_in_range} years: {start_year}-{end_year})")
    else:
        print(f"Warning: No data found for range {range_name}")

print(f"\nAll period CSVs written to {out_dir}")

Wrote per-year country CSVs to export/by_year_country
