# Comparison to the Water Budget
We compare the changes of our reconstructed water storage anomalies to the water balance, derived from ERA5.

In [None]:
import string

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr
import cmcrameri.cm as cmc
from dask.diagnostics import ProgressBar
from matplotlib import ticker
from matplotlib_inline.backend_inline import set_matplotlib_formats

import deeprec  # noqa
from deeprec import metrics
from deeprec.regions import basins
from deeprec.utils import ROOT_DIR, repeat_by_weight

# Register dask progress bar
ProgressBar(minimum=5).register()


In [None]:
set_matplotlib_formats("retina")
plt.style.use(ROOT_DIR / "config/style_paper.mplstyle")
FIGURE_DIR = ROOT_DIR / "docs/figures/paper"


## Data Loading

For preprocessing, see the corresponding notebook in `11-data-processing/water-budget-processing.ipynb`.

In [None]:
twsc_path = ROOT_DIR / "data/evaluations/twsc_basin-scale.zarr"
twsc = xr.open_zarr(twsc_path)


Specify variables of interest and their descriptive names:

In [None]:
vars_verbose = {
    "ours_era": "DeepRec (Ours)",
    "gap": "WGHM",
    "palazzoli_jpl_full": "Palazzoli's JPL Rec",
    "yin_csr_full": "Yin's CSR Rec",
    "li_csr_full": "Li's CSR Rec",
    "humphrey_gsfc_detrend": "Humphrey's GSFC Rec",
    "csr": "CSR (GRACE)",
}
bdgt_name = "era5_smooth"
grace_name = "csr"
our_name = "ours_era"
eval_names = vars_verbose.keys()


## Analysis

Define evaluation models and water balance target:

In [None]:
# Drop all time steps where one variable is NaN
twsc_post = twsc.where(twsc.dr.time_notnull("region"), drop=True)

bdgt_post = twsc_post[bdgt_name]
eval_post = twsc_post[eval_names].to_dataarray("model")

twsc_post


In [None]:
twsc_pre = twsc.sel(time=slice("1980", "2001"))

bdgt_pre = twsc_pre[bdgt_name]
eval_pre = twsc_pre[eval_names].drop_vars(grace_name).to_dataarray("model")

twsc_pre


###  Calculate RMSE

In [None]:
rmse_pre = metrics.rmse(bdgt_pre, eval_pre, dim="time").compute()
rmse_post = metrics.rmse(bdgt_post, eval_post, dim="time").compute()

# Select our reconstruction and GRACE
our_rmse_pre = rmse_pre.sel(model=our_name)
our_rmse_post = rmse_post.sel(model=our_name)
csr_rmse_post = rmse_post.sel(model=grace_name)


### Calculate NSE

In [None]:
nse_pre = metrics.nse(bdgt_pre, eval_pre, dim="time").compute()
nse_post = metrics.nse(bdgt_post, eval_post, dim="time").compute()

# Select our reconstruction and GRACE
our_nse_pre = nse_pre.sel(model=our_name)
our_nse_post = nse_post.sel(model=our_name)
csr_nse_post = nse_post.sel(model=grace_name)


## Create Plot

In [None]:
rmse_kwargs = dict(
    vmin=0,
    vmax=45,
    cmap=cmc.batlow,
    rasterized=True,
    coastlines=True,
    gridlines=True,
    coastlines_kwargs=dict(rasterized=True, linewidth=0.5),
    add_colorbar=False,
    cbar_kwargs=None,
)
nse_kwargs = dict(
    vmin=-1,
    vmax=1,
    cmap=cmc.bam,
    rasterized=True,
    coastlines=True,
    gridlines=True,
    coastlines_kwargs=dict(rasterized=True, linewidth=0.5),
    add_colorbar=False,
    cbar_kwargs=None,
)
basinlines_kwargs = dict(
    edgecolor="black",
    linewidth=0.4,
    facecolor="None",
    rasterized=True,
    zorder=3.0,
)
rmse_cbar_kwargs = dict(location="bottom", aspect=50, shrink=0.66)
nse_cbar_kwargs = dict(location="bottom", aspect=50, shrink=0.66, extend="min")


In [None]:
# Get timespan
start = eval_post.get_index("time")[0]
end = eval_post.get_index("time")[-1]
timespan_str = f"{start.year}-{start.month:02} - {end.year}-{end.month:02}"
timespan_str


In [None]:
# Create a spatial dummy array for the .dr.projplot_basins() function
STEP_DEG = 0.5
lats = np.arange(-89.75, 89.75 + STEP_DEG, STEP_DEG)
lons = np.arange(-179.75, 179.75 + STEP_DEG, STEP_DEG)

ones = np.ones([len(lats), len(lons)])

spatial_dummy = xr.DataArray(ones, coords={"lat": lats, "lon": lons})


In [None]:
# GeoDataFrame containing basin shapes and areas
gdf_basin_shapes = basins(top=72)

# Series of basin areas, for scatter plotting
s_basin_areas = gdf_basin_shapes.rename(
    columns={"riverbasin": "region", "sum_sub_ar": "area"}
).set_index("region")["area"]

# DataArray of basin areas, for weighting
da_basin_areas = xr.DataArray.from_series(s_basin_areas)

In [None]:
print(f"{our_rmse_pre.min() = :.2f}")
print(f"{our_rmse_pre.max() = :.2f}\n")

print(f"{our_rmse_post.min() = :.2f}")
print(f"{our_rmse_post.max() = :.2f}\n")

print(f"{our_rmse_pre.weighted(da_basin_areas).mean() = :.2f}")
print(f"{our_rmse_pre.weighted(da_basin_areas).quantile(0.9) = :.2f}\n")

print(f"{our_rmse_post.weighted(da_basin_areas).mean() = :.2f}")
print(f"{our_rmse_post.weighted(da_basin_areas).quantile(0.9) = :.2f}\n")


In [None]:
# Subfigure approach
fig = plt.figure(figsize=(7.2, 4))
subfigs = fig.subfigures(nrows=2, hspace=0.05)

axs_rmse = subfigs[0].subplots(1, 3, subplot_kw={"projection": ccrs.EqualEarth()})
axs_nse = subfigs[1].subplots(1, 3, subplot_kw={"projection": ccrs.EqualEarth()})

axs_geo = [*axs_rmse, *axs_nse]

# RMSE Ours Pre-GRACE
p_rmse = our_rmse_pre.dr.projplot_basins(spatial_dummy, ax=axs_rmse[0], **rmse_kwargs)
axs_rmse[0].set(title="DeepRec, 1980--2001")

# RMSE Ours GRACE era
our_rmse_post.dr.projplot_basins(spatial_dummy, ax=axs_rmse[1], **rmse_kwargs)
axs_rmse[1].set(title="DeepRec, 2002--2019")

# RMSE GRACE
csr_rmse_post.dr.projplot_basins(spatial_dummy, ax=axs_rmse[2], **rmse_kwargs)
axs_rmse[2].set(title="GRACE, 2002--2019")

# NSE Ours Pre-GRACE
p_nse = our_nse_pre.dr.projplot_basins(spatial_dummy, ax=axs_nse[0], **nse_kwargs)
axs_nse[0].set(title="DeepRec, 1980--2001")

# NSE Ours GRACE era
our_nse_post.dr.projplot_basins(spatial_dummy, ax=axs_nse[1], **nse_kwargs)
axs_nse[1].set(title="DeepRec, 2002--2019")

# NSE GRACE
csr_nse_post.dr.projplot_basins(spatial_dummy, ax=axs_nse[2], **nse_kwargs)
axs_nse[2].set(title="GRACE, 2002--2019")

for ax in axs_geo:
    # Add basinlines
    gdf_basin_shapes.dr.projplot(ax=ax, **basinlines_kwargs)

# Add colorbars
subfigs[0].colorbar(p_rmse, ax=axs_rmse, label="RMSE (mm)", **rmse_cbar_kwargs)
subfigs[0].colorbar(p_nse, ax=axs_nse, label="NSE", **nse_cbar_kwargs)


## Weighted Boxplot: Ability to close the water budget

We calculate the error between the reconstructed TWSC and the ERA5 TWSC (the budget). We plot the basin wise errors as a boxplot, whereby we weight the basins according to their approximate size.

In [None]:
# Create DataArrays for the models to evaluate and the budget
da_eval = twsc[eval_names].to_dataarray("model")
da_bdgt = twsc[bdgt_name]


Create a DataFrame with an integer column where the number is relative to the the basin area. Our weighting function requires integer values.

In [None]:
# Create a DataFrame with the basin shape areas
basin_sizes = gdf_basin_shapes[["riverbasin", "sum_sub_ar"]].rename(
    columns={"riverbasin": "region", "sum_sub_ar": "area"}
)
# Smallest int equals multiplier (specifies accuracy fo the weighting)
MULT = 10
area_rel = basin_sizes.area / basin_sizes.area.min() * MULT
basin_sizes["area_int"] = area_rel.round().astype(int)
basin_sizes


### Create the boxplot

In [None]:
colors = sns.color_palette("tab10")
colors


In [None]:
# Create dictionaries for color lookup
colors_dict = {
    "ours_era": colors[0],
    "csr": colors[1],
    "gap": colors[2],
    "palazzoli_jpl_full": colors[3],
    "yin_csr_full": colors[4],
    "li_csr_full": colors[5],
    "humphrey_gsfc_detrend": colors[6],
}


In [None]:
# Time periods for calculating basin wise errors
TIME_PERIODS = [
    ("1940", "1959"),
    ("1960", "1979"),
    ("1980", "2001"),
    ("2002", "2019"),
]


In [None]:
# DataFrame of the weighted basin RMSEs for every period
df_errors = []

for period in TIME_PERIODS:
    # Select current time period
    da_bdgt_period = da_bdgt.sel(time=slice(*period))
    da_eval_period = da_eval.sel(time=slice(*period))

    # Remove variables which are NA for most of the time
    THRES = 0.5
    n_time = len(da_eval_period.time)
    for model in da_eval_period.model.values:
        # Drop model if less not-NA timesteps than threshold for current period
        n_time_notna = len(
            da_eval_period.sel(model=model).dropna("time", how="all").time
        )
        if n_time_notna < 0.5 * n_time:
            da_eval_period = da_eval_period.drop_sel(model=model)

    # Calculate the RMSE
    da_error = metrics.rmse(da_bdgt_period, da_eval_period, dim="time", skipna=True)
    # Convert to data frame
    df_error = da_error.to_pandas().unstack().reset_index(name="error")
    # Merge with basin frame
    df_error = df_error.merge(basin_sizes)
    # Add decriptive names
    df_error["model_verbose"] = df_error["model"].map(vars_verbose)
    # Repeat rows according to the basin size
    df_error = repeat_by_weight(df_error, weight_col="area_int")

    df_errors.append(df_error)


In [None]:
# Create boxplots
fig, axs = plt.subplots(
    ncols=len(TIME_PERIODS), sharex=True, sharey=True, figsize=(7.2, 2.2)
)
locator = ticker.MultipleLocator(10)

# Iterate over different time periods
for ax, period, df_error in zip(axs, TIME_PERIODS, df_errors):
    # Plot
    sns.boxplot(
        df_error,
        x="error",
        y="model_verbose",
        hue="model",
        palette=colors_dict,
        showfliers=False,
        width=0.5,
        saturation=1,
        legend=False,
        ax=ax,
    )
    ax.xaxis.grid()
    ax.set(xlabel="RMSE (mm)", ylabel=None)
    ax.set_title(period[0] + "--" + period[1])
    ax.xaxis.set_major_locator(locator)

axs[0].set(xlim=(0, 40))


Switch x and y axis, use color-coded legend instead of categorical axis labels:

In [None]:
# Create boxplots
fig, axs = plt.subplots(
    ncols=len(TIME_PERIODS), sharex=True, sharey=True, figsize=(7.2, 2.2)
)
# Iterate over different time periods
for i, (ax, period, df_error) in enumerate(zip(axs, TIME_PERIODS, df_errors)):
    # Plot
    sns.boxplot(
        df_error,
        x="model_verbose",
        y="error",
        hue="model",
        palette=colors_dict,
        showfliers=False,
        width=0.5,
        saturation=1,
        legend=True if i == 3 else False,
        ax=ax,
    )
    ax.yaxis.grid()
    ax.set(ylabel="RMSE (mm)", xlabel=None)
    ax.set_title(period[0] + "--" + period[1])
    ax.xaxis.set_visible(False)

axs[0].set(ylim=(0, 35))

# Move legend outside the axes
leg = fig.legend(loc="outside right")
for text in leg.texts:
    text.set_text(vars_verbose[text.get_text()])
axs[-1].legend().remove()

## Combine Everything

In [None]:
# Subfigure approach
fig = plt.figure(figsize=(7.2, 6))
subfigs = fig.subfigures(nrows=3, hspace=0.05, height_ratios=[1, 1, 1.2])

axs_rmse = subfigs[0].subplots(1, 3, subplot_kw={"projection": ccrs.EqualEarth()})
axs_nse = subfigs[1].subplots(1, 3, subplot_kw={"projection": ccrs.EqualEarth()})
axs_box = subfigs[2].subplots(1, 4, sharex=True, sharey=True)


### MAP PLOTS ###

axs_geo = [*axs_rmse, *axs_nse]

# RMSE Ours Pre-GRACE
p_rmse = our_rmse_pre.dr.projplot_basins(spatial_dummy, ax=axs_rmse[0], **rmse_kwargs)
axs_rmse[0].set(title="DeepRec, 1980--2001")

# RMSE Ours GRACE era
our_rmse_post.dr.projplot_basins(spatial_dummy, ax=axs_rmse[1], **rmse_kwargs)
axs_rmse[1].set(title="DeepRec, 2002--2019")

# RMSE GRACE
csr_rmse_post.dr.projplot_basins(spatial_dummy, ax=axs_rmse[2], **rmse_kwargs)
axs_rmse[2].set(title="GRACE, 2002--2019")

# NSE Ours Pre-GRACE
p_nse = our_nse_pre.dr.projplot_basins(spatial_dummy, ax=axs_nse[0], **nse_kwargs)
axs_nse[0].set(title="DeepRec, 1980--2001")

# NSE Ours GRACE era
our_nse_post.dr.projplot_basins(spatial_dummy, ax=axs_nse[1], **nse_kwargs)
axs_nse[1].set(title="DeepRec, 2002--2019")

# NSE GRACE
csr_nse_post.dr.projplot_basins(spatial_dummy, ax=axs_nse[2], **nse_kwargs)
axs_nse[2].set(title="GRACE, 2002--2019")

for ax in axs_geo:
    # Add basinlines
    gdf_basin_shapes.dr.projplot(ax=ax, **basinlines_kwargs)

# Add colorbars
subfigs[0].colorbar(p_rmse, ax=axs_rmse, label="RMSE (mm)", **rmse_cbar_kwargs)
subfigs[0].colorbar(p_nse, ax=axs_nse, label="NSE", **nse_cbar_kwargs)


### BOXPLOTS ###

locator = ticker.MultipleLocator(10)

# Iterate over different time periods
for i, (ax, period, df_error) in enumerate(zip(axs_box, TIME_PERIODS, df_errors)):
    # Plot
    sns.boxplot(
        df_error,
        x="model_verbose",
        y="error",
        hue="model",
        palette=colors_dict,
        showfliers=False,
        width=0.5,
        saturation=1,
        ax=ax,
        legend=True if i == 3 else False,
    )
    ax.set(ylabel="RMSE (mm)", xlabel=None)
    ax.set_title(period[0] + "--" + period[1])
    ax.xaxis.set_visible(False)
    ax.yaxis.grid()

axs_box[0].set(ylim=(0, 35))

# Move legend outside the axes
leg = subfigs[2].legend(loc="outside right")
for text in leg.texts:
    text.set_text(vars_verbose[text.get_text()])
axs_box[3].legend().remove()

### LETTERS ###

for n, ax in enumerate(axs_geo):
    ax.text(
        0.0,
        1.0 - 0.10,
        string.ascii_lowercase[n],
        transform=ax.transAxes,
        size="x-large",
        weight="bold",
    )
for n, ax in enumerate(axs_box):
    ax.text(
        0.0,
        1.0 + 0.05,
        string.ascii_lowercase[n + len(axs_geo)],
        ha="center",
        transform=ax.transAxes,
        size="x-large",
        weight="bold",
    )

fig.savefig(FIGURE_DIR / "bdgt_closure.pdf", backend="pgf")
