# Reference Upper-Air Network for trend analysis

## Import libraries

In [None]:
import os

import cdsapi
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import download

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
start = "2016-02"  # "2006-05"
stop = "2016-02"  # "2020-04"

# Station
station = "NYA"

# CDS credential
os.environ["CDSAPI_RC"] = os.path.expanduser("~/ciardini_virginia/.cdsapirc")

## Define requests

In [None]:
collection_id = "insitu-observations-gruan-reference-network"
request = {
    "variable": ["air_temperature", "relative_humidity", "air_pressure", "altitude"],
    "format": "netcdf",
}

client = cdsapi.Client(sleep_max=1)
requests = []
for date in pd.date_range(start, stop, freq="1MS"):
    time_request = {"year": date.strftime("%Y"), "month": date.strftime("%m")}
    time_request["day"] = client.client.apply_constraints(
        collection_id, request | time_request
    )["day"]
    if time_request["day"]:
        requests.append(request | time_request)

## Define transform functions

In [None]:
def compute_specific_humidity(ds):
    pressure_hpa = ds["pressure"] * 0.01
    temperature_celsius = ds["air_temperature"] - 273.15
    sat_vap_p = 6.112 * np.exp(
        (17.67 * temperature_celsius) / (temperature_celsius + 243.5)
    )
    da = 622 * ds["relative_humidity"] * sat_vap_p / (100 * pressure_hpa)
    da.attrs = {"long_name": "Specific Humidity", "units": "g/kg"}
    return da


def compute_saturation_vapor_pressure(ds):
    temperature = ds["air_temperature"] - 273.15
    return 6.112 * np.exp((17.67 * temperature) / (temperature + 243.5))


def compute_integrated_water_vapour(ds):
    e_s = compute_saturation_vapor_pressure(ds)
    e = e_s * (ds["relative_humidity"]) / 100
    rho_v = (e * 18.015) / (10 * 8.3145 * ds["air_temperature"])
    iwv_value = rho_v * ds["altitude"].diff("altitude")
    da = iwv_value.sum("altitude")
    da.attrs = {"long_name": "Integrated Water Vapour", "units": "kg/m²"}
    return da


def reorganize_dataset(ds):
    # Rename
    (varname,) = set(ds["observed_variable"].values)
    ds = ds.rename(observation_value=str(varname)).drop_vars("observed_variable")
    ds = ds.rename(
        {
            var: "_".join([varname, var.replace("_value", "")])
            for var in ds.data_vars
            if var.startswith("uncertainty")
        }
    )
    # Update attrs
    for var, da in ds.data_vars.items():
        match var:
            case "pressure":
                da.attrs["standard_name"] = "Pressure"
            case "air_temperature":
                da.attrs["standard_name"] = "Temperature"
            case "altitude":
                da.attrs["standard_name"] = "Altitude"
            case "relative_humidity":
                da.attrs["standard_name"] = "Relative"
        for string in ("units", "type"):
            if string in var:
                ds = ds.drop_vars(var)
                (value,) = set(da.values)
                attrs_var = varname if var == string else var.replace("_" + string, "")
                ds[attrs_var].attrs[string] = value
    return ds


def test(ds, station):
    for var, da in ds.data_vars.items():
        if np.issubdtype(da.dtype, np.bytes_):
            ds[var].values = np.char.decode(da.values, "utf-8")
    ds = ds.where(ds["primary_station_id"] == station, drop=True)
    if not ds.sizes["index"]:
        return xr.Dataset()

    datasets = []
    for var, ds in ds.groupby("observed_variable"):
        datasets.append(reorganize_dataset(ds))
    ds = xr.merge(datasets)

    # Global attrs
    for var, da in ds.data_vars.items():
        if len(set(da.values)) == 1:
            ds = ds.drop_vars(var)
            ds.attrs[var] = da[0].values.tolist()

    # Add variables
    ds["specific_humidity"] = compute_specific_humidity(ds)
    ds["time"] = ("index", pd.to_datetime(ds["report_timestamp"]).values)

    # Compute profiles
    subset = ["air_temperature", "relative_humidity", "specific_humidity", "time"]
    profiles = []
    for time, profile in ds.groupby("time"):
        profile = profile.swap_dims(index="altitude")[subset]
        profile = profile.sortby("altitude").dropna(
            "altitude", how="any", subset=subset
        )
        if (profile["altitude"].diff("altitude") > 2_000).any():
            continue

        profile = profile.interp(altitude=range(50, 30_001, 50))
        profiles.append(profile.expand_dims(time=[time]))
    ds_profiles = xr.concat(profiles, "time")
    ds_profiles["integrated_water_vapour"] = compute_integrated_water_vapour(
        ds_profiles
    )
    return ds_profiles

## Download and transform

In [None]:
ds = download.download_and_transform(
    collection_id,
    requests,
    transform_func=test,
    transform_func_kwargs={"station": station},
)

## Plot profiles

In [None]:
plot_kwargs = {"y": "altitude"}
for var, da in ds.data_vars.items():
    if "altitude" not in da.dims:
        continue
    da.plot(hue="time", add_legend=False, **plot_kwargs)
    mean = da.mean("time", keep_attrs=True)
    std = da.std("time", keep_attrs=True)
    for sign in (-1, +1):
        (mean + std * sign).plot(
            color="k",
            linestyle="--",
            label="mean ± std" if sign > 0 else None,
            **plot_kwargs,
        )
    mean.plot(color="k", linestyle="-", label="mean", **plot_kwargs)
    plt.legend()
    plt.grid()
    plt.show()

## Plot timeseries

In [None]:
for var, da in ds.drop_dims("altitude").data_vars.items():
    da.plot(marker="o")
    plt.grid()