# Figure 10

HAMP retrieval and satellite observations before, during, and after rain-on-snow event.

In [None]:
from string import ascii_lowercase

import cartopy.crs as ccrs
import cmcrameri.cm as cmc
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
from lizard import ac3airlib
from lizard.mpltools import style
from lizard.readers.amsr2_sic import read_amsr2_sic
from lizard.readers.gpm_l1c import flag_gpml1c, read_gpm_l1c_swath
from lizard.readers.mira import read_mira
from lizard.readers.wales import read_wales
from lizard.writers.figure_to_file import write_figure

from si_clouds.helpers.sigmoid import sigmoid
from si_clouds.io.readers.ancillary import read_ancillary_data
from si_clouds.io.readers.oem_result import read_oem_result_concat
from si_clouds.io.readers.sensitivity import read_sensitivity_params

In [None]:
roi = [-60, 30, 75, 90]

In [None]:
ds_anc = read_ancillary_data()
popt = read_sensitivity_params()

In [None]:
test_id = ""
write = False

segment_ids = {
    "pre ROS": {
        "flight_id": "HALO-AC3_HALO_RF02",
        "segment_id": "HALO-AC3_HALO_RF02_hl11",
        "color": cmc.batlow(0.2),
        "version": "pub_r2_case_7_v1",
    },
    "ROS": {
        "flight_id": "HALO-AC3_HALO_RF03",
        "segment_id": "HALO-AC3_HALO_RF03_hl09",
        "color": cmc.batlow(0.5),
        "version": "pub_r2_case_5_v1",
    },
    "post ROS 1": {
        "flight_id": "HALO-AC3_HALO_RF04",
        "segment_id": "HALO-AC3_HALO_RF04_hl08",
        "color": cmc.batlow(0.7),
        "version": "pub_r2_case_1_v1",
    },
    "post ROS 2": {
        "flight_id": "HALO-AC3_HALO_RF04",
        "segment_id": "HALO-AC3_HALO_RF04_hl11",
        "color": cmc.batlow(0.8),
        "version": "pub_r2_case_8_v1",
    },
}

# add data for each of the segments
for name in segment_ids:
    segment_ids[name]["t0"], segment_ids[name]["t1"] = ac3airlib.segment_times(
        segment_ids[name]["segment_id"]
    )

    ds_a_segment, ds_op_segment, _, _ = (
        read_oem_result_concat(
            segment_ids[name]["version"], test_id, write=write
        )
    )

    # remove times where the retrieval was not valid (ze or 2m temperature)
    ds_a_segment = ds_a_segment.sel(
        time=ds_anc.ix_retrieval_valid.sel(time=ds_a_segment.time)
    )
    ds_op_segment = ds_op_segment.sel(
        time=ds_anc.ix_retrieval_valid.sel(time=ds_op_segment.time)
    )

    # add a priori and a posteriori data for the segment time
    segment_ids[name]["ds_a"] = ds_a_segment.sel(
        time=slice(segment_ids[name]["t0"], segment_ids[name]["t1"])
    )
    segment_ids[name]["ds_op"] = ds_op_segment.sel(
        time=slice(segment_ids[name]["t0"], segment_ids[name]["t1"])
    )

    # add ancillary data
    segment_ids[name]["ds_anc"] = ds_anc.sel(
        time=slice(segment_ids[name]["t0"], segment_ids[name]["t1"])
    )

    # radar data
    segment_ids[name]["ds_rad"] = read_mira(
        segment_ids[name]["flight_id"]
    ).sel(time=slice(segment_ids[name]["t0"], segment_ids[name]["t1"]))

    # wales data
    segment_ids[name]["ds_wales"] = read_wales(
        segment_ids[name]["flight_id"], product="bsrgl"
    ).sel(time=slice(segment_ids[name]["t0"], segment_ids[name]["t1"]))

    print(segment_ids[name].keys())

In [None]:
for name in segment_ids:
    print(name, segment_ids[name]["t0"], segment_ids[name]["t1"])

In [None]:
# check the convergence fraction for each of the segments
print(segment_ids["pre ROS"]["ds_a"]["conv"].mean().item())
print(segment_ids["ROS"]["ds_a"]["conv"].mean().item())
print(
    np.append(
        segment_ids["post ROS 1"]["ds_a"]["conv"].values,
        segment_ids["post ROS 2"]["ds_a"]["conv"].values,
    ).mean()
)

# during the rain event
print(
    segment_ids["ROS"]["ds_a"]["conv"]
    .sel(time=ds_anc.lat.sel(time=segment_ids["ROS"]["ds_a"].time) < 83)
    .mean()
    .item()
)
print(
    segment_ids["ROS"]["ds_a"]["conv"]
    .sel(time=ds_anc.lat.sel(time=segment_ids["ROS"]["ds_a"].time) >= 83)
    .mean()
    .item()
)

In [None]:
# assign segments to axis
event_dct = {0: ["pre ROS"], 1: ["ROS"], 2: ["post ROS 1", "post ROS 2"]}

# satellite snapshot
snapshot_def = {
    "instrument": "SSMIS",
    "satellite": "DMSP-F16",
    "granules": [
        "094933",
        "094948",
        "094962",
    ],
    "swath": "S4",
    "channel": 17,
    "label": "89 GHz (V-pol)",
}

extent = [-17, 17, 77, 90]
lat_min = 79
lat_max = 87.5

fig, axes = plt.subplot_mosaic(
    [
        ["im_0_map", "im_1_map", "im_2_map"],
        ["ze_0", "ze_1", "ze_2"],
        ["wales_0", "wales_1", "wales_2"],
        ["cwp_0", "cwp_1", "cwp_2"],
        ["t_as_0", "t_as_1", "t_as_2"],
        [
            "t_si_0",
            "t_si_1",
            "t_si_2",
        ],
        [
            "wind_slab_corr_length_0",
            "wind_slab_corr_length_1",
            "wind_slab_corr_length_2",
        ],
        [
            "depth_hoar_corr_length_0",
            "depth_hoar_corr_length_1",
            "depth_hoar_corr_length_2",
        ],
        [
            "wind_slab_thickness_0",
            "wind_slab_thickness_1",
            "wind_slab_thickness_2",
        ],
    ],
    gridspec_kw=dict(height_ratios=[1.7, 1, 1, 1, 1, 1, 1, 1, 1]),
    figsize=(6, 8),
    layout="constrained",
    per_subplot_kw={
        tuple([f"im_{i}_map" for i in range(3)]): {
            "projection": ccrs.NorthPolarStereo()
        },
    },
)
x_label_name = "wind_slab_thickness"

for i, (name, ax) in enumerate(axes.items()):
    j = i % 3  # column index
    k = i // 3  # row index (letter)
    ha = "left"
    x = 0.005
    if "map" not in name:
        ha = "right"
        x = 0.995
    ax.annotate(
        f"({ascii_lowercase[k]}{j+1})",
        xy=(x, 0.99),
        xycoords="axes fraction",
        ha=ha,
        va="top",
        color="k",
    )

axes["im_0_map"].annotate(
    "12 March (before ROS)",
    xy=(0.5, 1.02),
    xycoords="axes fraction",
    ha="center",
    va="bottom",
)

axes["im_1_map"].annotate(
    "13 March (during ROS)",
    xy=(0.5, 1.02),
    xycoords="axes fraction",
    ha="center",
    va="bottom",
)

axes["im_2_map"].annotate(
    "14 March (after ROS)",
    xy=(0.5, 1.02),
    xycoords="axes fraction",
    ha="center",
    va="bottom",
)

# remove tick labels
for name, ax in axes.items():
    if name[-1] in ["1", "2"]:
        ax.set_yticklabels([])
    if name[-3:] != "map" and x_label_name not in name:
        ax.set_xticklabels([])
    if "map" not in name:
        ax.set_xlim(lat_min, lat_max)
        ax.xaxis.set_major_locator(mticker.MultipleLocator(2))
        ax.xaxis.set_minor_locator(mticker.MultipleLocator(1))
    if x_label_name in name:
        ax.set_xlabel("Latitude [°N]")

# radar reflectivity
batlow_nan = cmc.batlow.copy()
batlow_nan.set_bad(color="white")
height_max = 7000
for i, events in event_dct.items():
    axes[f"ze_{i}"].set_facecolor("gray")
    for event in events:
        im_ze = axes[f"ze_{i}"].pcolormesh(
            segment_ids[event]["ds_rad"].lat.sel(
                time=slice(segment_ids[event]["t0"], segment_ids[event]["t1"])
            ),
            segment_ids[event]["ds_rad"].height.sel(
                height=slice(0, height_max)
            )
            * 1e-3,
            10
            * np.log10(
                segment_ids[event]["ds_rad"].Zg.sel(
                    height=slice(0, height_max),
                    time=slice(
                        segment_ids[event]["t0"], segment_ids[event]["t1"]
                    ),
                )
            ).T,
            cmap=batlow_nan,
            vmin=-35,
            vmax=20,
        )
axes[f"ze_0"].set_ylabel("Ht. [km]")
axes[f"ze_0"].set_ylim(0, 7)
axes[f"ze_0"].yaxis.set_major_locator(mticker.MultipleLocator(5))
axes[f"ze_0"].yaxis.set_minor_locator(mticker.MultipleLocator(1))

# wales backscatter ratio
cmap = cmc.davos_r.copy()
cmap.set_bad("gray")
norm = mcolors.LogNorm(1, 200)
for i, events in event_dct.items():
    axes[f"wales_{i}"].set_facecolor("gray")
    for event in events:
        im_wales = axes[f"wales_{i}"].pcolormesh(
            segment_ids[event]["ds_wales"].latitude.sel(
                time=slice(segment_ids[event]["t0"], segment_ids[event]["t1"])
            ),
            segment_ids[event]["ds_wales"].altitude.sel(
                altitude=slice(height_max, 0)
            )
            * 1e-3,
            (
                segment_ids[event]["ds_wales"]
                .backscatter_ratio.where(
                    segment_ids[event]["ds_wales"]["flags"] == 0
                )
                .sel(
                    altitude=slice(height_max, 0),
                    time=slice(
                        segment_ids[event]["t0"], segment_ids[event]["t1"]
                    ),
                )
            ).T,
            cmap=cmap,
            norm=norm,
            shading="nearest",
        )
axes[f"wales_0"].set_ylabel("Ht. [km]")
axes[f"wales_0"].set_ylim(0, 7)
axes[f"wales_0"].yaxis.set_major_locator(mticker.MultipleLocator(5))
axes[f"wales_0"].yaxis.set_minor_locator(mticker.MultipleLocator(1))


# state space
x_vars = [
    "cwp",
    "t_as",
    "t_si",
    "wind_slab_corr_length",
    "depth_hoar_corr_length",
    "wind_slab_thickness",
]
labels = [
    r"CWP [g m$^{-2}$]",
    r"$T_{as}$ [K]",
    r"$T_{si}$ [K]",
    r"$\xi_{WS}$ [mm]",
    r"$\xi_{DH}$ [mm]",
    r"$h_{WS}$ [cm]",
]
y_min = [0, 250, 250, 0.05, 0.15, 10]
y_max = [400, 273, 273, 0.22, 0.6, 30]
loc_maj = [200, 10, 10, 0.1, 0.2, 10]
loc_min = [50, 2, 2, 0.025, 0.05, 2]
factor = [1e3, 1, 1, 1, 1, 1e2]
kwds = dict(s=5, lw=0, marker=".")
for i, events in event_dct.items():
    for event in events:
        for i_var, x_var in enumerate(x_vars):
            ax = axes[f"{x_var}_{i}"]
            ax.fill_between(
                segment_ids[event]["ds_anc"].lat.sel(
                    time=segment_ids[event]["ds_a"].time
                ),
                (
                    segment_ids[event]["ds_a"][x_var]
                    - segment_ids[event]["ds_a"][x_var + "_std"]
                )
                * factor[i_var],
                (
                    segment_ids[event]["ds_a"][x_var]
                    + segment_ids[event]["ds_a"][x_var + "_std"]
                )
                * factor[i_var],
                color="gray",
                linewidths=0,
                alpha=0.25,
            )
            ax.fill_between(
                segment_ids[event]["ds_anc"].lat.sel(
                    time=segment_ids[event]["ds_op"].time
                ),
                (
                    segment_ids[event]["ds_op"][x_var]
                    - segment_ids[event]["ds_op"][x_var + "_std"]
                )
                * factor[i_var],
                (
                    segment_ids[event]["ds_op"][x_var]
                    + segment_ids[event]["ds_op"][x_var + "_std"]
                )
                * factor[i_var],
                color="coral",
                linewidths=0,
                alpha=0.25,
            )
            ax.scatter(
                segment_ids[event]["ds_anc"].lat.sel(
                    time=segment_ids[event]["ds_a"].time
                ),
                segment_ids[event]["ds_a"][x_var] * factor[i_var],
                color="gray",
                label="A priori",
                **kwds,
            )
            ax.scatter(
                segment_ids[event]["ds_anc"].lat.sel(
                    time=segment_ids[event]["ds_op"].time
                ),
                segment_ids[event]["ds_op"][x_var] * factor[i_var],
                color="coral",
                label="Optimal",
                **kwds,
            )
            ax.yaxis.set_major_locator(mticker.MultipleLocator(loc_maj[i_var]))
            ax.yaxis.set_minor_locator(mticker.MultipleLocator(loc_min[i_var]))

            ax.set_ylim(y_min[i_var], y_max[i_var])
            if i == 0:
                ax.set_ylabel(labels[i_var])

        # plot cwp sensitivity
        axes[f"cwp_{i}"].scatter(
            segment_ids[event]["ds_anc"].lat.sel(
                time=segment_ids[event]["ds_a"].time
            ),
            sigmoid(
                segment_ids[event]["ds_anc"].dist_sic_0_50.sel(
                    time=segment_ids[event]["ds_a"].time
                ),
                *popt,
            )
            * 1e3,
            color=cmc.batlow(0),
            label="CWP det.",
            **kwds,
        )

        # plot era5
        axes[f"cwp_{i}"].scatter(
            segment_ids[event]["ds_anc"].lat.sel(
                time=segment_ids[event]["ds_a"].time
            ),
            segment_ids[event]["ds_anc"].era5_tclw.sel(
                time=segment_ids[event]["ds_a"].time
            )
            * 1e3,
            color=cmc.batlow(0.5),
            label="ERA5",
            **kwds,
        )

        # plot kt19
        axes[f"t_as_{i}"].scatter(
            segment_ids[event]["ds_anc"].lat.sel(
                time=segment_ids[event]["ds_a"].time
            ),
            segment_ids[event]["ds_anc"]
            .kt19_bt.where(ds_anc.ix_clear_sky_kt19)
            .sel(time=segment_ids[event]["ds_a"].time)
            / 0.995,
            color=cmc.batlow(0.25),
            label="KT-19",
            **kwds,
        )

# plot satellite images
for i, granule in enumerate(snapshot_def["granules"]):
    ax = axes[f"im_{i}_map"]
    ax.set_facecolor("lightgray")
    ax.set_extent(extent)
    ax.coastlines(color="white", linewidth=0.5)
    gl = ax.gridlines(
        crs=ccrs.PlateCarree(),
        draw_labels=["left", "bottom", "right"],
        xlocs=mticker.FixedLocator(np.arange(-180, 180, 10)),
        ylocs=mticker.FixedLocator(np.arange(70, 90, 2)),
        x_inline=False,
        y_inline=False,
        rotate_labels=False,
        linewidth=0.25,
        color="#1F7298",
        alpha=0.5,
        zorder=5,
        xlabel_style={"size": 9},
        ylabel_style={"size": 9},
    )

    ds_sat = read_gpm_l1c_swath(
        instrument=snapshot_def["instrument"],
        satellite=snapshot_def["satellite"],
        granule=granule,
        roi=[roi[0] - 30, roi[2] + 60, roi[1] - 5, roi[3] + 5],
        swath=snapshot_def["swath"],
        add_index=True,
    )
    ds_sat = flag_gpml1c(ds_sat)

    # get scan time in the area covered by HALO
    ix_region = (
        (ds_sat.lat > 80)
        & (ds_sat.lat < 87)
        & (ds_sat.lon > -15)
        & (ds_sat.lon < 15)
    )
    da_sat_time_min = ds_sat.scan_time.sel(x=ix_region.any("y")).min(
        ("x", "y")
    )
    da_sat_time_max = ds_sat.scan_time.sel(x=ix_region.any("y")).max(
        ("x", "y")
    )
    da_sat_time = ds_sat.scan_time.sel(x=ix_region.any("y")).mean(("x", "y"))
    time_str = da_sat_time.dt.strftime("%H:%M").item()
    print(time_str)
    # ax.annotate(
    #    time_str,
    #    xy=(0.99, 0.99),
    #    xycoords="axes fraction",
    #    ha="right",
    #    va="top",
    #    color="k",
    #    fontsize=9,
    # )

    cmap = cmc.batlowK
    bounds = np.arange(170, 271, 10)
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    im_tb = ax.scatter(
        ds_sat.lon,
        ds_sat.lat,
        c=ds_sat.tb.sel(channel=snapshot_def["channel"]),
        transform=ccrs.PlateCarree(),
        cmap=cmap,
        norm=norm,
        lw=0,
        s=1,
    )

    # plot 90% sea ice concentration contour
    ds_sic = read_amsr2_sic(
        pd.Timestamp(da_sat_time.dt.date.item()),
        path="/data/obs/campaigns/halo-ac3/auxiliary/sea_ice/daily_grid/",
    )

    # convert lon/lat to xyz of this projection
    xy_sic = ax.projection.transform_points(
        ccrs.PlateCarree(), ds_sic.lon.values, ds_sic.lat.values
    )
    x = xy_sic[:, :, 0].reshape(ds_sic.lon.shape)
    y = xy_sic[:, :, 1].reshape(ds_sic.lat.shape)
    ax.contour(
        x,
        y,
        ds_sic.sic,
        levels=[15],
        colors="k",
        linewidths=1,
        transform=ccrs.NorthPolarStereo(),
    )

    # plot the HALO track
    for event in event_dct[i]:
        bounds = [0, 0.5, 1]
        cmap = cmc.grayC
        norm = mcolors.BoundaryNorm(bounds, cmap.N)
        im_conv = ax.scatter(
            segment_ids[event]["ds_anc"].lon.sel(
                time=segment_ids[event]["ds_a"].time
            ),
            segment_ids[event]["ds_anc"].lat.sel(
                time=segment_ids[event]["ds_a"].time
            ),
            c=segment_ids[event]["ds_a"].conv,
            s=5,
            lw=0,
            cmap=cmap,
            norm=norm,
            transform=ccrs.PlateCarree(),
        )

plt.draw()
cax = fig.add_axes(
    (
        axes["ze_2"].get_position().x1 + 0.01,
        axes["im_2_map"].get_position().y0,
        0.01,
        axes["im_2_map"].get_position().y1
        - axes["im_2_map"].get_position().y0,
    )
)
fig.colorbar(im_tb, cax=cax, label="$T_b$ (91V) [K]")

plt.draw()
cax = fig.add_axes(
    (
        axes["ze_0"].get_position().x0 - 0.03,
        axes["im_0_map"].get_position().y0,
        0.01,
        axes["im_0_map"].get_position().y1
        - axes["im_0_map"].get_position().y0,
    )
)
cbar = fig.colorbar(
    im_conv,
    cax=cax,
    label="Converged",
    ticks=[0.25, 0.75],
    ticklocation="left",
)
cbar.set_ticks([0.25, 0.75])
cbar.set_ticks([], minor=True)
cbar.ax.set_yticklabels(["No", "Yes"])

plt.draw()
cax = fig.add_axes(
    (
        axes["ze_2"].get_position().x1 + 0.01,
        axes["ze_2"].get_position().y0,
        0.01,
        axes["ze_2"].get_position().y1 - axes["ze_2"].get_position().y0,
    )
)
fig.colorbar(
    im_ze,
    cax=cax,
    label="$Z_e$ [dBZ]",
    ticks=[-30, -15, 0, 15],
)

plt.draw()
cax = fig.add_axes(
    (
        axes["wales_2"].get_position().x1 + 0.01,
        axes["wales_2"].get_position().y0,
        0.01,
        axes["wales_2"].get_position().y1 - axes["wales_2"].get_position().y0,
    )
)
fig.colorbar(
    im_wales,
    cax=cax,
    label="BSR",
    ticks=[1, 100, 200],
)

plt.draw()
# legend
handles = []
labels = []
for ax in [axes["cwp_0"], axes["t_as_0"]]:
    h, l = ax.get_legend_handles_labels()
    handles += h
    labels += l
unique_legend = dict(zip(labels, handles))
fig.legend(
    unique_legend.values(),
    unique_legend.keys(),
    loc="upper left",
    frameon=False,
    markerscale=6,
    ncol=1,
    bbox_to_anchor=(
        axes["cwp_2"].get_position().x1 - 0.02,
        axes["cwp_2"].get_position().y1,
    ),
)

write_figure(
    fig,
    f"paper/fig10.png",
    dpi=300,
    bbox_inches="tight",
)

plt.show()