# Evaluation of sea ice concentration of the historical CMIP6 experiment

## Import libraries

In [None]:
import datetime
import warnings

import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")
warnings.filterwarnings("ignore", module="cf_xarray")

## Set parameters

In [None]:
# Time
year_start = 1970
year_stop = 2020
assert year_start >= 1970 and year_stop <= 2020

# Sea Ice Concentration Threshold
sic_threshold = 30  # %

# Choose CMIP6 historical models
models = [
    "access_cm2",
    "access_esm1_5",
    "cams_csm1_0",
    "canesm5",
    "canesm5_canoe",
    "cmcc_cm2_hr4",
    "cmcc_cm2_sr5",
    "cmcc_esm2",
    "cnrm_cm6_1",
    "cnrm_cm6_1_hr",
    "cnrm_esm2_1",
    "e3sm_1_0",
    "e3sm_1_1",
    "e3sm_1_1_eca",
    "ec_earth3_aerchem",
    "ec_earth3_cc",
    "ec_earth3_veg_lr",
    "fgoals_f3_l",
    "fio_esm_2_0",
    "giss_e2_1_h",
    "hadgem3_gc31_ll",
    "hadgem3_gc31_mm",
    "inm_cm4_8",
    "inm_cm5_0",
    "ipsl_cm5a2_inca",
    "ipsl_cm6a_lr",
    "kiost_esm",
    "miroc6",
    "miroc_es2l",
    "mpi_esm1_2_hr",
    "mpi_esm1_2_lr",
    "mri_esm2_0",
    "nesm3",
    "norcpm1",
    "noresm2_mm",
    "taiesm1",
    "ukesm1_0_ll",
]

## Define request

In [None]:
all_months = [f"{month:02d}" for month in range(1, 13)]

request_era5 = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "sea_ice_cover",
        "year": [
            str(year)
            for year in range(
                max(year_start, 1940), min(year_stop + 1, datetime.date.today().year)
            )
        ],
        "month": all_months,
    },
)

request_cmip6_historical = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "sea_ice_area_percentage_on_ocean_grid",
        "year": [
            str(year) for year in range(max(year_start, 1850), min(year_stop, 2014) + 1)
        ],
        "month": all_months,
    },
)

## Define function to cache

In [None]:
def compare_model_vs_obs(ds, datasets_sat, sic_threshold, grid_cell_area):
    ds = ds.convert_calendar("standard", align_on="date")
    datasets_sat = {
        k: ds.convert_calendar("standard", align_on="date")
        for k, ds in datasets_sat.items()
    }

    grid_cell_area *= 1.0e-6  # 10^6 km2
    sic = ds.cf["sea_ice_area_fraction"]
    if sic.attrs.get("units", "") == "(0 - 1)":
        sic *= 100

    dims = ("xc", "yc")
    datasets = []
    for origin, ds_sat in datasets_sat.items():
        # Get variables
        sic_obs = ds_sat.cf["sea_ice_area_fraction"]
        sic_obs_err = ds_sat.cf["sea_ice_area_fraction standard_error"]
        sic_model = sic.sel(time=sic_obs["time"])

        # Compute useful variables
        sic_diff = sic_model - sic_obs
        over = ((sic_model > sic_threshold) & (sic_obs <= sic_threshold)).sum(dims)
        under = ((sic_model <= sic_threshold) & (sic_obs > sic_threshold)).sum(dims)

        # Compute output
        dataarrays = {}
        dataarrays["siconc_bias"] = sic_diff.mean(dims)
        dataarrays["siconc_bias"].attrs = {
            "standard_name": "sea_ice_concentration_bias",
            "units": "%",
            "long_name": "Sea ice concentration bias",
        }

        dataarrays["siconc_rmse"] = (sic_diff**2).mean(dim=dims) ** (1 / 2)
        dataarrays["siconc_rmse"].attrs = {
            "standard_name": "sea_ice_concentration_rmse",
            "units": "%",
            "long_name": "Sea ice concentration root mean square error",
        }

        dataarrays["rms_sic_obs_error"] = (sic_obs_err**2).mean(dims) ** (1 / 2)
        dataarrays["rms_sic_obs_error"].attrs = {
            "standard_name": "root_mean_square_sea_ice_concentration_observation_error",
            "units": "%",
            "long_name": "Root mean square sea ice concentration observation error",
        }

        dataarrays["iiee"] = (over + under) * grid_cell_area
        dataarrays["iiee"].attrs = {
            "standard_name": "integrated_ice_edge_error",
            "units": "$10^6$km$^2$",
            "long_name": "Integrated ice edge error",
        }

        dataarrays["siextent_bias"] = (over - under) * grid_cell_area
        dataarrays["siextent_bias"].attrs = {
            "standard_name": "sea_ice_extent_bias",
            "units": "$10^6$km$^2$",
            "long_name": "Sea ice extent bias",
        }

        dataarrays["siarea_bias"] = (
            (sic_model.sum(dims) - sic_obs.sum(dims)) * grid_cell_area / 100
        )
        dataarrays["siarea_bias"].attrs = {
            "standard_name": "sea_ice_area_bias",
            "units": "$10^6$km$^2$",
            "long_name": "Sea ice area bias",
        }

        datasets.append(xr.Dataset(dataarrays).expand_dims(origin=[origin]))
    return xr.concat(datasets, "origin")


def interpolate_to_satellite_grid(obj, **regrid_kwargs):
    collection_id = "satellite-sea-ice-concentration"
    request = {
        "version": "v2",
        "variable": "all",
        "format": "zip",
        "origin": "esa_cci",
        "cdr_type": "cdr",
        "year": "2002",
        "month": "06",
        "day": "01",
    }
    interpolated_objs = []
    for region in ("northern_hemisphere", "southern_hemisphere"):
        grid_out = download.download_and_transform(
            collection_id, request | {"region": region}
        ).drop_dims("time")
        obj_out = diagnostics.regrid(obj, grid_out, **regrid_kwargs)
        interpolated_objs.append(obj_out.expand_dims(region=[region]))
    return xr.concat(interpolated_objs, "region")


def get_monthly_interpolated_data(ds, add_stde, check_values, **regrid_kwargs):
    if add_stde:
        stde = ds.cf["sea_ice_area_fraction standard_error"]

    ds = ds.cf[["latitude", "longitude", "sea_ice_area_fraction"]]
    ds = ds.drop_dims(set(ds.dims) & {"vertices", "bnds"})

    if regrid_kwargs:
        ds = interpolate_to_satellite_grid(ds, **regrid_kwargs)

    ds = ds.sortby("time").resample(time="MS").mean()
    ds["time"].attrs["long_name"] = "time"

    if add_stde:
        with xr.set_options(keep_attrs=True):
            ds = ds.merge((stde**2).resample(time="MS").mean() ** (1 / 2))

    if check_values:
        ds = ds.where(
            ds.cf["sea_ice_area_fraction"].any(set(ds.dims) - {"time"}).compute(),
            drop=True,
        )
    return ds


def get_satellite_data(time):
    year_start = time.dt.year.min().values
    year_stop = time.dt.year.max().values

    common_request = {
        "cdr_type": "cdr",
        "variable": "all",
        "version": "v2",
    }
    satellite_requests = {
        "ESA-CCI": download.update_request_date(
            common_request | {"origin": "esa_cci"},
            start=f"{max(year_start, 2002)}-01",
            stop=f"{min(year_stop, 2017)}-12",
            stringify_dates=True,
        ),
        "EUMETSAT-OSI-SAF": download.update_request_date(
            common_request | {"origin": "eumetsat_osi_saf"},
            start=f"{max(year_start, 1979)}-01",
            stop=f"{min(year_stop, 2015)}-12",
            stringify_dates=True,
        ),
    }

    datasets_satellite = {}
    for name, requests in satellite_requests.items():
        if not requests:
            continue
        tmp_datasets = []
        for region in ("northern_hemisphere", "southern_hemisphere"):
            print(f"{name=} {region=}")
            ds = download.download_and_transform(
                "satellite-sea-ice-concentration",
                [request | {"region": region} for request in requests],
                chunks={"year": 1},
                transform_func=get_monthly_interpolated_data,
                transform_func_kwargs={"add_stde": True, "check_values": True},
            )
            tmp_datasets.append(ds.expand_dims(region=[region]))
        datasets_satellite[name] = xr.concat(tmp_datasets, "region")
    return datasets_satellite


def compute_sea_ice_evaluation_diagnostics(ds, sic_threshold, **regrid_kwargs):
    datasets_sat = get_satellite_data(ds["time"])
    ds = get_monthly_interpolated_data(
        ds, add_stde=False, check_values=False, **regrid_kwargs
    )
    return compare_model_vs_obs(ds, datasets_sat, sic_threshold, grid_cell_area=25**2)

## Utilities

In [None]:
def postprocess_dataset(ds):
    ds = ds.rename(
        {
            var: da.attrs["long_name"]
            .lower()
            .replace("sea ice ", "")
            .replace("concentration ", "")
            .replace("observation", "obs")
            .replace("root mean square", "RMS")
            if var != "iiee"
            else var.upper()
            for var, da in ds.data_vars.items()
        }
    )
    ds["region"] = [
        "Arctic" if region.startswith("northern") else "Antarctic"
        for region in ds["region"].values
    ]
    return ds.compute()


kwargs = {
    "transform_func": compute_sea_ice_evaluation_diagnostics,
    "transform_func_kwargs": {
        "sic_threshold": sic_threshold,
        "method": "bilinear",
        "periodic": True,
        "ignore_degenerate": True,
    },
    "chunks": {"year": 10},
    # Parameters to speed up IO
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "drop_variables": ("type",),
}

## Download and transform ERA5

In [None]:
ds_era5 = postprocess_dataset(download.download_and_transform(*request_era5, **kwargs))

## Download and transform CMIP6

In [None]:
datasets = []
for model in models:
    print(f"{model=}")
    ds = download.download_and_transform(
        request_cmip6_historical[0],
        request_cmip6_historical[1] | {"model": model},
        **kwargs,
    )
    datasets.append(postprocess_dataset(ds).expand_dims(model=[model]))
ds_cmip6 = xr.concat(datasets, "model")

## Define plotting function

In [None]:
def plot_timeseries(ds_era5, ds_cmip6, func=None, title=None, **kwargs):
    if func:
        ds_era5 = func(ds_era5, **kwargs)
        ds_cmip6 = func(ds_cmip6, **kwargs)
    else:
        assert not kwargs, f"{func=} but {kwargs=}"
    da_era5 = ds_era5.to_array()
    da_cmip6 = ds_cmip6.to_array()

    for i, (origin, da) in enumerate(da_era5.groupby("origin")):
        kwargs = {
            "color": f"C{i}",
            "label": f"ERA5 vs {origin}",
        }
        if not i:
            facet = da.plot(
                row="variable", col="region", hue="origin", sharey=False, **kwargs
            )
        else:
            for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
                ax.plot(da["time"], da.sel(sel_dict), **kwargs)

    # Plot CMIP6
    da_quantiles = da_cmip6.quantile([1 / 4, 1 / 2, 3 / 4], dim="model")
    for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
        for j, (origin, da) in enumerate(da_quantiles.sel(sel_dict).groupby("origin")):
            kwargs = {"color": f"C{i+j+1}"}
            ax.plot(
                da["time"],
                da.sel(quantile=1 / 2),
                label=f"CMIP6 vs {origin} median",
                zorder=2,
                **kwargs,
            )
            ax.fill_between(
                da["time"],
                da.sel(quantile=1 / 4),
                da.sel(quantile=3 / 4),
                alpha=0.4,
                label=f"CMIP6 vs {origin} IQL",
                zorder=1,
                **kwargs,
            )
            ax.grid(linestyle=":")

    # Edit axs
    for ax, sel_dict in zip(facet.axs[:, 0], facet.name_dicts[:, 0]):
        variable = sel_dict.pop("variable")
        da = ds_era5.sel(sel_dict)[variable]
        ax.set_ylabel(f"[{da.attrs['units']}]")
    facet.axs[0, -1].legend(bbox_to_anchor=(1.1, 1))
    if title is not None:
        facet.fig.suptitle(title)

    return facet

# Plot timeseries

In [None]:
_ = plot_timeseries(ds_era5, ds_cmip6)