# Effect of climate change on sea ice concentration

## Import libraries

In [None]:
import warnings

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import download

plt.style.use("seaborn-v0_8-notebook")
warnings.filterwarnings("ignore", module="cf_xarray")

## Set parameters

In [None]:
# Time
year_start = 1979
year_stop = 2023

# Conc threshold for calculating extent
sic_threshold = 15

## Define request

In [None]:
collection_id = "satellite-sea-ice-concentration"

conc_request = {
    "cdr_type": "cdr",
    "variable": "all",
    "version": "v2",
}

request_dict = {
    # CDR
    "EUMETSAT-OSI-SAF (CDR)": download.update_request_date(
        conc_request | {"origin": "eumetsat_osi_saf"},
        start=f"{max(year_start, 1979)}-01",
        stop=f"{min(year_stop, 2015)}-12",
        stringify_dates=True,
    ),
    # interim CDR for later years
    "EUMETSAT-OSI-SAF (ICDR)": download.update_request_date(
        conc_request | {"cdr_type": "icdr", "origin": "eumetsat_osi_saf"},
        start=f"{max(year_start, 2016)}-01",
        stop=f"{min(year_stop, 2023)}-12",
        stringify_dates=True,
    ),
    # only CDR available for ESA-CCI
    "ESA-CCI (CDR)": download.update_request_date(
        conc_request | {"origin": "esa_cci"},
        start=f"{max(year_start, 2002)}-01",
        stop=f"{min(year_stop, 2017)}-12",
        stringify_dates=True,
    ),
}

## Define function to cache

In [None]:
def compute_siconc_time_series(ds, sic_threshold):
    ds = ds.convert_calendar("standard", align_on="date")

    # grid cell area of sea ice edge grid
    dims = ("xc", "yc")
    dx = np.abs(np.diff(ds[dims[0]][:2].values))[0]
    grid_cell_area = (dx**2) * 1.0e-6  # 10^6 km2

    # get sea ice concentration and convert to ice/water classification
    sic = ds.cf["sea_ice_area_fraction"]
    sic_error = ds.cf["sea_ice_area_fraction standard_error"]
    if sic.attrs.get("units", "") == "(0 - 1)":
        sic *= 100
        sic_error *= 100

    # compute extent
    dataarrays = {}
    sic_class = xr.where(sic >= sic_threshold, 2, 1)  # 1 = open water, 2 = ice
    dataarrays["extent"] = grid_cell_area * (sic_class - 1).sum(dim=dims)
    dataarrays["extent"].attrs = {
        "standard_name": "sea_ice_extent",
        "units": "$10^6$km$^2$",
        "long_name": "Sea ice extent",
    }

    # compute area
    dataarrays["area"] = grid_cell_area * 0.01 * sic.sum(dim=dims)
    dataarrays["area"].attrs = {
        "standard_name": "sea_ice_area",
        "units": "$10^6$km$^2$",
        "long_name": "Sea ice area",
    }

    # compute RMS error
    dataarrays["rms_error"] = np.sqrt((sic_error**2).mean(dim=dims))
    dataarrays["rms_error"].attrs = {
        "standard_name": "root_mean_square sea_ice_area_fraction standard_error",
        "units": "%",
        "long_name": "Root mean square sea ice area fraction standard error",
    }

    return xr.Dataset(dataarrays)

## Download and transform

In [None]:
datasets = []
for product, requests in request_dict.items():
    for region in [
        "northern_hemisphere",
        "southern_hemisphere",
    ]:
        print(f"{product = }, {region = }")
        regional_requests = [request | {"region": region} for request in requests]
        ds = download.download_and_transform(
            collection_id,
            regional_requests,
            transform_func=compute_siconc_time_series,
            transform_func_kwargs={"sic_threshold": sic_threshold},
            chunks={"year": 1},
            drop_variables=(
                "raw_ice_conc_values",
                "smearing_standard_error",
                "algorithm_standard_error",
                "status_flag",
            ),
        )
        datasets.append(ds.expand_dims(region=[region], product=[product]))
ds = xr.merge(datasets)

### Plotting functions

In [None]:
def rearrange_year_vs_dayofyear(ds):
    new_dims = ("year", "dayofyear")
    ds = ds.convert_calendar("noleap")
    ds = ds.assign_coords(
        {dim: ("time", getattr(ds["time"].dt, dim).values) for dim in new_dims}
    )
    return ds.set_index(time=new_dims).unstack("time")


def compute_yearly_extremes(da, reduction, min_samples=150, remove_outliers=True):
    grouped = da.groupby("time.year")
    mask = grouped.count() > min_samples
    da = getattr(grouped, reduction)(keep_attrs=True)
    if remove_outliers:
        da = da.chunk(year=-1)
        q1 = da.quantile(1 / 4, "year")
        q3 = da.quantile(3 / 4, "year")
        delta = 1.5 * (q3 - q1)
        mask &= (da >= q1 - delta) & (da <= q3 + delta)
    da = da.where(mask.compute(), drop=True)
    da.attrs["long_name"] = " ".join([reduction.title(), da.attrs["long_name"]])
    return da


def plot_against_dayofyear(
    ds,
    cmap="viridis",
    **kwargs,
):
    defaults = {
        "row": "variable",
        "x": "time",
        "hue": "year",
        "add_legend": False,
        "figsize": (8, 8),
    }
    kwargs = defaults | kwargs

    ds = rearrange_year_vs_dayofyear(ds)
    ds = ds.dropna("year", how="all")

    da = ds.to_array()
    da = da.assign_coords(time=("dayofyear", pd.date_range("2001-01-01", "2001-12-31")))

    colors = plt.get_cmap(cmap, da.sizes["year"]).colors
    with plt.rc_context({"axes.prop_cycle": plt.cycler(color=colors)}):
        facet = da.plot(**kwargs)

    for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
        ax.set_ylabel(ds[sel_dict["variable"]].attrs["units"])
        ax.grid()
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%d %b"))
        ax.xaxis.set_tick_params(rotation=45)

    scalar_mappable = plt.cm.ScalarMappable(
        cmap=cmap,
        norm=plt.Normalize(vmin=da["year"].min(), vmax=da["year"].max()),
    )
    facet.fig.colorbar(scalar_mappable, ax=facet.axs, label="year")
    return facet

## Plot yearly extremes

In [None]:
for region, da_region in ds["extent"].groupby("region"):
    for reduction in ("min", "max"):
        da = compute_yearly_extremes(da_region, reduction)
        da.plot(hue="product", marker="^")
        plt.grid()
        plt.show()

## Plot day of year

In [None]:
for product, ds_product in ds.drop("rms_error").groupby("product"):
    for region, ds_region in ds_product.groupby("region"):
        facet = plot_against_dayofyear(ds_region)
        facet.fig.suptitle(f"{product=} {region=}", y=1.01)
        plt.show()