# Sea ice diagnostics for different CMIP6 experiments

## Import libraries

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
year_start = 1985
year_stop = 1986

# Sea Ice Concentration Threshold
sic_threshold = 30  # %

# Models
models = [
    "access_cm2",
    "access_esm1_5",
    "cams_csm1_0",
    "cmcc_cm2_hr4",
    "cmcc_cm2_sr5",
    "cmcc_esm2",
    "cnrm_cm6_1",
    "cnrm_cm6_1_hr",
    "cnrm_esm2_1",
    "canesm5",
    "canesm5_canoe",
    "e3sm_1_0",
    "e3sm_1_1",
    "e3sm_1_1_eca",
    "ec_earth3_aerchem",
    "ec_earth3_cc",
    "ec_earth3_veg_lr",
    "fgoals_f3_l",
    "fio_esm_2_0",
    "giss_e2_1_h",
    "hadgem3_gc31_ll",
    "hadgem3_gc31_mm",
    "inm_cm4_8",
    "inm_cm5_0",
    "ipsl_cm5a2_inca",
    "ipsl_cm6a_lr",
    "kiost_esm",
    "miroc_es2l",
    "miroc6",
    "mpi_esm1_2_hr",
    "mpi_esm1_2_lr",
    "mri_esm2_0",
    "nesm3",
    "norcpm1",
    "noresm2_mm",
    "taiesm1",
    "ukesm1_0_ll",
]

# Chunks for download
chunks = {"year": 1}

## Define request

In [None]:
common_request = {
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 12 + 1)],
}

request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "sea_ice_cover",
        **common_request,
    },
)

request_sim = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "sea_ice_area_percentage_on_ocean_grid",
        **common_request,
    },
)

## Define function to cache

In [None]:
def transform_grid(ds):
    for coord in ("longitude", "latitude"):
        if coord not in ds.cf.bounds:
            ds = ds.cf.add_bounds(coord)
    return ds.drop_dims("time")


def interpolate_and_get_sea_ice_diagnostics(ds, sic_threshold):
    """
    Interpolate to 25x25km grid and return sea ice diagnostics.

    Parameters
    ----------
    ds: xr.Dataset
        Dataset to process
    sic_threshol: float
        Sea ice concentration threshold (%)

    Returns
    -------
    xr.Dataset
    """
    # Grid for interpolation
    collection_id_grid = "satellite-sea-ice-concentration"
    request_grid = {
        "version": "v2",
        "variable": "all",
        "format": "zip",
        "origin": "esa_cci",
        "cdr_type": "cdr",
        "year": "2002",
        "month": "06",
        "day": "01",
    }
    grid_cell_area = 25**2  # km2

    # Same time for all datasets
    ds["time"] = pd.to_datetime(ds["time"].dt.strftime("%Y-%m-01"))
    sic_is_normalized = ds["siconc"].attrs.get("units", "") == "(0 - 1)"
    if sic_is_normalized:
        sic_threshold /= 100

    # Loop over north and south
    datasets = []
    for region in ("northern_hemisphere", "southern_hemisphere"):
        request_grid["region"] = region
        grid_out = download.download_and_transform(
            collection_id_grid, request_grid, transform_func=transform_grid
        )

        # Regrid
        ds_region = diagnostics.regrid(ds, grid_out, method="conservative")

        # Compute sea ice diagnostics
        siextent = xr.where(ds_region["siconc"] > sic_threshold, grid_cell_area, 0)
        siarea = ds_region["siconc"] * (grid_cell_area / int(sic_is_normalized or 100))
        ds_region = xr.merge([siextent.rename("siextent"), siarea.rename("siarea")])

        # Sum and append
        ds_region = ds_region.sum(set(ds_region.dims) - {"time"})
        datasets.append(ds_region.expand_dims(region=[region]))
    ds = xr.merge(datasets)

    # Add attributes
    for var, da in ds.data_vars.items():
        ds[var].attrs = {
            "standard_name": var.replace("si", "sea_ice_", 1),
            "units": "km2",
            "long_name": var.replace("si", "Sea ice ", 1),
        }
    return ds

## Download and combine dataset

In [None]:
datasets = []
for model in ["ERA5"] + models:
    print(f"Downloading and processing {model}")
    if model == "ERA5":
        request_model = request_era
    else:
        request_model = request_sim
        request_model[1]["model"] = model
    ds = download.download_and_transform(
        *request_model,
        chunks=chunks,
        transform_func=interpolate_and_get_sea_ice_diagnostics,
        transform_func_kwargs={
            "sic_threshold": sic_threshold,
        },
        drop_variables=("type",),
    )
    datasets.append(ds.expand_dims(model=[model]))
ds = xr.concat(datasets, "model")

## Plot variables

In [None]:
for da in ds.data_vars.values():
    da.plot(col="region", hue="model")
    plt.show()