# Satellite sea ice type

## Import libraries

In [None]:
import cartopy.crs as ccrs
import cmocean
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import download, plot

plt.style.use("seaborn-v0_8-notebook")

## Set Parameters

In [None]:
# Define time periods
periods = [
    slice(1991, 2000),
    slice(2001, 2010),
    slice(2011, 2020),
]

# Define path to NERSC data
NERSC_PATH = "/data/wp5/mangini_fabio/nersc_ice_age_v2p1"

## Define requests

In [None]:
collection_id = "satellite-sea-ice-edge-type"


def get_requests(year_start, year_stop):
    common_request = {
        "cdr_type": "cdr",
        "variable": "sea_ice_type",
        "region": "northern_hemisphere",
        "version": "3_0",
        "day": [f"{day:02d}" for day in range(1, 32)],
    }
    return [
        common_request
        | {
            "year": [str(year) for year in range(year_start, year_stop)],
            "month": [f"{month:02d}" for month in range(10, 13)],
        },
        common_request
        | {
            "year": [str(year + 1) for year in range(year_start, year_stop)],
            "month": [f"{month:02d}" for month in range(1, 5)],
        },
    ]


request_timeseries = get_requests(
    year_start=min([period.start for period in periods]),
    year_stop=max([period.stop for period in periods]),
)

requests_periods = {}
for period in periods:
    label = f"{period.start}-{period.stop}"
    requests_periods[label] = get_requests(period.start, period.stop)

## Define functions to cache

In [None]:
def get_nersc_data(ds_cds, nersc_path):
    if nersc_path is None:
        nersc_path = NERSC_PATH

    paths = ds_cds["time"].dt.strftime(
        f"{nersc_path}/%Y/arctic25km_sea_ice_age_v2p1_%Y%m%d.nc"
    )
    ds_nersc = xr.open_mfdataset(
        set(paths.values),
        concat_dim="time",
        combine="nested",
        data_vars="minimal",
        coords="minimal",
        compat="override",
    )
    ds_nersc = ds_nersc.rename(x="xc", y="yc")
    ds_nersc = ds_nersc.assign_coords(
        {coord: ds_cds[coord] for coord in ("xc", "yc", "longitude", "latitude")}
    )
    return ds_nersc


def get_nersc_multiyear_ice(
    ds_cds,
    use_fyi,
    age_threshold,
    conc_threshold,
    nersc_ice_age_path,
):
    ds_nersc = get_nersc_data(ds_cds, nersc_ice_age_path)

    if age_threshold is not None:
        assert not use_fyi
        assert conc_threshold is None
        return ds_nersc["sia"] > age_threshold

    assert conc_threshold is not None

    conc_myi = ds_nersc["conc_2yi"]
    n = 3
    while (varname := f"conc_{n}yi") in ds_nersc.variables:
        conc_myi += ds_nersc[varname]
        n += 1

    if not use_fyi:
        return conc_myi > conc_threshold

    conc_fyi = ds_nersc["conc_1yi"]
    return ((conc_myi + conc_fyi) > conc_threshold) & (conc_myi > conc_fyi)


def compute_spatial_sum(da, grid_cell_area, dim=("xc", "yc")):
    return grid_cell_area * da.sum(dim=dim)


def get_classification_mask(ds, use_ambiguous):
    da = ds["ice_type"]
    return (da >= 3) if use_ambiguous else (da == 3)


def compute_sea_ice_evaluation_diagnostics(
    ds,
    use_ambiguous,
    use_fyi,
    age_threshold,
    conc_threshold,
):
    # grid cell area of sea ice edge grid
    (dx,) = set(ds["xc"].diff("xc").values)
    grid_cell_area = (dx**2) * 1.0e-6  # 10^6 km2

    # Masks
    da_cds = get_classification_mask(ds, use_ambiguous)
    da_nersc = get_nersc_multiyear_ice(
        ds,
        use_fyi=use_fyi,
        age_threshold=age_threshold,
        conc_threshold=conc_threshold,
        nersc_ice_age_path=None,
    )

    # Fill variables
    units = "$10^6$km$^2$"
    dataarrays = {}
    dataarrays["mysi_extent"] = compute_spatial_sum(da_cds, grid_cell_area)
    dataarrays["mysi_extent"].attrs = {
        "standard_name": "multi_year_sea_ice_extent",
        "units": units,
        "long_name": "Multi-year sea ice extent",
    }

    dataarrays["mysi_extent_age"] = compute_spatial_sum(da_nersc, grid_cell_area)
    dataarrays["mysi_extent_age"].attrs = {
        "standard_name": "multi_year_sea_ice_extent_age",
        "units": units,
        "long_name": "Multi-year sea ice extent from NERSC sea ice age product",
    }

    over = compute_spatial_sum(da_cds & ~da_nersc, grid_cell_area)
    under = compute_spatial_sum(~da_cds & da_nersc, grid_cell_area)
    dataarrays["mysi_extent_bias"] = over - under
    dataarrays["mysi_extent_bias"].attrs = {
        "standard_name": "multi_year_sea_ice_extent_bias",
        "units": units,
        "long_name": "Multi-year sea ice extent bias",
    }

    dataarrays["iite"] = over + under
    dataarrays["iite"].attrs = {
        "standard_name": "integrated_ice_type_error",
        "units": units,
        "long_name": "Integrated ice type error",
    }
    return xr.Dataset(dataarrays)


def compute_multiyear_ice_percentage(da):
    da = da.groupby("time.month").map(
        lambda da: 100 * da.sum("time") / da.sizes["time"]
    )
    da.attrs = {
        "units": "%",
        "long_name": "Multi-year sea ice percentage",
    }
    return da.to_dataset(name="percentage")


def compute_cds_multiyear_ice_percentage(ds, use_ambiguous):
    da = get_classification_mask(ds, use_ambiguous)
    return compute_multiyear_ice_percentage(da)


def compute_nersc_multiyear_ice_percentage(ds, use_fyi, age_threshold, conc_threshold):
    da = get_nersc_multiyear_ice(
        ds,
        use_fyi=use_fyi,
        age_threshold=age_threshold,
        conc_threshold=conc_threshold,
        nersc_ice_age_path=None,
    )
    return compute_multiyear_ice_percentage(da)

## Download and transform data

In [None]:
nersc_kwargs = {
    "use_fyi": True,
    "age_threshold": None,
    "conc_threshold": 0.15,
}
kwargs = {
    "chunks": {"year": 1},
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
}

# NERSC Maps
datasets = []
for period, requests in requests_periods.items():
    print(f"NERSC Maps: {period=}")
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=compute_nersc_multiyear_ice_percentage,
        transform_func_kwargs=nersc_kwargs,
        transform_chunks=False,
        **kwargs,
    )
    datasets.append(ds.expand_dims(period=[period]))
ds_nersc = xr.concat(datasets, "period")

# CDS Maps
datasets = []
for period, requests in requests_periods.items():
    for use_ambiguous in (True, False):
        print(f"CDS Maps: {period=} {use_ambiguous=}")
        ds = download.download_and_transform(
            collection_id,
            requests,
            transform_func=compute_cds_multiyear_ice_percentage,
            transform_func_kwargs={"use_ambiguous": use_ambiguous},
            transform_chunks=False,
            **kwargs,
        )
        datasets.append(ds.expand_dims(use_ambiguous=[use_ambiguous], period=[period]))
ds_cds = xr.merge(datasets)

# Timeseries
datasets = []
for use_ambiguous in (True, False):
    print(f"Timeseries: {use_ambiguous=}")
    ds = download.download_and_transform(
        collection_id,
        request_timeseries,
        transform_func=compute_sea_ice_evaluation_diagnostics,
        transform_func_kwargs=nersc_kwargs | {"use_ambiguous": use_ambiguous},
        **kwargs,
    )
    datasets.append(ds.expand_dims(use_ambiguous=[use_ambiguous]).compute())
ds_timeseries = xr.concat(datasets, "use_ambiguous")

## Plotting functions

In [None]:
def rearrange_year_vs_monthday(ds):
    ds = ds.convert_calendar("noleap")
    year = ds["time"].dt.year
    year = year.where(ds["time"].dt.month >= 10, year - 1)
    coords = {
        "year": ("time", year.values),
        "monthday": ("time", ds["time"].dt.strftime("%m-%d").values),
    }
    ds = ds.assign_coords(coords)
    return ds.set_index(time=tuple(coords)).unstack("time")


def plot_against_monthday(ds, cmap="viridis", **kwargs):
    defaults = {
        "row": "variable",
        "x": "time",
        "hue": "year",
        "add_legend": False,
        "figsize": (10, 10),
        "sharey": False,
    }
    ds = rearrange_year_vs_monthday(ds)

    da = ds.to_array()
    time = pd.to_datetime(
        [
            f"200{'1' if int(monthday[:2]) >= 10 else '2'}-{monthday}"
            for monthday in da["monthday"].values
        ]
    )
    da = da.assign_coords(time=("monthday", time)).sortby("time")

    colors = plt.get_cmap(cmap, da.sizes["year"]).colors
    with plt.rc_context({"axes.prop_cycle": plt.cycler(color=colors)}):
        facet = da.plot(**(defaults | kwargs))

    for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
        ax.grid()
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%d %b"))
        ax.xaxis.set_tick_params(rotation=45)

    for ax in facet.axs[:, 0]:
        ax.set_ylabel(ds[sel_dict["variable"]].attrs["units"])

    scalar_mappable = plt.cm.ScalarMappable(
        cmap=cmap,
        norm=plt.Normalize(vmin=da["year"].min(), vmax=da["year"].max()),
    )
    facet.fig.colorbar(scalar_mappable, ax=facet.axs, label="year")
    for label in facet.row_labels:
        *_, variable = label.get_text().split()
        long_name = ds[variable].attrs["long_name"].replace("from", "from\n")
        label.set_text(long_name)
    return facet


def plot_maps(da, coastline_color="limegreen", **kwargs):
    defaults = {
        "row": "period",
        "col": "month",
        "projection": ccrs.Stereographic(central_latitude=90.0),
        "cmap": cmocean.cm.ice,
    }
    da = da.sel(
        xc=slice(-2.5e3, 2.5e3),
        yc=slice(2.5e3, -2.5e3),
        month=[10, 11, 12, 1, 2, 3, 4],
    )
    facet = plot.projected_map(da, **(defaults | kwargs))
    for ax in facet.axs.flatten():
        ax.coastlines(color=coastline_color, lw=1)
    return facet

## Plot timeseries

In [None]:
facet = plot_against_monthday(ds_timeseries, col="use_ambiguous")
plt.show()

## Plot NERSC

In [None]:
facet = plot_maps(ds_nersc["percentage"])
plt.show()

## Plot CDS

In [None]:
for use_ambiguous, ds in ds_cds.groupby("use_ambiguous", squeeze=False):
    facet = plot_maps(ds["percentage"])
    facet.fig.suptitle(f"{use_ambiguous=}")
    plt.show()