# ERA5 vs GRUAN

## Import libraries

In [None]:
import os

import cdsapi
import numpy as np
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import download

os.environ["CDSAPI_RC"] = os.path.expanduser("~/ciardini_virginia/.cdsapirc")

## Define parameters

In [None]:
# Time period
start = "2006-05"
stop = "2020-03"

# Regions
stations = {
    "NYA": {"latitude": 78.92, "longitude": 11.93},
}
assert isinstance(stations, dict)

# Pressure levels
levels = [
    "100",
    "125",
    "150",
    "175",
    "200",
    "225",
    "250",
    "300",
    "350",
    "400",
    "450",
    "500",
    "550",
    "600",
    "650",
    "700",
    "750",
    "775",
    "800",
    "825",
    "850",
    "875",
    "900",
    "925",
    "950",
    "975",
    "1000",
]

## Define requests

In [None]:
# GRUAN
collection_id_gruan = "insitu-observations-gruan-reference-network"
request_gruan = {
    "version": "1_0_0",
    "variable": [
        "air_temperature",
        "relative_humidity",
        "air_pressure",
        "altitude",
        "eastward_wind_speed",
        "northward_wind_speed",
    ],
    "data_format": "netcdf",
}

# ERA5
collection_id_era5 = "reanalysis-era5-pressure-levels-monthly-means"
request_era5 = {
    "product_type": "monthly_averaged_reanalysis",
    "variable": [
        "temperature",
        "u_component_of_wind",
        "v_component_of_wind",
        "relative_humidity",
    ],
    "pressure_level": levels,
    "time": "00:00",
    "data_format": "grib",
    "download_format": "unarchived",
}

# Build requests
client = cdsapi.Client()
requests_gruan = []
requests_era5 = []
for date in pd.date_range(start, stop, freq="1MS"):
    # GRUAN
    time_request = {"year": date.strftime("%Y"), "month": date.strftime("%m")}
    time_request["day"] = client.client.apply_constraints(
        collection_id_gruan, request_gruan | time_request
    )["day"]
    if time_request["day"]:
        requests_gruan.append(request_gruan | time_request)
    # ERA5
    requests_era5.append(request_era5 | time_request)

## Functions to cache

In [None]:
def _reorganize_dataset(ds):
    # Rename
    (varname,) = set(ds["observed_variable"].values)
    ds = ds.rename(observation_value=str(varname)).drop_vars("observed_variable")
    ds = ds.rename(
        {
            var: "_".join([varname, var.replace("_value", "")])
            for var in ds.data_vars
            if var.startswith("uncertainty")
        }
    )
    # Update attrs
    for var, da in ds.data_vars.items():
        da.attrs["long_name"] = var.replace("_", " ").title()
        for string in ("units", "type"):
            if string in var:
                ds = ds.drop_vars(var)
                (value,) = set(da.values)
                attrs_var = varname if var == string else var.replace("_" + string, "")
                ds[attrs_var].attrs[string] = value
    return ds


def reorganize_dataset(ds):
    for var, da in ds.data_vars.items():
        if np.issubdtype(da.dtype, np.bytes_):
            ds[var].values = np.char.decode(da.values, "utf-8")

    if not ds.sizes["index"]:
        return ds

    datasets = []
    for var, ds in ds.groupby("observed_variable"):
        datasets.append(_reorganize_dataset(ds))
    with xr.set_options(use_new_combine_kwarg_defaults=True):
        return xr.merge(datasets)


def compute_interpolated_insitu_profiles(ds, levels, variables):
    ds = reorganize_dataset(ds)

    # Add time index
    ds["time"] = ("index", pd.to_datetime(ds["report_timestamp"]).values)

    # Variables to retain
    profiles = []
    for station, ds_station in ds.groupby("primary_station_id"):
        for time, profile in ds_station.groupby("time"):
            # Organizza il profilo per altitudine
            profile = profile.swap_dims(index="altitude")[variables]
            profile = profile.sortby("altitude")
            profile = profile.dropna("altitude", how="any", subset=variables)
            profile = profile.drop_duplicates("altitude")

            # Quality check
            if (
                not profile.sizes["altitude"]
                or (profile["altitude"].diff("altitude") > 2_000).any()
            ):
                continue

            # Interpolate
            if "pressure" not in profile:
                profile["pressure"] = 1013.25 * np.exp(-profile["altitude"] / 8434.5)
            profile = profile.swap_dims(altitude="pressure").drop_duplicates("pressure")
            try:
                profile = profile.interp(pressure=levels)
            except Exception:
                print(profile)
                raise
            profile["pressure"].attrs.update({"long_name": "Pressure", "units": "hPa"})

            # Append
            profile = profile.expand_dims(time=[time])
            profile = profile.assign_coords(station=("time", [station]))
            profiles.append(profile)
    return xr.concat(profiles, dim="time")


def select_nearest_station(ds, latitude, longitude):
    return ds.sel(latitude=latitude, longitude=longitude, method="nearest")

## GRUAN

In [None]:
ds_gruan = download.download_and_transform(
    collection_id_gruan,
    requests_gruan,
    chunks={"year": 1, "month": 1},
    transform_func=compute_interpolated_insitu_profiles,
    transform_func_kwargs={
        "levels": sorted(map(float, levels)),
        "variables": sorted(
            [
                "pressure" if variable == "air_pressure" else variable
                for variable in request_gruan["variable"]
            ]
        ),
    },
    cached_open_mfdataset_kwargs={"concat_dim": "time", "combine": "nested"},
)

if stations is not None:
    ds_gruan = ds_gruan.where(
        ds_gruan["station"].isin(sorted(stations)).compute(), drop=True
    )
ds_gruan = ds_gruan.compute()

## ERA5

In [None]:
datasets = []
for station, transform_func_kwargs in stations.items():
    print(f"{station=}")
    ds = download.download_and_transform(
        collection_id_era5,
        requests_era5,
        transform_func=select_nearest_station,
        transform_func_kwargs=transform_func_kwargs,
    )
    datasets.append(ds.expand_dims(station=[station]))
ds_era5 = xr.concat(datasets, "station").compute()

# Convert plev to hPa (mb)
ds_era5["plev"] = ds_era5["plev"] / 100
ds_era5["plev"].attrs.update({"long_name": "Level", "units": "hPa"})