# Evaluation of sea ice concentration of CARRA

## Import libraries

In [None]:
import warnings

import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")
warnings.filterwarnings("ignore", module="cf_xarray")

## Set parameters

In [None]:
# Time
start = "1978-10"
stop = "2022-12"

# Region
domain = "west_domain"
assert domain in ("east_domain", "west_domain")

# Sea Ice Concentration Threshold
sic_threshold = 30  # %

## Define request

In [None]:
request_dict = {
    "ERA5": (
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": "sea_ice_cover",
            "time": [f"{hour:02d}:00" for hour in range(24)],
        },
    ),
    "CARRA": (
        "reanalysis-carra-single-levels",
        {
            "domain": domain,
            "level_type": "surface_or_atmosphere",
            "variable": "sea_ice_area_fraction",
            "product_type": "analisys",
            "time": [f"{hour:02d}:00" for hour in range(0, 24, 3)],
        },
    ),
}

## Define function to cache

In [None]:
def compare_model_vs_obs(ds, datasets_sat, sic_threshold, grid_cell_area):
    ds = ds.convert_calendar("standard", align_on="date")
    datasets_sat = {
        k: ds.convert_calendar("standard", align_on="date")
        for k, ds in datasets_sat.items()
    }

    grid_cell_area *= 1.0e-6  # 10^6 km2
    sic = ds.cf["sea_ice_area_fraction"]
    if sic.attrs.get("units", "") == "(0 - 1)":
        sic *= 100

    dims = ("xc", "yc")
    datasets = []
    for origin, ds_sat in datasets_sat.items():
        sic_model = sic.sel(time=ds_sat["time"])
        dataarrays = {}
        if "sea_ice_area_fraction" not in ds_sat.cf.standard_names:
            # Get variables
            si_class = ds_sat.cf["sea_ice_classification"]

            # Compute variables
            sic_obs_lower = (
                si_class.where(si_class != 3, 70)
                .where(si_class != 2, 30)
                .where(si_class != 1, 0)
            )
            sic_obs_upper = (
                si_class.where(si_class != 3, 100)
                .where(si_class != 2, 70)
                .where(si_class != 1, 30)
            )
            over = (sic_model > sic_threshold) & (sic_obs_upper <= sic_threshold)
            under = (sic_model <= sic_threshold) & (sic_obs_lower >= sic_threshold)
        else:
            # Get variables
            sic_obs = ds_sat.cf["sea_ice_area_fraction"]
            sic_obs_err = ds_sat.cf["sea_ice_area_fraction standard_error"]

            # Compute variables
            sic_diff = sic_model - sic_obs
            over = (sic_model > sic_threshold) & (sic_obs <= sic_threshold)
            under = (sic_model <= sic_threshold) & (sic_obs > sic_threshold)

            # Compute output
            dataarrays["siconc_bias"] = sic_diff.mean(dims)
            dataarrays["siconc_bias"].attrs = {
                "standard_name": "sea_ice_concentration_bias",
                "units": "%",
                "long_name": "Sea ice concentration bias",
            }

            dataarrays["siconc_rmse"] = (sic_diff**2).mean(dim=dims) ** (1 / 2)
            dataarrays["siconc_rmse"].attrs = {
                "standard_name": "sea_ice_concentration_rmse",
                "units": "%",
                "long_name": "Sea ice concentration root mean square error",
            }

            dataarrays["rms_sic_obs_error"] = (sic_obs_err**2).mean(dims) ** (1 / 2)
            dataarrays["rms_sic_obs_error"].attrs = {
                "standard_name": "root_mean_square_sea_ice_concentration_observation_error",
                "units": "%",
                "long_name": "Root mean square sea ice concentration observation error",
            }

            dataarrays["siarea_bias"] = (
                (sic_model.sum(dims) - sic_obs.sum(dims)) * grid_cell_area / 100
            )
            dataarrays["siarea_bias"].attrs = {
                "standard_name": "sea_ice_area_bias",
                "units": "$10^6$km$^2$",
                "long_name": "Sea ice area bias",
            }

        # Compute common output
        over = over.sum(dims)
        under = under.sum(dims)
        dataarrays["iiee"] = (over + under) * grid_cell_area
        dataarrays["iiee"].attrs = {
            "standard_name": "integrated_ice_edge_error",
            "units": "$10^6$km$^2$",
            "long_name": "Integrated ice edge error",
        }

        dataarrays["siextent_bias"] = (over - under) * grid_cell_area
        dataarrays["siextent_bias"].attrs = {
            "standard_name": "sea_ice_extent_bias",
            "units": "$10^6$km$^2$",
            "long_name": "Sea ice extent bias",
        }
        datasets.append(xr.Dataset(dataarrays).expand_dims(origin=[origin]))
    return xr.concat(datasets, "origin") if datasets else xr.Dataset()


def add_bounds(obj):
    return obj.cf.add_bounds(
        coord for coord in ("longitude", "latitude") if coord not in obj.cf.bounds
    )


def interpolate_to_satellite_grid(obj, satellite_id, domain, **regrid_kwargs):
    # Lat 360 to 180
    if (obj["longitude"] > 180).any():
        with xr.set_options(keep_attrs=True):
            obj["longitude"] = (obj["longitude"] + 180) % 360 - 180

    request = {
        "region": "northern_hemisphere",
        "format": "zip",
        "year": "2002",
        "month": "06",
        "day": "01",
    }
    if satellite_id == "satellite-sea-ice-concentration":
        request |= {
            "version": "v2",
            "variable": "all",
            "origin": "esa_cci",
            "cdr_type": "cdr",
        }
    elif satellite_id == "satellite-sea-ice-edge-type":
        request |= {
            "variable": "sea_ice_edge",
            "cdr_type": "cdr",
            "version": "2_0",
        }
    else:
        raise ValueError(f"{satellite_id=}")
    grid_out = download.download_and_transform(satellite_id, request).drop_dims("time")

    if domain:
        grid_out["__mask__"] = get_carra_lsm(
            satellite_id=satellite_id, domain=domain
        ).notnull()
        grid_out = grid_out.where(grid_out["__mask__"].compute(), drop=True)
    if regrid_kwargs["method"].startswith("conservative"):
        obj = add_bounds(obj)
        grid_out = add_bounds(grid_out)

    obj = diagnostics.regrid(obj, grid_out, **regrid_kwargs)
    if domain:
        obj = obj.where(grid_out["__mask__"])
    return obj


def get_carra_lsm(satellite_id, domain):
    collection_id = "reanalysis-carra-single-levels"
    request = {
        "domain": domain,
        "level_type": "surface_or_atmosphere",
        "variable": "land_sea_mask",
        "product_type": "analisys",
        "year": "1990",
        "month": "09",
        "day": "01",
        "time": "00:00",
    }
    return (
        download.download_and_transform(
            collection_id,
            request,
            transform_func=interpolate_to_satellite_grid,
            transform_func_kwargs={
                "satellite_id": satellite_id,
                "domain": None,
                "method": "bilinear",
                "unmapped_to_nan": True,
            },
        )["lsm"]
        .squeeze()
        .reset_coords(drop=True)
    )


def get_interpolated_data(
    ds, add_stde, check_values, time_freq, domain, satellite_id, **regrid_kwargs
):
    if add_stde:
        stde = ds.cf["sea_ice_area_fraction standard_error"]

    cf_var = (
        "sea_ice_area_fraction"
        if "sea_ice_area_fraction" in ds.cf.standard_names
        else "sea_ice_classification"
    )
    ds = ds.cf[["latitude", "longitude", cf_var]]
    ds = ds.drop_dims(set(ds.dims) & {"vertices", "bnds"})

    if regrid_kwargs:
        ds = interpolate_to_satellite_grid(
            ds, satellite_id=satellite_id, domain=domain, **regrid_kwargs
        )
    else:
        mask = (
            get_carra_lsm(satellite_id=satellite_id, domain=domain).notnull().compute()
        )
        ds = ds.where(mask, drop=True)
        if add_stde:
            stde = stde.where(mask, drop=True)

    ds = ds.sortby("time").resample(time=time_freq).mean()
    ds["time"].attrs["long_name"] = "time"

    if add_stde:
        with xr.set_options(keep_attrs=True):
            ds = ds.merge((stde**2).resample(time=time_freq).mean() ** (1 / 2))
    if check_values:
        mask = ds.cf[cf_var].notnull() & (ds.cf[cf_var] != 0)
        ds = ds.sel(time=mask.any(set(mask.dims) - {"time"}))
    return ds


def get_satellite_data(time, time_freq, domain):
    year_start = time.dt.year.min().values
    year_stop = time.dt.year.max().values

    conc_request = {
        "cdr_type": "cdr",
        "variable": "all",
        "version": "v2",
        "region": "northern_hemisphere",
    }
    edge_request = {
        "format": "zip",
        "region": "northern_hemisphere",
        "variable": "sea_ice_edge",
        "cdr_type": "cdr",
        "version": "2_0",
    }

    satellite_requests = {
        "SEA ICE EDGE": download.update_request_date(
            edge_request,
            start=f"{max(year_start, 1978)}-01",
            stop=f"{min(year_stop, 2020)}-12",
            stringify_dates=True,
        ),
        "ESA-CCI": download.update_request_date(
            conc_request | {"origin": "esa_cci"},
            start=f"{max(year_start, 2002)}-01",
            stop=f"{min(year_stop, 2017)}-12",
            stringify_dates=True,
        ),
        "EUMETSAT-OSI-SAF": download.update_request_date(
            conc_request | {"origin": "eumetsat_osi_saf"},
            start=f"{max(year_start, 1979)}-01",
            stop=f"{min(year_stop, 2015)}-12",
            stringify_dates=True,
        ),
    }

    datasets_satellite = {}
    for name, requests in satellite_requests.items():
        if not requests:
            continue
        collection_id = "satellite-sea-ice-" + (
            "edge-type" if name == "SEA ICE EDGE" else "concentration"
        )
        ds = download.download_and_transform(
            collection_id,
            requests,
            chunks={"year": 1},
            transform_func=get_interpolated_data,
            transform_func_kwargs={
                "satellite_id": collection_id,
                "add_stde": name != "SEA ICE EDGE",
                "check_values": True,
                "time_freq": time_freq,
                "domain": domain,
            },
        )
        datasets_satellite[name] = ds.sel(time=slice(time.min(), time.max()))
    return datasets_satellite


def compute_sea_ice_evaluation_diagnostics(
    ds, sic_threshold, time_freq, domain, **regrid_kwargs
):
    ds = ds.reset_coords(set(ds.cf.coordinates["time"]) - set(ds.dims), drop=True)
    if "forecast_reference_time" in ds.dims:
        ds = ds.rename(forecast_reference_time="time")
    # Satellite data
    datasets_sat = get_satellite_data(
        ds["time"],
        time_freq=time_freq,
        domain=domain,
    )
    # CARRA data
    satellite_dict = {
        "satellite-sea-ice-concentration": {
            "grid_cell_area": 25**2,
            "datasets": {k: v for k, v in datasets_sat.items() if k != "SEA ICE EDGE"},
        },
        "satellite-sea-ice-edge-type": {
            "grid_cell_area": 12.5**2,
            "datasets": {k: v for k, v in datasets_sat.items() if k == "SEA ICE EDGE"},
        },
    }
    datasets = []
    for satellite_id, values in satellite_dict.items():
        ds_interpolated = get_interpolated_data(
            ds,
            add_stde=False,
            check_values=False,
            time_freq=time_freq,
            domain=domain,
            satellite_id=satellite_id,
            **regrid_kwargs,
        )
        ds_comparison = compare_model_vs_obs(
            ds_interpolated,
            values["datasets"],
            sic_threshold,
            grid_cell_area=values["grid_cell_area"],
        )
        if ds_comparison.dims:
            datasets.append(ds_comparison)
    return xr.concat(datasets, "origin")

## Download and transform

In [None]:
kwargs = {
    # Parameters to speed up IO
    "concat_dim": "forecast_reference_time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "drop_variables": ("type",),
}

datasets = []
for product, (collection_id, request) in request_dict.items():
    print(f"{product=}")
    requests = download.update_request_date(
        request, start=start, stop=stop, stringify_dates=True
    )
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=compute_sea_ice_evaluation_diagnostics,
        transform_func_kwargs={
            "sic_threshold": sic_threshold,
            "time_freq": "D",
            "domain": domain,
            "method": "conservative",
        },
        chunks={"year": 1, "month": 1},
        **kwargs,
    )
    datasets.append(ds.expand_dims(product=[product]))
ds = xr.concat(datasets, "product")

## Plot timeseries

In [None]:
for var, da in ds.data_vars.items():
    if var == "rms_sic_obs_error":
        continue
    fig, ax = plt.subplots()
    colors = (f"C{i}" for i in range(10))
    for product, da_product in da.groupby("product"):
        for origin, da_origin in da_product.groupby("origin"):
            color = next(colors)
            da_origin = da_origin.squeeze()
            if not da_origin.notnull().any():
                continue
            da_origin.plot(ax=ax, color=color, label=f"{product} vs {origin}")
            if var in ("siconc_bias", "siconc_rmse"):
                rms = ds["rms_sic_obs_error"].sel(product=product, origin=origin)
                ax.fill_between(
                    da_origin["time"],
                    da_origin - rms,
                    da_origin + rms,
                    color=color,
                    alpha=0.2,
                    label=f"{product} vs {origin}\n± {origin} RMSE",
                )
    plt.title(f"{domain.replace('_', ' ').title()}: From {start} to {stop}")
    plt.grid()
    plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
    plt.show()