In [22]:
import xarray as xr
import xesmf as xe
import os
import numpy as np
from pyproj import CRS, Transformer
import matplotlib.pyplot as plt

In [23]:
os.getcwd()

'/scratch2/mg963/aifs-tp'

In [24]:
aifs = xr.open_zarr("AIFS_TP_FXX24_DE.zarr", consolidated=True, chunks={})
ifs = xr.open_zarr("IFS_TP_FXX24_DE.zarr", consolidated=True, chunks={})

In [25]:
obs = xr.open_dataset("../../aifs-data/obs/pr_hyras_1_2025_v6-1_de.nc")

In [26]:
# Standardize coord names
if "latitude" in aifs.coords: aifs = aifs.rename({"latitude":"lat","longitude":"lon"})
if "latitude" in ifs.coords:  ifs  = ifs.rename({"latitude":"lat","longitude":"lon"})

In [27]:
# Clean, explicit grids: lat/lon as DATA VARS, drop all other coords
grid_in = xr.Dataset(
    data_vars=dict(
        lat=(["y","x"],  obs["lat"].values.astype("float64")),
        lon=(["y","x"],  obs["lon"].values.astype("float64")),
    ),
    coords=dict(
        y=obs["y"].values.astype("float64"),
        x=obs["x"].values.astype("float64"),
    ),
).reset_coords(drop=True)

In [28]:
# MODEL grid: 1D lat/lon ONLY as data_vars (no coords with same names) âœ…
grid_out = xr.Dataset(
    data_vars=dict(
        lat=(("lat",), aifs["lat"].values.astype("float64")),
        lon=(("lon",), aifs["lon"].values.astype("float64")),
    )
)  # note: no coords=... here

In [44]:
# Build regridder
regridder = xe.Regridder(
    grid_in, grid_out,
    method="nearest_d2s",          # good for precip
    extrap_method=None,          
    # filename="weights_obs_to_aifs_DE_nearest.nc",
    # reuse_weights=True,
    periodic=False,
    ignore_degenerate=True
)


inverse_regridder = xe.Regridder(
    grid_out, grid_in,
    method="bilinear",          # good for precip
    extrap_method=None,          
    # filename="weights_obs_to_aifs_DE_nearest.nc",
    # reuse_weights=True,
    periodic=False,
    ignore_degenerate=True
)

In [194]:
# 1) Build a 0/1 mask on the HYRAS grid (time-independent)
#    1 where pr is finite on a sample time; NaN elsewhere
src_mask = xr.where(np.isfinite(obs["pr"].isel(time=0)), 1.0, np.nan).astype("float32")

# 2) Regrid the mask to the model grid  -> should be (lat, lon) with 1 or NaN
footprint = regridder(src_mask).rename("footprint")

# Optional cleanup: ensure strict 1/NaN
footprint = xr.where(footprint > 0, 1.0, np.nan)

# 3) Regrid precipitation and apply the footprint mask
obs_on_model = regridder(obs["pr"]).where(np.isfinite(footprint)).rename("pr_on_model")

aifs_crop = aifs.where(np.isfinite(footprint))
ifs_crop = ifs.where(np.isfinite(footprint))

In [56]:
src_mask = xr.where(np.isfinite(obs["pr"].isel(time=0)), 1.0, np.nan).astype("float32")

aifs_on_obs = inverse_regridder(aifs.tp)
aifs_on_obs_crop = aifs_on_obs.where(np.isfinite(src_mask))

ifs_on_obs = inverse_regridder(ifs.tp)
ifs_on_obs_crop = ifs_on_obs.where(np.isfinite(src_mask))

  result_var = func(*data_vars)
  intermediate = blockwise(
  result_var = func(*data_vars)
  intermediate = blockwise(


In [229]:
# OPTION 1
aifs_tp = aifs_crop["tp"]
ifs_tp = ifs_crop["tp"]*1000
obs_tp = obs_on_model/1000

In [60]:
# OPTION 2
aifs_interp_tp = aifs_on_obs_crop
ifs_interp_tp = ifs_on_obs_crop * 1000
obs_tp = obs.pr/1000

In [230]:
# OPTION 1
aifs_aln, ifs_aln, obs_aln = xr.align(aifs_tp, ifs_tp, obs_tp, join="inner")

In [61]:
# OPTION 2
aifs_aln, ifs_aln, obs_aln = xr.align(aifs_interp_tp, ifs_interp_tp, obs_tp, join="inner")

In [62]:
# build a single Dataset with both aligned arrays
pair = xr.Dataset({"aifs": aifs_aln, 
                   "ifs": ifs_aln, 
                   "obs": obs_aln})

# mask where either side is NaN (optional but recommended)
pair = pair.where(np.isfinite(pair["aifs"]) & np.isfinite(pair["ifs"]) & np.isfinite(pair["obs"]))

In [63]:
pair.to_zarr("AIFS_IFS_OBS_GERMANY_HIGHRES.zarr", mode="w", consolidated=True)



<xarray.backends.zarr.ZarrStore at 0x749738067e20>

In [64]:
pair.to_netcdf("AIFS_IFS_OBS_GERMANY_HIGHRES.nc")

KeyboardInterrupt: 