## Import packages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import download

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
start = "2016-02"
stop = "2016-03"

## Define request

In [None]:
collection_id = "insitu-observations-gruan-reference-network"
request = {
    "format": "csv-lev.zip",
    "variable": ["air_temperature", "altitude"],
}
requests = download.update_request_date(request, start=start, stop=stop)

## Functions to cache

In [None]:
def calculate_tropopause(ds):
    attrs = {"long_name": "WMO Lapse-Rate Tropopause", "units": "km"}

    # sort and drop
    ds = ds.swap_dims(index="altitude").drop("index").sortby("altitude")
    ds = (
        ds.where(ds["altitude"].notnull())
        .dropna("altitude", how="any")
        .drop_duplicates("altitude")
    )
    if not ds.sizes["altitude"]:
        return xr.DataArray(None, attrs=attrs)
    
    # convert units
    ds["altitude"] = ds["altitude"] * 1.0e-3
    ds["air_pressure"] = ds["air_pressure"] * 1.0e-2


    # interpolate
    interp_altitude = np.arange(0.1, 40.1, 0.1)
    temp = ds["air_temperature"].interp(altitude=interp_altitude, method="cubic")
    temp = temp.assign_coords(
        air_pressure=10 ** np.log10(ds["air_pressure"]).interp(altitude=interp_altitude)
    )
    temp = temp.dropna("altitude")

    # compute lapse rate
    diff_kwargs = {"dim": "altitude", "label": "lower"}
    lapse_rate = -temp.diff(**diff_kwargs) / temp["altitude"].diff(**diff_kwargs)
    lapse_rate = lapse_rate.sel(altitude=slice(None, lapse_rate["altitude"].max() - 2))

    # mask and loop over valid lapse rates
    mask = (lapse_rate <= 2) & (lapse_rate["air_pressure"] <= 500)
    valid_altitude = lapse_rate["altitude"].where(mask.compute(), drop=True)
    for bottom in valid_altitude:
        temp_bottom = temp.sel(altitude=bottom)
        temp_above = temp.sel(altitude=slice(bottom, bottom + 2)).drop_sel(
            altitude=bottom
        )
        lapse_rate = (temp_bottom - temp_above) / (
            temp_above["altitude"] - temp_bottom["altitude"]
        )
        if (lapse_rate <= 2).all():
            return xr.DataArray(float(bottom.values), attrs=attrs)
    return xr.DataArray(None, attrs=attrs)


def compute_tropopause_altitude(ds):
    dataarrays = []
    for report_id, ds_id in ds.groupby(ds["report_id"]):
        coords = {"report_id": ("time", [report_id])}
        for var, da_coord in ds_id.data_vars.items():
            unique = set(da_coord.values)
            if len(unique) == 1:
                coords[var] = ("time", list(unique))
        da = calculate_tropopause(ds_id)
        dataarrays.append(da.expand_dims("time").assign_coords(coords))
    da = xr.concat(dataarrays, "time")
    da = da.assign_coords(time=pd.to_datetime(da["report_timestamp"]).tz_localize(None))
    return da.sortby("time").to_dataset(name="tropopause")

## Download and compute tropopause

In [None]:
ds = download.download_and_transform(
    collection_id,
    requests,
    chunks={"year": 1, "month": 1},
    transform_func=compute_tropopause_altitude,
)

## Plot tropopause for each station

In [None]:
for station, da in ds["tropopause"].groupby("station_name"):
    da.plot(label=station)
plt.grid()
plt.legend()