# CMIP6 sea ice thickness evaluation

## Import libraries

In [None]:
import warnings

import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")
warnings.filterwarnings("ignore", module="cf_xarray")

## Set parameters

In [None]:
year_start = 2002
year_stop = 2014
assert year_start >= 2002 and year_stop <= 2014

# Choose CMIP6 historical models
models = [
    "access_cm2",
    "access_esm1_5",
    "canesm5",
    "cmcc_cm2_sr5",
    "cmcc_esm2",
    "cnrm_cm6_1",
    "cnrm_cm6_1_hr",
    "cnrm_esm2_1",
    "e3sm_1_0",
    "e3sm_1_1",
    "e3sm_1_1_eca",
    "ec_earth3_aerchem",
    "ec_earth3_cc",
    "ec_earth3_veg_lr",
    "hadgem3_gc31_ll",
    "ipsl_cm5a2_inca",
    "ipsl_cm6a_lr",
    "miroc6",
    "miroc_es2l",
    "mpi_esm1_2_hr",
    "mpi_esm1_2_lr",
    "nesm3",
    "norcpm1",
    "taiesm1",
    "ukesm1_0_ll",
]

## Define request

In [None]:
months = [f"{month:02d}" for month in [1, 2, 3, 4, 10, 11, 12]]
collection_id = "projections-cmip6"
request = {
    "format": "zip",
    "temporal_resolution": "monthly",
    "experiment": "historical",
    "variable": "sea_ice_thickness",
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": months,
}

## Functions to cache

In [None]:
def get_satellite_data(time):
    year_start = time.dt.year.min().values
    year_stop = time.dt.year.max().values

    year_ranges = {
        "envisat": range(max(2002, year_start), min(2010, year_stop) + 1),
        "cryosat_2": range(max(2010, year_start), min(2020, year_stop) + 1),
    }

    datasets = []
    for satellite, year_range in year_ranges.items():
        if not year_range:
            continue
        ds = download.download_and_transform(
            "satellite-sea-ice-thickness",
            {
                "satellite": satellite,
                "version": "2_0",
                "cdr_type": "cdr",
                "variable": "all",
                "year": list(map(str, year_range)),
                "month": [f"{month:02d}" for month in [1, 2, 3, 4, 10, 11, 12]],
            },
            chunks={"year": 1},
        )
        datasets.append(ds)
    return xr.concat(datasets, "time")


def regrid(obj, grid_out, **regrid_kwargs):
    # Remove nan columns
    for dim in [dim for dim in obj.dims if "x" in dim or "lon" in dim]:
        for i in (0, -1):
            if obj.isel({dim: i}).isnull().all():
                obj = obj.drop_isel({dim: i})
    return diagnostics.regrid(obj, grid_out, **regrid_kwargs)


def compare_model_vs_satellite(ds, ds_satellite):
    # Homogenize time
    ds = ds.convert_calendar("standard", align_on="date")
    ds["time"] = pd.to_datetime(ds["time"].dt.strftime("%Y-%m"))
    ds_satellite = ds_satellite.convert_calendar("standard", align_on="date")

    # Get variables
    dims = ("xc", "yc")
    std_name = "sea_ice_thickness"
    sit = ds.cf[std_name]
    sit_obs = ds_satellite.cf[std_name]
    sit_obs_err = ds_satellite.cf[f"{std_name} standard_error"]
    sit_model = sit.sel(time=sit_obs["time"])

    # Compute useful variables
    sit_diff = sit_model - sit_obs

    # Compute output
    dataarrays = {}
    dataarrays["sithick_bias"] = sit_diff.mean(dim=dims)
    dataarrays["sithick_bias"].attrs = {
        "standard_name": "sea_ice_thickness_bias",
        "units": "m",
        "long_name": "Sea ice thickness bias",
    }

    dataarrays["sithick_rmse"] = (sit_diff**2).mean(dim=dims) ** (1 / 2)
    dataarrays["sithick_rmse"].attrs = {
        "standard_name": "sea_ice_thickness_rmse",
        "units": "m",
        "long_name": "Sea ice thickness root mean square error",
    }

    dataarrays["rms_sit_obs_error"] = (sit_obs_err**2).mean(dim=dims) ** (1 / 2)
    dataarrays["rms_sit_obs_error"].attrs = {
        "standard_name": "root_mean_square_sea_ice_thickness_observation_error",
        "units": "m",
        "long_name": "Root mean square sea ice thickness observation error",
    }

    return xr.Dataset(dataarrays)


def compute_sea_ice_thickness_diagnostics(ds, **regrid_kwargs):
    ds_satellite = get_satellite_data(ds["time"])
    ds = regrid(ds, ds_satellite[["latitude", "longitude"]], **regrid_kwargs)
    return compare_model_vs_satellite(ds, ds_satellite)

## Download and transform

In [None]:
datasets = []
for model in models:
    print(f"{model=}")
    ds = download.download_and_transform(
        collection_id,
        request | {"model": model},
        chunks={"year": 1},
        transform_func=compute_sea_ice_thickness_diagnostics,
        transform_func_kwargs={
            "method": "nearest_s2d",
            "periodic": True,
            "ignore_degenerate": True,
        },
    )
    datasets.append(ds.expand_dims(model=["model"]).compute())
ds_cmip6 = xr.concat(datasets, "model")

## Plot quantiles

In [None]:
ds_quantiles = ds_cmip6.quantile([1 / 4, 1 / 2, 3 / 4], dim="model", keep_attrs=True)
for var, da in ds_quantiles.data_vars.items():
    da = da.resample(time="1MS").mean()
    da.sel(quantile=1 / 2).drop_vars("quantile").plot(label="median")
    plt.fill_between(
        da["time"],
        da.sel(quantile=1 / 4),
        da.sel(quantile=3 / 4),
        alpha=0.5,
        label="IQL",
    )
    plt.legend()
    plt.grid()
    plt.show()