# Sea ice diagnostics for different CMIP6 experiments

## Import libraries

In [None]:
import datetime
import warnings

import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")
warnings.filterwarnings("ignore", module="cf_xarray")

## Set parameters

In [None]:
# Time
year_start = 1850
year_stop = 2100
assert year_start >= 1850 and year_stop <= 2100
assert not year_start % 10

# Sea Ice Concentration Threshold
sic_threshold = 30  # %

# Models
experiments = ["historical", "ssp1_2_6", "ssp2_4_5", "ssp3_7_0", "ssp5_8_5"]
assert set(experiments) & {
    "historical",
    "ssp1_1_9",
    "ssp1_2_6",
    "ssp2_4_5",
    "ssp3_7_0",
    "ssp4_3_4",
    "ssp4_6_0",
    "ssp5_8_5",
}

## Define models for each experiment

In [None]:
models_dict = {
    "historical": [
        "access_cm2",
        "access_esm1_5",
        "cams_csm1_0",
        "canesm5",
        "canesm5_canoe",
        "cmcc_cm2_hr4",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "cnrm_esm2_1",
        "e3sm_1_0",
        "e3sm_1_1",
        "e3sm_1_1_eca",
        "ec_earth3_aerchem",
        "ec_earth3_cc",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fio_esm_2_0",
        "giss_e2_1_h",
        "hadgem3_gc31_ll",
        "hadgem3_gc31_mm",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm5a2_inca",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_hr",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "norcpm1",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp1_1_9": [
        "canesm5",
        "ec_earth3",
        "ec_earth3_veg",
        "ec_earth3_veg_lr",
        "fgoals_g3",
        "gfdl_esm4",
        "ipsl_cm6a_lr",
        "miroc6",
        "miroc_es2l",
        "mri_esm2_0",
        "ukesm1_0_ll",
    ],
    "ssp1_2_6": [
        "access_cm2",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "fio_esm_2_0",
        "gfdl_esm4",
        "hadgem3_gc31_ll",
        "hadgem3_gc31_mm",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm5a2_inca",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "noresm2_lm",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp2_4_5": [
        "access_cm2",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "cnrm_esm2_1",
        "ec_earth3_cc",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "fio_esm_2_0",
        "hadgem3_gc31_ll",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "noresm2_lm",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp3_7_0": [
        "access_cm2",
        "canesm5",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "ec_earth3_aerchem",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "gfdl_esm4",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm5a2_inca",
        "ipsl_cm6a_lr",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "noresm2_lm",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp4_3_4": [
        "canesm5",
        "ec_earth3",
        "fgoals_g3",
        "ipsl_cm6a_lr",
        "miroc6",
        "mri_esm2_0",
        "ukesm1_0_ll",
    ],
    "ssp4_6_0": [
        "canesm5",
        "fgoals_g3",
        "ipsl_cm6a_lr",
        "miroc6",
        "mri_esm2_0",
    ],
    "ssp5_8_5": [
        "access_cm2",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "e3sm_1_1",
        "ec_earth3_cc",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "fio_esm_2_0",
        "gfdl_esm4",
        "hadgem3_gc31_ll",
        "hadgem3_gc31_mm",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "noresm2_lm",
        "noresm2_mm",
        "ukesm1_0_ll",
    ],
}

## Define request

In [None]:
all_months = [f"{month:02d}" for month in range(1, 13)]

request_era5 = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "sea_ice_cover",
        "year": [
            str(year)
            for year in range(
                max(year_start, 1940), min(year_stop + 1, datetime.date.today().year)
            )
        ],
        "month": all_months,
    },
)

request_cmip6_historical = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "sea_ice_area_percentage_on_ocean_grid",
        "year": [
            str(year) for year in range(max(year_start, 1850), min(year_stop, 2014) + 1)
        ],
        "month": all_months,
    },
)

request_cmip6_projections = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "sea_ice_area_percentage_on_ocean_grid",
        "year": [
            str(year) for year in range(max(year_start, 2015), min(year_stop, 2100) + 1)
        ],
        "month": all_months,
    },
)

request_eumetsat = (
    "satellite-sea-ice-concentration",
    download.update_request_date(
        {
            "cdr_type": "cdr",
            "origin": "eumetsat_osi_saf",
            "variable": "all",
            "version": "v2",
        },
        start=f"{max(year_start, 1979)}-01",
        stop=f"{min(year_stop, 2015)}-12",
        stringify_dates=True,
    ),
)

request_esa = (
    "satellite-sea-ice-concentration",
    download.update_request_date(
        {
            "cdr_type": "cdr",
            "origin": "esa_cci",
            "variable": "all",
            "version": "v2",
        },
        start=f"{max(year_start, 2002)}-01",
        stop=f"{min(year_stop, 2017)}-12",
        stringify_dates=True,
    ),
)

## Define function to cache

In [None]:
def compute_extent_and_area_from_sic(ds, sic_threshold, grid_cell_area):
    sic = ds.cf["sea_ice_area_fraction"]

    # Compute diagnostics
    sic_is_normalized = sic.attrs.get("units", "") == "(0 - 1)"
    siextent = xr.where(
        sic > (sic_threshold / 100 if sic_is_normalized else sic_threshold),
        grid_cell_area,
        0,
    )
    siarea = sic * (grid_cell_area if sic_is_normalized else grid_cell_area / 100)

    # Merge and add attributes
    ds = xr.merge([siextent.rename("siextent"), siarea.rename("siarea")])
    ds = ds.sum(("xc", "yc")) * 1.0e-6
    for var in ds.data_vars:
        ds[var].attrs = {
            "standard_name": var.replace("si", "sea_ice_", 1),
            "units": "$10^6$km$^2$",
            "long_name": var.replace("si", "Sea ice ", 1),
        }
    return ds


def interpolate_to_satellite_grid(obj, **regrid_kwargs):
    collection_id = "satellite-sea-ice-concentration"
    request = {
        "version": "v2",
        "variable": "all",
        "format": "zip",
        "origin": "esa_cci",
        "cdr_type": "cdr",
        "year": "2002",
        "month": "06",
        "day": "01",
    }
    interpolated_objs = []
    for region in ("northern_hemisphere", "southern_hemisphere"):
        grid_out = download.download_and_transform(
            collection_id, request | {"region": region}
        ).drop_dims("time")
        obj_out = diagnostics.regrid(obj, grid_out, **regrid_kwargs)
        interpolated_objs.append(obj_out.expand_dims(region=[region]))
    return xr.concat(interpolated_objs, "region")


def compute_interpolated_sea_ice_extent_and_area(ds, sic_threshold, **regrid_kwargs):
    """
    Interpolate to 25x25km grid and return sea ice diagnostics.

    Parameters
    ----------
    ds: xr.Dataset
        Dataset to process
    sic_threshold: float
        Sea ice concentration threshold (%)
    **regrid_kwargs: Any
        xesmf regrid kwargs

    Returns
    -------
    xr.Dataset
        Dataset with siarea and siextend (km2)
    """
    ds = ds.cf[["latitude", "longitude", "sea_ice_area_fraction"]]

    # Monthly resample
    ds = ds.sortby("time").resample(time="MS").mean()
    ds["time"].attrs["long_name"] = "time"

    if regrid_kwargs:
        ds = interpolate_to_satellite_grid(ds, **regrid_kwargs)

    return compute_extent_and_area_from_sic(ds, sic_threshold, grid_cell_area=25**2)

## Utilities

In [None]:
def postprocess_dataset(ds):
    ds = ds.convert_calendar("standard", align_on="date")
    ds = ds.rename({var: da.attrs["long_name"] for var, da in ds.data_vars.items()})
    ds["region"] = [
        "Arctic" if region.startswith("northern") else "Antarctic"
        for region in ds["region"].values
    ]
    return ds.compute()


common_kwargs = {
    "transform_func": compute_interpolated_sea_ice_extent_and_area,
    # Parameters to speed up IO
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "drop_variables": ("type",),
}
transform_func_kwargs = {"sic_threshold": sic_threshold}
interpolation_kwargs = {
    "method": "bilinear",
    "periodic": True,
    "ignore_degenerate": True,
}

## Download and transform ERA5

In [None]:
ds_era5 = postprocess_dataset(
    download.download_and_transform(
        *request_era5,
        transform_func_kwargs=transform_func_kwargs | interpolation_kwargs,
        chunks={"year": 10},
        **common_kwargs,
    )
)

## Download and transform satellites

In [None]:
datasets_satellite = {}
for name, (collection_id, requests) in zip(
    ("ESA-CCI", "EUMETSAT-OSI-SAF"), (request_esa, request_eumetsat)
):
    tmp_datasets = []
    for region in ("northern_hemisphere", "southern_hemisphere"):
        print(f"{name=} {region=}")
        ds = download.download_and_transform(
            collection_id,
            [request | {"region": region} for request in requests],
            transform_func_kwargs=transform_func_kwargs,
            chunks={"year": 1},
            **common_kwargs,
        )
        ds = ds.where(ds != 0).dropna("time")  # Missing months are filled with 0
        tmp_datasets.append(ds.expand_dims(region=[region]))
    datasets_satellite[name] = postprocess_dataset(xr.concat(tmp_datasets, "region"))

## Download and transform CMIP6

In [None]:
datasets_cmip6 = {}
for experiment in experiments:
    request = (
        request_cmip6_historical
        if experiment == "historical"
        else request_cmip6_projections
    )
    if not request[1]["year"]:
        continue
    tmp_datasets = []
    for model in models_dict[experiment]:
        print(f"{experiment=} {model=}")
        ds = download.download_and_transform(
            request[0],
            request[1] | {"experiment": experiment, "model": model},
            transform_func_kwargs=transform_func_kwargs | interpolation_kwargs,
            chunks={"year": 10},
            **common_kwargs,
        )
        tmp_datasets.append(postprocess_dataset(ds).expand_dims(model=[model]))
    datasets_cmip6[experiment] = xr.concat(tmp_datasets, "model")

## Define plotting function

In [None]:
def plot_timeseries(
    ds_era5, datasets_satellite, datasets_cmip6, func, title=None, **kwargs
):
    # Define colors
    colors = (f"C{i}" for i in range(len(datasets_satellite) + len(datasets_cmip6) + 1))

    # Get dataarrays
    da_era5 = func(ds_era5, **kwargs).to_array()
    dataarrays_satellite = {
        k: func(ds, **kwargs).to_array() for k, ds in datasets_satellite.items()
    }
    dataarrays_cmip6 = {
        k: func(ds, **kwargs).to_array() for k, ds in datasets_cmip6.items()
    }

    # Plot ERA5
    facet = da_era5.plot(
        row="variable",
        col="region",
        label="ERA5",
        color=next(colors),
        ls="--",
        zorder=20,
    )

    # Plot satellites
    for (satellite, da_satellite), ls, zorder in zip(
        dataarrays_satellite.items(), ("--", "-"), (11, 10)
    ):
        color = next(colors)
        for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
            da = da_satellite.sel(sel_dict)
            ax.plot(da["time"], da, label=satellite, color=color, ls=ls, zorder=zorder)

    # Plot CMIP6
    quantiles = {
        experiment: da.quantile([1 / 4, 1 / 2, 3 / 4], dim="model")
        for experiment, da in dataarrays_cmip6.items()
        if da.size
    }
    for experiment, da_quantiles in quantiles.items():
        color = next(colors)
        for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
            da = da_quantiles.sel(sel_dict)
            ax.plot(
                da["time"],
                da.sel(quantile=1 / 2),
                color=color,
                label=f"CMIP6 {experiment} median",
                zorder=2,
            )
            ax.fill_between(
                da["time"],
                da.sel(quantile=1 / 4),
                da.sel(quantile=3 / 4),
                color=color,
                alpha=0.4,
                label=f"CMIP6 {experiment} IQL",
                zorder=1,
            )
            ax.grid(linestyle=":")

    # Edit axs
    for ax, sel_dict in zip(facet.axs[:, 0], facet.name_dicts[:, 0]):
        variable = sel_dict.pop("variable")
        da = ds_era5.sel(sel_dict)[variable]
        ax.set_ylabel(f"[{da.attrs['units']}]")
    facet.axs[0, -1].legend(bbox_to_anchor=(1.1, 1))
    if title is not None:
        facet.fig.suptitle(title)
    return facet

# Plot sliced timeseries

In [None]:
for time_slice in (slice("1985", "2004"), slice("2005", "2024")):
    plot_timeseries(
        ds_era5,
        datasets_satellite,
        datasets_cmip6,
        func=lambda ds, indexers: ds.sel(indexers),
        indexers={"time": time_slice},
        title=f"{time_slice.start} - {time_slice.stop}",
    )
    plt.show()

# Plot max and min

In [None]:
def full_year_only_resample(ds, reduction):
    mask = ds["time"].resample(time="Y").count() == 12
    return getattr(ds.resample(time="Y"), reduction)().where(mask, drop=True)


for reduction in ("max", "min"):
    plot_timeseries(
        ds_era5.sel(time=slice("1980", None)),
        datasets_satellite,
        datasets_cmip6,
        func=full_year_only_resample,
        reduction=reduction,
        title=f"Yearly {reduction}ima",
    )
    plt.show()