# Figure 7

Cloud water path (CWP) retrieval skill based on synthetic experiments.

In [None]:
from string import ascii_lowercase as abc

import cmcrameri.cm as cmc
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import xarray as xr
from lizard.mpltools import style
from lizard.readers.band_pass import read_band_pass
from lizard.writers.figure_to_file import write_figure

from si_clouds.io.readers.ancillary import read_ancillary_data
from si_clouds.io.readers.oem_result import read_oem_result_concat

In [None]:
ds_bp = read_band_pass("HAMP")
ds_anc = read_ancillary_data()

In [None]:
def rmse(group):
    group = group.unstack()
    return np.sqrt(((group["cwp_true"] - group["cwp_pred"]) ** 2).mean("time"))


def prmse(group):
    group = group.unstack()
    return (
        np.sqrt(
            ((group["cwp_true"] - group["cwp_pred"]) ** 2).mean("time")
            / (group["cwp_true"] ** 2).mean("time")
        )
        * 100
    )


def bias(group):
    group = group.unstack()
    return (group["cwp_true"] - group["cwp_pred"]).mean("time")


def mae(group):
    group = group.unstack()
    return np.abs(group["cwp_true"] - group["cwp_pred"]).mean("time")

In [None]:
ds_a_main, ds_op_main, _, ds_syn_main = (
    read_oem_result_concat(
        version="pub_r2_syn_22_183_v1", test_id="random", write=False
    )
)
ds_a_fu, ds_op_fu, _, _ = read_oem_result_concat(
    version="pub_r2_clearsky_v1", test_id="", write=False
)

In [None]:
print(ds_a_main.conv.mean())

In [None]:
# align the datasets
ds_a_main, ds_op_main, ds_syn_main = xr.align(
    ds_a_main, ds_op_main, ds_syn_main
)

In [None]:
print(ds_anc.era5_tcwv.sel(time=ds_anc.ix_sea_ice).mean("time").item())
print(ds_anc.era5_tcwv.sel(time=ds_anc.ix_sea_ice).std("time").item())
print(ds_anc.era5_tcwv.sel(time=ds_anc.ix_sea_ice).quantile(0.25, "time").item())
print(ds_anc.era5_tcwv.sel(time=ds_anc.ix_sea_ice).quantile(0.75, "time").item())

print(ds_anc.era5_skt.sel(time=ds_anc.ix_sea_ice).mean("time").item() - 273.15)
print(ds_anc.era5_skt.sel(time=ds_anc.ix_sea_ice).std("time").item())
print(
    ds_anc.era5_skt.sel(time=ds_anc.ix_sea_ice).quantile(0.25, "time").item() - 273.15
)
print(
    ds_anc.era5_skt.sel(time=ds_anc.ix_sea_ice).quantile(0.75, "time").item() - 273.15
)

In [None]:
# overview of some statistics for these simulations
print(ds_anc.era5_tcwv.sel(time=ds_a_main.time).mean("time").item())
print(ds_anc.era5_tcwv.sel(time=ds_a_main.time).std("time").item())
print(ds_anc.era5_tcwv.sel(time=ds_a_main.time).quantile(0.25, "time").item())
print(ds_anc.era5_tcwv.sel(time=ds_a_main.time).quantile(0.75, "time").item())

print(ds_anc.era5_skt.sel(time=ds_a_main.time).mean("time").item() - 273.15)
print(ds_anc.era5_skt.sel(time=ds_a_main.time).std("time").item())
print(
    ds_anc.era5_skt.sel(time=ds_a_main.time).quantile(0.25, "time").item()
    - 273.15
)
print(
    ds_anc.era5_skt.sel(time=ds_a_main.time).quantile(0.75, "time").item()
    - 273.15
)

In [None]:
ds_cwp_skill = xr.Dataset()
ds_cwp_skill.coords["time"] = ds_syn_main.time
ds_cwp_skill["cwp_true"] = ds_syn_main.cwp
ds_cwp_skill["cwp_pred"] = ds_op_main.cwp

bins_cwp = np.arange(0, 501, 50) * 1e-3
bins_center_cwp = (bins_cwp[1:] + bins_cwp[:-1]) / 2

ds_cwp_skill["rmse"] = ds_cwp_skill.groupby_bins(
    ds_cwp_skill.cwp_true, bins=bins_cwp
).apply(rmse)
ds_cwp_skill["prmse"] = ds_cwp_skill.groupby_bins(
    ds_cwp_skill.cwp_true, bins=bins_cwp
).apply(prmse)
ds_cwp_skill["bias"] = ds_cwp_skill.groupby_bins(
    ds_cwp_skill.cwp_true, bins=bins_cwp
).apply(bias)
ds_cwp_skill["mae"] = ds_cwp_skill.groupby_bins(
    ds_cwp_skill.cwp_true, bins=bins_cwp
).apply(mae)

ds_cwp_skill["cwp_true_bin"] = bins_center_cwp

In [None]:
# RMSE from real clear-sky observations
# compute the RMSE of the full retrieval for observations far from the sea ice edge
rmse_real_all = np.sqrt(np.mean((ds_op_fu.cwp) ** 2)).item()

rmse_real_200km = np.sqrt(
    np.mean(
        (
            ds_op_fu.cwp.sel(
                time=ds_anc.ix_central_arctic.sel(time=ds_op_fu.time)
            )
        )
        ** 2
    )
).item()

print(rmse_real_200km)

In [None]:
fig, axes = plt.subplot_mosaic(
    [["hist", "rmse"], ["hist", "prmse"]],
    figsize=(7, 4),
    layout="constrained",
    gridspec_kw=dict(width_ratios=[2, 1]),
)

# 2d histogram of true and retrieved lwp
bins_cwp = np.arange(0, 501, 25)
bins_center_cwp = (bins_cwp[1:] + bins_cwp[:-1]) / 2
cwp_hist, bin_edge, bin_edge = np.histogram2d(
    ds_syn_main.cwp * 1e3,
    ds_op_main.cwp * 1e3,
    bins=bins_cwp,
    density=True,
)
cwp_hist[cwp_hist == 0] = np.nan

# compute the median retrieved cwp for bins of true cwp
da_median = (
    ds_op_main.cwp.groupby_bins(ds_syn_main.cwp * 1e3, bins=bins_cwp).median()
    * 1e3
)
da_q25 = (
    ds_op_main.cwp.groupby_bins(ds_syn_main.cwp * 1e3, bins=bins_cwp).quantile(
        0.25
    )
    * 1e3
)
da_q75 = (
    ds_op_main.cwp.groupby_bins(ds_syn_main.cwp * 1e3, bins=bins_cwp).quantile(
        0.75
    )
    * 1e3
)

im = axes["hist"].pcolormesh(
    bin_edge,
    bin_edge,
    cwp_hist.T,
    cmap=cmc.batlow,
    vmin=0,
)

# plot lines of median etc.
axes["hist"].plot(bins_center_cwp, da_median, color="white", linewidth=2.5)
axes["hist"].plot(
    bins_center_cwp, da_median, color="k", linestyle="--", label="Median"
)

axes["hist"].plot(bins_center_cwp, da_q25, color="white", linewidth=2.5)
axes["hist"].plot(
    bins_center_cwp, da_q25, color="k", linestyle=":", label="Q25, Q75"
)

axes["hist"].plot(bins_center_cwp, da_q75, color="white", linewidth=2.5)
axes["hist"].plot(bins_center_cwp, da_q75, color="k", linestyle=":")

axes["hist"].plot([0, 600], [0, 600], color="white", linewidth=2.5)
axes["hist"].plot([0, 600], [0, 600], color="k", label="1:1")

axes["hist"].set_aspect("equal")

axes["hist"].set_xlim(0, 500)
axes["hist"].set_ylim(0, 500)
axes["hist"].set_xlabel("True CWP [g m$^{-2}$]")
axes["hist"].set_ylabel("Retrieved CWP [g m$^{-2}$]")

axes["hist"].legend(loc="upper left", frameon=True, bbox_to_anchor=(0, 0.82))

# cwp skill
axes["rmse"].plot(
    ds_cwp_skill.cwp_true_bin * 1e3,
    ds_cwp_skill.rmse * 1e3,
    color="k",
    marker=".",
)
axes["prmse"].plot(
    ds_cwp_skill.cwp_true_bin * 1e3,
    ds_cwp_skill.prmse,
    color="k",
    marker=".",
)

# cwp rmse from real clear-sky data
axes["rmse"].scatter(
    0,
    rmse_real_all * 1e3,
    color=cmc.batlow(0.25),
    label="All",
    s=50,
    lw=0,
)
axes["rmse"].scatter(
    0,
    rmse_real_200km * 1e3,
    color=cmc.batlow(0.75),
    label="C. Arctic",
    s=50,
    lw=0,
)

axes["rmse"].legend(loc="upper left", frameon=True)

axes["rmse"].set_ylabel("RMSE [g m$^{-2}$]")
axes["prmse"].set_ylabel("PRMSE [%]")
axes["prmse"].set_xlabel("True CWP [g m$^{-2}$]")

axes["rmse"].set_xlim(0, 500)
axes["rmse"].set_ylim(0, 250)

axes["prmse"].set_xlim(0, 500)
axes["prmse"].set_ylim(0, 150)

for ax in [axes["rmse"], axes["prmse"], axes["hist"]]:
    ax.grid()
    ax.xaxis.set_major_locator(mticker.MultipleLocator(100))
    ax.xaxis.set_minor_locator(mticker.MultipleLocator(25))

for i, ax in enumerate(axes.values()):
    ax.annotate(
        f"({abc[i]})",
        xy=(0.71, 1),
        xycoords="axes fraction",
        ha="center",
        va="top",
    )

axes["hist"].yaxis.set_major_locator(mticker.MultipleLocator(100))
axes["hist"].yaxis.set_minor_locator(mticker.MultipleLocator(25))

plt.draw()

cax = fig.add_axes(
    [
        axes["hist"].get_position().x0 + 0.02,
        axes["hist"].get_position().y1 - 0.03,
        0.25,
        0.03,
    ]
)
fig.colorbar(im, cax=cax, label="Frequency", orientation="horizontal")

write_figure(
    fig,
    f"paper/fig07.png",
    dpi=300,
    bbox_inches="tight",
)

plt.show()