# Figure 12

Correlations between normalized parameter residuals from the synthetic retrieval experiment.

In [None]:
from string import ascii_lowercase as abc

import cmcrameri.cm as cmc
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from lizard.mpltools import style
from lizard.writers.figure_to_file import write_figure
from scipy.stats import pearsonr

from si_clouds.io.readers.ancillary import read_ancillary_data
from si_clouds.io.readers.oem_result import read_oem_result_concat

In [None]:
ds_anc = read_ancillary_data()

In [None]:
ds_a_amb, ds_op_amb, _, ds_syn_amb = read_oem_result_concat(
    version="pub_r2_syn_amb_v1", test_id="random", write=False
)

In [None]:
ds_a_amb, ds_op_amb, ds_syn_amb = xr.align(
    ds_a_amb, ds_op_amb, ds_syn_amb, join="inner"
)

In [None]:
# overview of some statistics for these simulations
print(ds_anc.era5_tcwv.sel(time=ds_a_amb.time).mean("time").item())
print(ds_anc.era5_tcwv.sel(time=ds_a_amb.time).std("time").item())
print(ds_anc.era5_tcwv.sel(time=ds_a_amb.time).quantile(0.25, "time").item())
print(ds_anc.era5_tcwv.sel(time=ds_a_amb.time).quantile(0.75, "time").item())

print(ds_anc.era5_skt.sel(time=ds_a_amb.time).mean("time").item() - 273.15)
print(ds_anc.era5_skt.sel(time=ds_a_amb.time).std("time").item())
print(
    ds_anc.era5_skt.sel(time=ds_a_amb.time).quantile(0.25, "time").item()
    - 273.15
)
print(
    ds_anc.era5_skt.sel(time=ds_a_amb.time).quantile(0.75, "time").item()
    - 273.15
)

In [None]:
print(ds_op_amb.dgf.mean("time"))
print(ds_op_amb.dgf.median("time"))
print(ds_op_amb.dgf.quantile(0.25, "time"))
print(ds_op_amb.dgf.quantile(0.75, "time"))

In [None]:
# compute correlations from covariance matrix
da_variance_1 = ds_op_amb.unc_aposteriori.sel(
    update=ds_op_amb.conv_i,
    x_vars1=ds_op_amb.x_vars1,
    x_vars2=ds_op_amb.x_vars1,
)
da_variance_2 = ds_op_amb.unc_aposteriori.sel(
    update=ds_op_amb.conv_i,
    x_vars1=ds_op_amb.x_vars2,
    x_vars2=ds_op_amb.x_vars2,
)
da_corr = ds_op_amb.unc_aposteriori.sel(update=ds_op_amb.conv_i) / np.sqrt(
    da_variance_1 * da_variance_2
)
da_corr = da_corr.mean("time")
da_corr.where(np.abs(da_corr) > 0.1).values

In [None]:
xb_vars = np.append(ds_op_amb.x_vars.values, ds_op_amb.b_vars.values)
for v in ds_a_amb.b_vars.values:
    ds_a_amb[v + "_std"] = np.sqrt(
        ds_a_amb.unc_b.sel(b_vars1=v, b_vars2=v).mean("time")
    )

print("Number of parameters:", len(xb_vars))

i = 0
done = []
x1_lst = []
x2_lst = []
xb_var1_lst = []
xb_var2_lst = []
corr_lst = []
for xb_var1 in xb_vars:
    for xb_var2 in xb_vars:
        if xb_var1 == xb_var2:
            continue
        if (
            xb_var1 in ds_op_amb.b_vars.values
            and xb_var2 in ds_op_amb.b_vars.values
        ):
            continue
        if xb_var1 + xb_var2 in done:
            continue

        if xb_var1 in ds_op_amb.x_vars.values:
            x1 = (ds_op_amb[xb_var1] - ds_syn_amb[xb_var1]) / ds_a_amb[
                xb_var1 + "_std"
            ]
        elif xb_var1 in ds_op_amb.b_vars.values:
            x1 = (ds_a_amb[xb_var1] - ds_syn_amb[xb_var1]) / ds_a_amb[
                xb_var1 + "_std"
            ]

        if xb_var2 in ds_op_amb.x_vars.values:
            x2 = (ds_op_amb[xb_var2] - ds_syn_amb[xb_var2]) / ds_a_amb[
                xb_var2 + "_std"
            ]
        elif xb_var2 in ds_op_amb.b_vars.values:
            x2 = (ds_a_amb[xb_var2] - ds_syn_amb[xb_var2]) / ds_a_amb[
                xb_var2 + "_std"
            ]

        # change x and y axis for negative correlations for easier plot reading
        corr = pearsonr(x1.values, x2.values)[0]
        corr_lst.append(corr)
        x1_lst.append(x1)
        x2_lst.append(x2)
        xb_var1_lst.append(xb_var1)
        xb_var2_lst.append(xb_var2)

        done.append(xb_var1 + xb_var2)
        done.append(xb_var2 + xb_var1)

In [None]:
labels = {
    "t_as": r"$T_{as}$",
    "t_si": r"$T_{si}$",
    "depth_hoar_corr_length": r"$\xi_{DH}$",
    "depth_hoar_density": r"$\rho_{DH}$",
    "depth_hoar_thickness": r"$h_{DH}$",
    "wind_slab_corr_length": r"$\xi_{WS}$",
    "wind_slab_density": r"$\rho_{WS}$",
    "wind_slab_thickness": r"$h_{WS}$",
    "yi_fraction": r"$f_{yi}$",
    "cwp": r"CWP",
    "specularity": r"$s$",
}

cmap = cmc.berlin
norm = mcolors.BoundaryNorm(np.arange(-1, 1.01, 0.2), cmap.N)

In [None]:
threshold = 0.1

fig, axes = plt.subplots(
    4, 4, figsize=(5, 5), layout="constrained", sharex=True, sharey=True
)

for i, ax in enumerate(axes.flat[:-2]):
    j = i % 4  # column index
    k = i // 4  # row index (letter)
    ax.annotate(
        f"({abc[k]}{j+1})",
        xy=(1, 1),
        xycoords="axes fraction",
        ha="right",
        va="top",
        color="k",
    )

i_ax = 0
for i in np.argsort(corr_lst)[::-1]:
    if abs(corr_lst[i]) < threshold:
        continue

    ax = axes.flat[i_ax]
    im = ax.scatter(
        x1_lst[i],
        x2_lst[i],
        c=np.repeat(corr_lst[i], len(x1_lst[i])),
        s=1,
        lw=0,
        cmap=cmap,
        norm=norm,
    )

    label = labels[xb_var1_lst[i]] + "~" + labels[xb_var2_lst[i]]

    ax.annotate(
        label,
        xy=(0.5, 1),
        xycoords="axes fraction",
        ha="center",
        va="bottom",
    )

    # annotate the correlation
    ax.annotate(
        f"{round(corr_lst[i], 2):.2f}",
        xy=(1, 0),
        xycoords="axes fraction",
        ha="right",
        va="bottom",
    )

    i_ax += 1

for ax in axes.flat:
    ax.set_ylim(-3, 3)
    ax.set_xlim(-3, 3)
    ax.set_aspect("equal")
    ax.set_xticks(np.arange(-3, 4, 2))
    ax.set_yticks(np.arange(-3, 4, 2))
    ax.set_xticklabels([rf"{i}$\sigma$" for i in np.arange(-3, 4, 2)])
    ax.set_yticklabels([rf"{i}$\sigma$" for i in np.arange(-3, 4, 2)])
    ax.grid()

axes[-1, -2].set_axis_off()
axes[-1, -1].set_axis_off()

axes[-1, 0].set_xlabel("Norm. res.")
axes[-1, 0].set_ylabel("Norm. res.")

plt.draw()
cax = fig.add_axes(
    (
        axes[-2, -2].get_position().x0,
        axes[-1, -2].get_position().y1 - 0.04,
        axes[-2, -1].get_position().x1 - axes[-2, -2].get_position().x0,
        0.02,
    )
)
fig.colorbar(
    im, cax=cax, label="Correlation coefficient", orientation="horizontal"
)

write_figure(fig, f"paper/fig12.png", dpi=300, bbox_inches="tight")

plt.show()