# Sea ice diagnostics for different CMIP6 experiments

## Import libraries

In [None]:
import datetime

import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
year_start = 1850
year_stop = 2100
assert year_start >= 1850 and year_stop <= 2100

# Sea Ice Concentration Threshold
sic_threshold = 30  # %

# Models
experiments = ["historical", "ssp1_2_6", "ssp2_4_5", "ssp3_7_0", "ssp5_8_5"]
assert set(experiments) & {
    "historical",
    "ssp1_1_9",
    "ssp1_2_6",
    "ssp2_4_5",
    "ssp3_7_0",
    "ssp4_3_4",
    "ssp4_6_0",
    "ssp5_8_5",
}

# Chunks for download
chunks = {"year": 10}

## Define models for each experiment

In [None]:
models_dict = {
    "historical": [
        "access_cm2",
        "access_esm1_5",
        "cams_csm1_0",
        "canesm5",
        "canesm5_canoe",
        "cmcc_cm2_hr4",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "cnrm_esm2_1",
        "e3sm_1_0",
        "e3sm_1_1",
        "e3sm_1_1_eca",
        "ec_earth3_aerchem",
        "ec_earth3_cc",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fio_esm_2_0",
        "giss_e2_1_h",
        "hadgem3_gc31_ll",
        "hadgem3_gc31_mm",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm5a2_inca",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_hr",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "norcpm1",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp1_1_9": [
        "canesm5",
        "ec_earth3",
        "ec_earth3_veg",
        "ec_earth3_veg_lr",
        "fgoals_g3",
        "gfdl_esm4",
        "ipsl_cm6a_lr",
        "miroc6",
        "miroc_es2l",
        "mri_esm2_0",
        "ukesm1_0_ll",
    ],
    "ssp1_2_6": [
        "access_cm2",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "fio_esm_2_0",
        "gfdl_esm4",
        "hadgem3_gc31_ll",
        "hadgem3_gc31_mm",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm5a2_inca",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "noresm2_lm",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp2_4_5": [
        "access_cm2",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "cnrm_esm2_1",
        "ec_earth3_cc",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "fio_esm_2_0",
        "hadgem3_gc31_ll",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "noresm2_lm",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp3_7_0": [
        "access_cm2",
        "canesm5",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "ec_earth3_aerchem",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "gfdl_esm4",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm5a2_inca",
        "ipsl_cm6a_lr",
        "miroc6",
        "miroc_es2l",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "noresm2_lm",
        "noresm2_mm",
        "taiesm1",
        "ukesm1_0_ll",
    ],
    "ssp4_3_4": [
        "canesm5",
        "ec_earth3",
        "fgoals_g3",
        "ipsl_cm6a_lr",
        "miroc6",
        "mri_esm2_0",
        "ukesm1_0_ll",
    ],
    "ssp4_6_0": [
        "canesm5",
        "fgoals_g3",
        "ipsl_cm6a_lr",
        "miroc6",
        "mri_esm2_0",
    ],
    "ssp5_8_5": [
        "access_cm2",
        "canesm5_canoe",
        "cmcc_cm2_sr5",
        "cmcc_esm2",
        "cnrm_cm6_1",
        "cnrm_cm6_1_hr",
        "e3sm_1_1",
        "ec_earth3_cc",
        "ec_earth3_veg_lr",
        "fgoals_f3_l",
        "fgoals_g3",
        "fio_esm_2_0",
        "gfdl_esm4",
        "hadgem3_gc31_ll",
        "hadgem3_gc31_mm",
        "inm_cm4_8",
        "inm_cm5_0",
        "ipsl_cm6a_lr",
        "kiost_esm",
        "miroc6",
        "mpi_esm1_2_lr",
        "mri_esm2_0",
        "nesm3",
        "noresm2_lm",
        "noresm2_mm",
        "ukesm1_0_ll",
    ],
}

## Define request

In [None]:
all_months = [f"{month:02d}" for month in range(1, 12 + 1)]

request_era5 = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "sea_ice_cover",
        "year": [
            str(year)
            for year in range(
                max(year_start, 1940), min(year_stop + 1, datetime.date.today().year)
            )
        ],
        "month": all_months,
    },
)

request_cmip6_historical = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "sea_ice_area_percentage_on_ocean_grid",
        "year": [
            str(year) for year in range(max(year_start, 1850), min(year_stop, 2014) + 1)
        ],
        "month": all_months,
    },
)

request_cmip6_projections = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "sea_ice_area_percentage_on_ocean_grid",
        "year": [
            str(year) for year in range(max(year_start, 2015), min(year_stop, 2100) + 1)
        ],
        "month": all_months,
    },
)

## Define function to cache

In [None]:
def compute_sea_ice_diagnostics(ds, sic_threshold, grid_cell_area):
    # Get variable
    (varname,) = set(ds.data_vars) & {"ice_conc", "siconc"}
    da = ds[varname]

    # Compute diagnostics
    sic_is_normalized = da.attrs.get("units", "") == "(0 - 1)"
    siextent = xr.where(
        da > (sic_threshold / 100 if sic_is_normalized else sic_threshold),
        grid_cell_area,
        0,
    )
    siarea = da * (grid_cell_area if sic_is_normalized else grid_cell_area / 100)

    # Merge and add attributes
    ds = xr.merge([siextent.rename("siextent"), siarea.rename("siarea")])
    ds = ds.sum(set(ds.dims) - {"time"}) * 1.0e-6
    for var in ds.data_vars:
        ds[var].attrs = {
            "standard_name": var.replace("si", "sea_ice_", 1),
            "units": "$10^6km^2$",
            "long_name": var.replace("si", "Sea ice ", 1),
        }
    return ds


def get_interpolated_sea_ice_extent_and_area(ds, sic_threshold, **regrid_kwargs):
    """
    Interpolate to 25x25km grid and return sea ice diagnostics.

    Parameters
    ----------
    ds: xr.Dataset
        Dataset to process
    sic_threshold: float
        Sea ice concentration threshold (%)
    regrid_kwargs: Any
        xesmf regrid kwargs

    Returns
    -------
    xr.Dataset
        Dataset with siarea and siextend (km2)
    """
    # Parameters
    grid_cell_area = 25**2  # km2

    # Time resample
    ds = ds.sortby("time").resample(time="MS").mean()
    ds["time"].attrs["long_name"] = "time"

    if not regrid_kwargs:
        return compute_sea_ice_diagnostics(ds, sic_threshold, grid_cell_area)

    # Grid for interpolation
    collection_id_grid = "satellite-sea-ice-concentration"
    request_grid = {
        "version": "v2",
        "variable": "all",
        "format": "zip",
        "origin": "esa_cci",
        "cdr_type": "cdr",
        "year": "2002",
        "month": "06",
        "day": "01",
    }

    # Loop over north and south
    datasets = []
    for region in ("northern_hemisphere", "southern_hemisphere"):
        # Regrid
        request_grid["region"] = region
        grid_out = download.download_and_transform(
            collection_id_grid,
            request_grid,
            transform_func=lambda ds: ds.drop_dims("time"),
        )
        ds_region = diagnostics.regrid(ds, grid_out, **regrid_kwargs)

        # Compute sea ice diagnostics
        ds_region = compute_sea_ice_diagnostics(
            ds_region, sic_threshold, grid_cell_area
        )

        # Sum and append
        datasets.append(ds_region.expand_dims(region=[region]))
    return xr.concat(datasets, "region")

## Download ERA5 and CMIP6 models

In [None]:
def postprocess_dataset(ds):
    ds = ds.convert_calendar("standard", align_on="date")
    ds = ds.rename({var: da.attrs["long_name"] for var, da in ds.data_vars.items()})
    ds["region"] = [
        "Arctic" if region.startswith("northern") else "Antarctic"
        for region in ds["region"].values
    ]
    return ds


common_kwargs = {
    "chunks": chunks,
    "transform_func": get_interpolated_sea_ice_extent_and_area,
    "transform_func_kwargs": {
        "sic_threshold": sic_threshold,
        "method": "bilinear",
        "periodic": True,
        "ignore_degenerate": True,
    },
    "combine": "nested",
    "concat_dim": "time",
    "drop_variables": ("type",),
}

ds_era5 = postprocess_dataset(
    download.download_and_transform(*request_era5, **common_kwargs)
)

datasets_cmip6 = {}
for experiment in experiments:
    request = (
        request_cmip6_historical
        if experiment == "historical"
        else request_cmip6_projections
    )
    if not request[1]["year"]:
        continue
    tmp_datasets = []
    for model in models_dict[experiment]:
        print(f"{experiment=} {model=}")
        ds = download.download_and_transform(
            request[0],
            request[1] | {"experiment": experiment, "model": model},
            **common_kwargs,
        )
        tmp_datasets.append(postprocess_dataset(ds).expand_dims(model=[model]))
    datasets_cmip6[experiment] = xr.concat(tmp_datasets, "model").compute()

## Plot yearly timeseries

In [None]:
resample_freq = "Y"

quantiles = {
    experiment: ds.resample(time=resample_freq)
    .mean()
    .quantile([0, 1 / 3, 1 / 2, 2 / 3, 1], dim="model")
    .to_array()
    for experiment, ds in datasets_cmip6.items()
}

facet = (
    ds_era5.resample(time=resample_freq)
    .mean()
    .to_array()
    .plot(row="variable", col="region", label="ERA5")
)
for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
    for i, (experiment, da_quantiles) in enumerate(quantiles.items()):
        color = f"C{i+1}"
        da = da_quantiles.sel(sel_dict)
        ax.plot(
            da["time"],
            da.sel(quantile=1 / 2),
            color=color,
            label=f"CMIP6 {experiment:^10} median",
            zorder=2,
        )
        ax.fill_between(
            da["time"],
            da.sel(quantile=1 / 3),
            da.sel(quantile=2 / 3),
            color=color,
            alpha=0.4,
            label=f"CMIP6 {experiment:^10} tertiles",
            zorder=1,
        )
        ax.fill_between(
            da["time"],
            da.sel(quantile=0),
            da.sel(quantile=1),
            color=color,
            alpha=0.2,
            label=f"CMIP6 {experiment:^10} range",
            zorder=0,
        )
        ax.grid()

for ax, sel_dict in zip(facet.axs[:, 0], facet.name_dicts[:, 0]):
    variable = sel_dict.pop("variable")
    da = ds_era5.sel(sel_dict)[variable]
    ax.set_ylabel(f"[{da.attrs['units']}]")

_ = facet.axs[0, -1].legend(bbox_to_anchor=(1.1, 1))