# CARRA Single Level Reanalysis

## Import packages

In [None]:
import tempfile

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
start = "2019-01"
stop = "2020-12"

# Region
domain = "west_domain"
assert domain in ("east_domain", "west_domain")

# Variable
variable = "2m_temperature"
assert variable in (
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
    "10m_wind_direction",
    "10m_wind_speed",
    "2m_relative_humidity",
    "2m_specific_humidity",
    "2m_temperature",
    "albedo",
    "cloud_base",
    "cloud_top",
    "fraction_of_snow_cover",
    "high_cloud_cover",
    "land_sea_mask",
    "low_cloud_cover",
    "mean_sea_level_pressure",
    "medium_cloud_cover",
    "orography",
    "percolation",
    "sea_ice_area_fraction",
    "sea_ice_surface_temperature",
    "sea_surface_temperature",
    "skin_temperature",
    "snow_albedo",
    "snow_density",
    "snow_depth_water_equivalent",
    "snow_on_ice_total_depth",
    "surface_pressure",
    "surface_roughness",
    "surface_roughness_length_for_heat",
    "surface_runoff",
    "total_cloud_cover",
    "total_column_graupel",
    "total_column_integrated_water_vapour",
    "visibility",
)

## Define Parameters

In [None]:
collection_id = "reanalysis-carra-single-levels"
request = {
    "domain": domain,
    "level_type": "surface_or_atmosphere",
    "variable": variable,
    "product_type": "analysis",
    "time": "12:00",
}
requests = download.update_request_date(
    request, start=start, stop=stop, stringify_dates=True
)

## Functions to cache

In [None]:
def get_da(ds):
    (varname,) = ds.data_vars
    return ds[varname]


def rechunk(da, target_store):
    da = da.chunk("auto").unify_chunks()
    da.to_zarr(target_store)
    ds = xr.open_dataset(target_store, chunks=dict(da.chunksizes), engine="zarr")
    return ds.set_coords(da.coords)[da.name]


def compute_time_mean_and_linear_trend(ds):
    da = get_da(ds)
    with tempfile.TemporaryDirectory() as tmpdir:
        print(f"{tmpdir=}")
        da = rechunk(da, f"{tmpdir}/target.zarr")
        datarrays = []
        for reduction in ("mean", "linear_trend"):
            print(f"{reduction=}")
            func = getattr(diagnostics, f"time_weighted_{reduction}")
            datarrays.append(func(da, weights=False).rename(reduction))
        return xr.merge(datarrays).compute()


def compute_spatial_weighted_mean_and_std(ds):
    da = get_da(ds)
    datarrays = []
    for reduction in ("mean", "std"):
        func = getattr(diagnostics, f"spatial_weighted_{reduction}")
        datarrays.append(func(da, weights=True).rename(reduction))
    return xr.merge(datarrays)

## Compute time reductions

In [None]:
ds_maps = download.download_and_transform(
    collection_id,
    requests,
    transform_func=compute_time_mean_and_linear_trend,
    transform_chunks=False,
    chunks={"year": 1, "month": 1},
)

## Compute spatial weighted reductions

In [None]:
ds_timeseries = download.download_and_transform(
    collection_id,
    requests,
    transform_func=compute_spatial_weighted_mean_and_std,
    chunks={"year": 1, "month": 1},
)

## Plot maps

In [None]:
projection = ccrs.LambertConformal(
    central_longitude=ds_maps["longitude"].mean().values,
    central_latitude=ds_maps["latitude"].mean().values,
)
for var, da in ds_maps.data_vars.items():
    plot_obj = plot.projected_map(da, projection=projection)
    for gl in plot_obj.axes._gridliners:
        gl.x_inline = False
        gl.xlabel_style = {"rotation": 0}
    plt.title(f"{collection_id.replace('-', ' ').title()}\nFrom {start} to {stop}")
    plt.show()

## Plot timeseries

In [None]:
fig, ax = plt.subplots()
ds_timeseries["mean"].plot(ax=ax, label="mean")
ax.fill_between(
    ds_timeseries["time"],
    ds_timeseries["mean"] - ds_timeseries["std"],
    ds_timeseries["mean"] + ds_timeseries["std"],
    alpha=0.25,
    label="mean ± std",
)
ax.grid()
ax.legend(loc="center left", bbox_to_anchor=(1, 1))
_ = ax.set_title(
    f"{collection_id}\n{domain}".replace("-", " ").replace("_", " ").title()
)