# CMIP6 sea ice thickness evaluation

## Import libraries

In [None]:
import warnings

import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")
warnings.filterwarnings("ignore", module="cf_xarray")

## Set parameters

In [None]:
year_start = 2002
year_stop = 2014
assert year_start >= 2002 and year_stop <= 2014

sea_masks = [
    "arctic",
    "transpolar_shipping_route",
    "northern_sea_shipping_route",
]
assert set(sea_masks) <= {
    "arctic",
    "transpolar_shipping_route",
    "northern_sea_shipping_route",
}

# Choose CMIP6 historical models
models = [
    "access_cm2",
    "access_esm1_5",
    "cams_csm1_0",
    "canesm5",
    "canesm5_canoe",
    "cmcc_cm2_hr4",
    "cmcc_cm2_sr5",
    "cmcc_esm2",
    "cnrm_cm6_1",
    "cnrm_cm6_1_hr",
    "cnrm_esm2_1",
    "e3sm_1_0",
    "e3sm_1_1",
    "e3sm_1_1_eca",
    "ec_earth3_aerchem",
    "ec_earth3_cc",
    "ec_earth3_veg_lr",
    "fgoals_f3_l",
    "giss_e2_1_h",
    "hadgem3_gc31_ll",
    "hadgem3_gc31_mm",
    "inm_cm4_8",
    "inm_cm5_0",
    "ipsl_cm5a2_inca",
    "ipsl_cm6a_lr",
    "miroc6",
    "miroc_es2l",
    "mpi_esm1_2_hr",
    "mpi_esm1_2_lr",
    "mri_esm2_0",
    "nesm3",
    "norcpm1",
    "taiesm1",
    "ukesm1_0_ll",
]

## Define request

In [None]:
collection_id = "projections-cmip6"
request = {
    "format": "zip",
    "temporal_resolution": "monthly",
    "experiment": "historical",
    "variable": "sea_ice_thickness",
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in [1, 2, 3, 4, 10, 11, 12]],
}

## Functions to cache

In [None]:
def apply_sea_mask(obj, sea_mask):
    lon = obj["longitude"].where(obj["longitude"] >= 0, obj["longitude"] + 360)
    lat = obj["latitude"]

    if sea_mask == "transpolar_shipping_route":
        # Define approximate GODAE limits
        limits = {
            "Chuckchi_Sea": {
                "lon_min": 180.0,
                "lon_max": 200.0,
                "lat_min": 66.0,
                "lat_max": 90,
            },
            "Siberian_Laptev_Kara_Seas": {
                "lon_min": 35.0,
                "lon_max": 180.0,
                "lat_min": 83.0,
                "lat_max": 90,
            },
            "Barents_Sea": {
                "lon_min": 5.0,
                "lon_max": 35.0,
                "lat_min": 68.0,
                "lat_max": 90,
            },
        }
    elif sea_mask == "northern_sea_shipping_route":
        limits = {
            "Chuckchi_Sea": {
                "lon_min": 177.0,
                "lon_max": 192.0,
                "lat_min": 66.0,
            },
            "Siberian_Sea": {
                "lon_min": 141.0,
                "lon_max": 177.0,
                "lat_min": 68.0,
            },
            "Laptev_Sea": {
                "lon_min": 96.0,
                "lon_max": 141.0,
                "lat_min": 70.0,
            },
            "Kara_Sea": {
                "lon_min": 65.0,
                "lon_max": 96.0,
                "lat_min": 66.0,
            },
            "Barents_Sea": {
                "lon_min": 35.0,
                "lon_max": 65.0,
                "lat_min": 68.0,
            },
        }
        lat_verts = [71, 73, 77, 82, 77, 71]
        lat_buffer = 2
        for sea_limits, y0, y1 in zip(limits.values(), lat_verts[:-1], lat_verts[1:]):
            x0 = sea_limits["lon_max"]
            x1 = sea_limits["lon_min"]
            sea_limits["lat_max"] = y0 + lat_buffer + (y1 - y0) * (lon - x0) / (x1 - x0)
    elif sea_mask in ("arctic", "antarctic"):
        return obj
    else:
        raise ValueError(f"{sea_mask}=")

    # Convert longitude
    mask = xr.zeros_like(lon)
    for sea_limits in limits.values():
        mask = xr.where(
            (lon >= sea_limits["lon_min"])
            & (lon <= sea_limits["lon_max"])
            & (lat >= sea_limits["lat_min"])
            & (lat <= sea_limits["lat_max"]),
            1,
            mask,
        )
    return obj.where(mask)


def get_satellite_data(time):
    year_start = time.dt.year.min().values
    year_stop = time.dt.year.max().values

    year_ranges = {
        "envisat": range(max(2002, year_start), min(2010, year_stop) + 1),
        "cryosat_2": range(max(2010, year_start), min(2020, year_stop) + 1),
    }

    datasets = []
    for satellite, year_range in year_ranges.items():
        if not year_range:
            continue
        ds = download.download_and_transform(
            "satellite-sea-ice-thickness",
            {
                "satellite": satellite,
                "version": "2_0",
                "cdr_type": "cdr",
                "variable": "all",
                "year": list(map(str, year_range)),
                "month": [f"{month:02d}" for month in [1, 2, 3, 4, 10, 11, 12]],
            },
            chunks={"year": 1},
        )
        datasets.append(ds)
    return xr.concat(datasets, "time")


def regrid(obj, grid_out, **regrid_kwargs):
    # Remove nan columns
    for dim in [dim for dim in obj.dims if "x" in dim or "lon" in dim]:
        for i in (0, -1):
            if obj.isel({dim: i}).isnull().all():
                obj = obj.drop_isel({dim: i})
    return diagnostics.regrid(obj, grid_out, **regrid_kwargs)


def compare_model_vs_satellite(ds, ds_satellite, sea_mask):
    # Homogenize time
    ds = apply_sea_mask(ds.convert_calendar("standard", align_on="date"), sea_mask)
    ds["time"] = pd.to_datetime(ds["time"].dt.strftime("%Y-%m"))
    ds_satellite = apply_sea_mask(
        ds_satellite.convert_calendar("standard", align_on="date"), sea_mask
    )

    # Get variables
    dims = ("xc", "yc")
    std_name = "sea_ice_thickness"
    sit = ds.cf[std_name]
    sit_obs = ds_satellite.cf[std_name]
    sit_obs_err = ds_satellite.cf[f"{std_name} standard_error"]
    sit_model = sit.sel(time=sit_obs["time"])

    # Compute useful variables
    sit_diff = sit_model - sit_obs

    # Compute output
    dataarrays = {}
    dataarrays["sithick_bias"] = sit_diff.mean(dims)
    dataarrays["sithick_bias"].attrs = {
        "standard_name": "sea_ice_thickness_bias",
        "units": "m",
        "long_name": "Sea ice thickness bias",
    }

    dataarrays["sithick_rmse"] = (sit_diff**2).mean(dim=dims) ** (1 / 2)
    dataarrays["sithick_rmse"].attrs = {
        "standard_name": "sea_ice_thickness_rmse",
        "units": "m",
        "long_name": "Sea ice thickness root mean square error",
    }

    dataarrays["rms_sit_obs_error"] = (sit_obs_err**2).mean(dims) ** (1 / 2)
    dataarrays["rms_sit_obs_error"].attrs = {
        "standard_name": "root_mean_square_sea_ice_thickness_observation_error",
        "units": "%",
        "long_name": "Root mean square sea ice thickness observation error",
    }

    return xr.Dataset(dataarrays)


def compute_sea_ice_thickness_diagnostics(ds, sea_mask, **regrid_kwargs):
    ds_satellite = get_satellite_data(ds["time"])
    ds = regrid(ds, ds_satellite[["latitude", "longitude"]], **regrid_kwargs)
    return compare_model_vs_satellite(ds, ds_satellite, sea_mask)

## Download and transform

In [None]:
datasets = []
for model in models:
    tmp_datasets = []
    for sea_mask in sea_masks:
        print(f"{model=} {sea_mask=}")
        ds = download.download_and_transform(
            collection_id,
            request | {"model": model},
            chunks={"year": 1},
            transform_func=compute_sea_ice_thickness_diagnostics,
            transform_func_kwargs={
                "sea_mask": sea_mask,
                "method": "nearest_s2d",
                "periodic": True,
                "ignore_degenerate": True,
            },
        )
        tmp_datasets.append(ds.expand_dims(sea_mask=[sea_mask], model=[model]))
    datasets.append(xr.concat(tmp_datasets, "sea_mask").compute())
ds = xr.concat(datasets, "model")