# Tropopause

## Import packages

In [None]:
import calendar
import os
import pathlib

import cdsapi
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import scipy.stats
import statsmodels.tsa.seasonal
import xarray as xr
from c3s_eqc_automatic_quality_control import download
from scipy import stats

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
start = "2006-05"
stop = "2020-03"

# Stations
stations = ["TEN", "LIN", "NYA"]  # Use None to analyse all stations
assert isinstance(stations, list | None)

# Directory for csv files
csv_dir = "./csv_files"

# CDS credentials
os.environ["CDSAPI_RC"] = os.path.expanduser("~/ciardini_virginia/.cdsapirc")

## Define request

In [None]:
collection_id = "insitu-observations-gruan-reference-network"
request = {
    "variable": ["air_temperature", "relative_humidity", "air_pressure", "altitude"],
    "format": "netcdf",
}

client = cdsapi.Client()
requests = []
for date in pd.date_range(start, stop, freq="1MS"):
    time_request = {"year": date.strftime("%Y"), "month": date.strftime("%m")}
    time_request["day"] = client.client.apply_constraints(
        collection_id, request | time_request
    )["day"]
    if time_request["day"]:
        requests.append(request | time_request)

## Functions to cache

In [None]:
def _reorganize_dataset(ds):
    # Rename
    (varname,) = set(ds["observed_variable"].values)
    ds = ds.rename(observation_value=str(varname)).drop_vars("observed_variable")
    ds = ds.rename(
        {
            var: "_".join([varname, var.replace("_value", "")])
            for var in ds.data_vars
            if var.startswith("uncertainty")
        }
    )
    # Update attrs
    for var, da in ds.data_vars.items():
        match var:
            case "pressure":
                da.attrs["standard_name"] = "Pressure"
            case "air_temperature":
                da.attrs["standard_name"] = "Temperature"
            case "altitude":
                da.attrs["standard_name"] = "Altitude"
            case "relative_humidity":
                da.attrs["standard_name"] = "Relative"
        for string in ("units", "type"):
            if string in var:
                ds = ds.drop_vars(var)
                (value,) = set(da.values)
                attrs_var = varname if var == string else var.replace("_" + string, "")
                ds[attrs_var].attrs[string] = value
    return ds


def reorganize_dataset(ds, stations):
    for var, da in ds.data_vars.items():
        if np.issubdtype(da.dtype, np.bytes_):
            ds[var].values = np.char.decode(da.values, "utf-8")

    if stations is not None:
        ds = ds.where(ds["primary_station_id"].isin(stations), drop=True)

    if not ds.sizes["index"]:
        return xr.Dataset()

    datasets = []
    for var, ds in ds.groupby("observed_variable"):
        datasets.append(_reorganize_dataset(ds))
    ds = xr.merge(datasets)

    return ds


def calculate_tropopause(ds):
    attrs = {"long_name": "WMO Lapse-Rate Tropopause", "units": "km"}

    # convert, drop, and sort
    ds = ds.swap_dims(index="altitude")[["air_temperature", "altitude", "pressure"]]
    ds = ds.where(
        (ds["altitude"].notnull() & ds["pressure"].notnull()).compute(), drop=True
    )
    ds["altitude"] = ds["altitude"] * 1.0e-3
    ds["pressure"] = np.log10(ds["pressure"] * 1.0e-2)
    ds = (
        ds.drop_duplicates("altitude")
        .swap_dims(altitude="pressure")
        .drop_duplicates("pressure")
        .swap_dims(pressure="altitude")
        .sortby("altitude")
    )

    if not ds.sizes["altitude"] or (ds["altitude"][-1] - ds["altitude"][0]) < 2:
        # Column must be at least 2km
        return xr.DataArray(None, attrs=attrs)

    # interpolate
    interp_altitude = np.arange(0.1, 40.1, 0.1)
    temp = ds["air_temperature"].interp(altitude=interp_altitude, method="cubic")
    temp = temp.assign_coords(pressure=10 ** ds["pressure"].interp_like(temp))
    temp = temp.dropna("altitude")
    if not temp.sizes["altitude"]:
        return xr.DataArray(None, attrs=attrs)

    # compute lapse rate
    diff_kwargs = {"dim": "altitude", "label": "lower"}
    lapse_rate = -temp.diff(**diff_kwargs) / temp["altitude"].diff(**diff_kwargs)
    lapse_rate = lapse_rate.sel(altitude=slice(None, lapse_rate["altitude"].max() - 2))

    # mask and loop over valid lapse rates
    mask = (lapse_rate <= 2) & (lapse_rate["pressure"] <= 500)
    valid_altitude = lapse_rate["altitude"].where(mask.compute(), drop=True)
    for bottom in valid_altitude:
        temp_bottom = temp.sel(altitude=bottom)
        temp_above = temp.sel(altitude=slice(bottom, bottom + 2)).drop_sel(
            altitude=bottom
        )
        lapse_rate = (temp_bottom - temp_above) / (
            temp_above["altitude"] - temp_bottom["altitude"]
        )
        if (lapse_rate <= 2).all():
            return xr.DataArray(float(bottom.values), attrs=attrs)
    return xr.DataArray(None, attrs=attrs)


def compute_tropopause_altitude(ds, stations):
    ds = reorganize_dataset(ds, stations).compute()
    dataarrays = []
    for report_id, ds_id in ds.groupby(ds["report_id"]):
        coords = {"report_id": ("time", [report_id])}
        for var, da_coord in ds_id.data_vars.items():
            if "uncertainty" in var:
                continue

            unique = set(da_coord.values)
            if len(unique) == 1:
                coords[var] = ("time", list(unique))
        da = calculate_tropopause(ds_id)
        dataarrays.append(da.expand_dims("time").assign_coords(coords))
    da = xr.concat(dataarrays, "time")
    da = da.assign_coords(time=pd.to_datetime(da["report_timestamp"]).tz_localize(None))
    return da.sortby("time").to_dataset(name="tropopause")

## Download and transform

In [None]:
ds = download.download_and_transform(
    collection_id,
    requests,
    chunks={"year": 1, "month": 1},
    transform_func=compute_tropopause_altitude,
    transform_func_kwargs={"stations": sorted(stations) if stations else stations},
).compute()

csv_dir_path = pathlib.Path(csv_dir)
csv_dir_path.mkdir(exist_ok=True)
ds.to_pandas().to_csv(csv_dir_path / "ds_data.csv")

## Plot monthly timeseries

In [None]:
fig, ax = plt.subplots(figsize=[9, 4])
for station, da in ds["tropopause"].groupby("primary_station_id"):
    da_resampled = da.resample(time="MS")
    df_mean = da_resampled.mean().to_pandas()
    df_std = da_resampled.std().to_pandas()
    df_mean.plot(yerr=df_std, marker=".", label=station)
    df_mean.to_csv(csv_dir_path / f"monthly_mean_lrt_{station.lower()}.csv")
    df_std.to_csv(csv_dir_path / f"monthly_lrt_std_{station.lower()}.csv")
ax.set_title(f"{da.attrs['long_name']} (LRT)")
ax.set_xlabel("")
ax.set_ylabel(f"Altitude [{da.attrs['units']}]")
ax.grid()
_ = ax.legend()

## Plot seasonality

In [None]:
fig, ax = plt.subplots(figsize=[9, 4])
for station, da in ds["tropopause"].groupby("primary_station_id"):
    grouped = da.groupby("time.month")
    df_mean = grouped.mean().to_pandas()
    df_std = grouped.std().to_pandas()
    df_mean.plot(yerr=df_std, marker=".", label=station)
    ax.set_xticks(range(1, 13), calendar.month_abbr[1:13])
    ax.set_xlabel("")
    ax.set_ylabel(f"Altitude [{da.attrs['units']}]")
    ax.set_title("Annual cycle of mean LRT")
    ax.grid()
    ax.legend()

## Plot fit

In [None]:
fig, ax = plt.subplots(figsize=[9, 5])
for i, (station, da) in enumerate(ds["tropopause"].groupby("primary_station_id")):
    if station not in ["LIN", "NYA"]:
        continue

    # Seasonal decomposition
    df_mean = da.resample(time="MS").mean().interpolate_na("time").to_pandas()
    df_anom = df_mean - df_mean.mean()
    decomposition = statsmodels.tsa.seasonal.seasonal_decompose(
        df_anom, model="additive"
    )
    decomposition.trend.to_csv(csv_dir_path / f"intra-annual_{station}.csv")
    decomposition.seasonal.to_csv(csv_dir_path / f"seasonal_{station}.csv")

    # Linear fit
    x_trend = np.arange(decomposition.trend.size)[~decomposition.trend.isnull()]
    y_trend = decomposition.trend.dropna()
    res = stats.linregress(x_trend, y_trend)
    fit = pd.Series(res.intercept + res.slope * x_trend, index=y_trend.index)

    # Two-sided inverse Students t-distribution
    q = 0.05 / 2
    deg_freedom = len(x_trend) - 2
    ts = scipy.stats.t.ppf(q, deg_freedom)

    # Plot
    decomposition.trend.plot(color=f"C{i}", ls="-", label=station)
    fit.plot(color=f"C{i}", ls="--", label=f"linear fit for {station}")

    # Text
    print(f"{station}:")
    pad = 17
    print(f"{'Equation':>{pad}}: {res.slope:+.4f}x {res.intercept:+.4f}")
    print(f"{'R^2':>{pad}}: {res.rvalue**2:.2f}")
    print(
        f"{'p-value':>{pad}}: " + "< 0.001"
        if res.pvalue < 0.001
        else f"{res.pvalue:.4f}"
    )
    print(f"{'slope (95%)':>{pad}}: {res.slope:+.4f} ± {abs(ts * res.stderr):.4f}")
    print(
        f"{'intercept (95%)':>{pad}}: {res.intercept:+.4f} ± {abs(ts * res.intercept_stderr):.4f}"
    )

# Plot settings
ax.set_title("Linear trends")
ax.set_ylabel(f"Altitude anomaly [{da.attrs['units']}]")
ax.set_xlabel("")
ax.grid()
_ = ax.legend()