# CARRA Single Level Reanalysis

## Import packages

In [None]:
import tempfile

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
start = "1991-01"
stop = "2020-12"

# Region
domain = "west_domain"
assert domain in ("east_domain", "west_domain")

# Variable
variable = "2m_temperature"
assert variable in (
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
    "10m_wind_direction",
    "10m_wind_speed",
    "2m_relative_humidity",
    "2m_specific_humidity",
    "2m_temperature",
    "albedo",
    "cloud_base",
    "cloud_top",
    "fraction_of_snow_cover",
    "high_cloud_cover",
    "land_sea_mask",
    "low_cloud_cover",
    "mean_sea_level_pressure",
    "medium_cloud_cover",
    "orography",
    "percolation",
    "sea_ice_area_fraction",
    "sea_ice_surface_temperature",
    "sea_surface_temperature",
    "skin_temperature",
    "snow_albedo",
    "snow_density",
    "snow_depth_water_equivalent",
    "snow_on_ice_total_depth",
    "surface_pressure",
    "surface_roughness",
    "surface_roughness_length_for_heat",
    "surface_runoff",
    "total_cloud_cover",
    "total_column_graupel",
    "total_column_integrated_water_vapour",
    "visibility",
)

## Define Parameters

In [None]:
collection_id = "reanalysis-carra-single-levels"
request = {
    "domain": domain,
    "level_type": "surface_or_atmosphere",
    "variable": variable,
    "product_type": "analysis",
    "time": "12:00",
}
requests = download.update_request_date(
    request, start=start, stop=stop, stringify_dates=True
)

## Functions to cache

In [None]:
def get_da(ds):
    (varname,) = ds.data_vars
    return ds[varname]


def rechunk(da, target_store=None):
    chunks = {"forecast_reference_time": -1, "month": 1}
    da = da.chunk({dim: chunks.get(dim, "auto") for dim in da.dims}).unify_chunks()
    if target_store:
        da.to_zarr(target_store)
        ds = xr.open_dataset(target_store, chunks=dict(da.chunksizes), engine="zarr")
        da = ds.set_coords(da.coords)[da.name]
    da.encoding["chunksizes"] = tuple(map(max, da.chunks))
    return da


def rechunk_and_reduce(da, reduce_func, **kwargs):
    with tempfile.TemporaryDirectory() as tmpdir:
        return reduce_func(rechunk(da, tmpdir), **kwargs).compute()


def compute_time_weighted_reduction(ds, monthly, func, **kwargs):
    da = get_da(ds)
    if monthly:
        da = da.groupby("forecast_reference_time.month").map(
            rechunk_and_reduce, reduce_func=func, **kwargs
        )
    else:
        da = rechunk_and_reduce(da, reduce_func=func, **kwargs)
    return rechunk(da).to_dataset()

## Compute time reductions

In [None]:
maps_datasets = {}
for monthly in (False, True):
    dataarrays = []
    for func in (
        diagnostics.time_weighted_mean,
        diagnostics.time_weighted_std,
        diagnostics.time_weighted_linear_trend,
    ):
        print(f"{monthly=} {func.__name__=}")
        ds = download.download_and_transform(
            collection_id,
            requests,
            transform_func=compute_time_weighted_reduction,
            transform_func_kwargs={"monthly": monthly, "func": func, "weights": False},
            transform_chunks=False,
            chunks={"year": 1, "month": 1},
        )
        da = rechunk(get_da(ds)).rename(func.__name__)
        attrs = {
            "long_name": (
                " ".join(
                    [
                        func.__name__.replace("time_weighted_", "").title(),
                        "of",
                        variable,
                    ]
                ).replace("_", " ")
            )
        }
        if func == diagnostics.time_weighted_linear_trend:
            with xr.set_options(keep_attrs=True):
                da *= 60 * 60 * 24 * 365
            attrs["units"] = da.attrs["units"].replace("s-1", "year-1")
        else:
            attrs["units"] = da.attrs["units"]
        da.attrs = attrs
        dataarrays.append(da)
    maps_datasets[f"{monthly=}"] = xr.merge(dataarrays)

## Compute spatial weighted reductions

In [None]:
dataarrays = []
for func in (diagnostics.spatial_weighted_mean, diagnostics.spatial_weighted_std):
    print(f"{func.__name__=}")
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_chunks=True,
        transform_func=func,
        chunks={"year": 1, "month": 1},
    )
    dataarrays.append(get_da(ds).rename(func.__name__))
ds_timeseries = xr.merge(dataarrays)

## Plot maps

In [None]:
for ds in maps_datasets.values():
    projection = ccrs.LambertConformal(
        central_longitude=ds["longitude"].mean().values,
        central_latitude=ds["latitude"].mean().values,
    )
    for var, da in ds.data_vars.items():
        plot_obj = plot.projected_map(
            da,
            projection=projection,
            col="month" if "month" in ds.dims else None,
            col_wrap=3,
        )
        gridliners = (
            [gl for ax in plot_obj.axs.flat for gl in ax._gridliners]
            if "month" in ds.dims
            else plot_obj.axes._gridliners
        )
        for gl in gridliners:
            gl.x_inline = False
            gl.xlabel_style = {"rotation": -45}
        title = f"{collection_id.replace('-', ' ')}\nFrom {start} to {stop}".title()
        plt.suptitle(title, y=1, va="bottom") if "month" in ds.dims else plt.title(
            title
        )
        plt.show()

## Plot timeseries

In [None]:
fig, ax = plt.subplots()
ds_timeseries["spatial_weighted_mean"].plot(ax=ax, label="mean")
ax.fill_between(
    ds_timeseries["time"],
    ds_timeseries["spatial_weighted_mean"] - ds_timeseries["spatial_weighted_mean"],
    ds_timeseries["spatial_weighted_mean"] + ds_timeseries["spatial_weighted_mean"],
    alpha=0.25,
    label="mean ± std",
)
ax.grid()
ax.legend(loc="center left", bbox_to_anchor=(1, 1))
_ = ax.set_title(
    f"{collection_id}\n{domain}".replace("-", " ").replace("_", " ").title()
)