# Ozone mixing ratio

## Import libraries

In [None]:
import os

import cdsapi
import matplotlib.pyplot as plt
import pandas as pd
import requests
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Client configuration
os.environ["CDSAPI_RC"] = os.path.expanduser("~/calmanti_sandro/.cdsapirc")

# Latitudes
lat_slices = [slice(54.5, 55.5), slice(64.5, 65.5)]

# Pressure level
pressure = 50

# Stations
stations = ["EDT", "MBI"]

## Set requests

In [None]:
collection_id = "satellite-ozone-v1"
request_dict = {
    "limb": {
        "processing_level": "level_3",
        "variable": "mole_concentration_of_ozone_in_air",
        "vertical_aggregation": "vertical_profiles_from_limb_sensors",
    },
    "nadir": {
        "processing_level": "level_3",
        "variable": "mole_content_of_ozone_in_atmosphere_layer",
        "vertical_aggregation": "vertical_profiles_from_nadir_sensors",
    },
}

## Functions to cache

In [None]:
def preprocess(ds):
    if "time" in ds.dims:
        return ds
    time_str = ds.attrs["time_coverage_start"][:6]
    time = pd.to_datetime(time_str, format="%Y%m")
    return ds.expand_dims(time=[time])


def compute_ozone_mixing_ratio(ds, lat_slices, pressure, rconst=8.314):
    # Select pressure
    ds = (
        ds.sel(altitude=pressure * 1.0e3)
        if "altitude" in ds.dims
        else ds.sel(pressure=pressure)
    )

    dataarrays = []
    for lat_slice in lat_slices:
        ds_lat = ds.sel(latitude=lat_slice)

        # Spatial weighted mean
        ds_lat = ds_lat.mean(
            set(ds_lat.dims) - {"time", "latitude", "longitude"},
            keep_attrs=True,
        )
        if ds_lat.sizes["latitude"] != 1:
            latitude = (
                ds_lat["latitude"].coarsen(latitude=ds_lat.sizes["latitude"]).mean()
            )
            ds_lat = diagnostics.spatial_weighted_mean(ds_lat)
            ds_lat = ds_lat.expand_dims(latitude=latitude)
            ds_lat["latitude"].attrs = latitude.attrs

        # Compute mixing ratio
        if "ozone_mixing_ratio" in ds_lat:
            da = ds_lat["ozone_mixing_ratio"] * 1.0e6
        else:
            da = (
                (ds_lat["ozone_concentration"] * rconst * ds_lat["temperature"])
                / ds_lat["pressure"]
            ) * 1e4
        dataarrays.append(da.compute())

    da = xr.concat(dataarrays, "latitude")
    da.attrs = {"long_name": "O₃ mixing ratio", "units": "ppm"}
    return da.to_dataset(name="ozone_mixing_ratio")


def download_combined_dataset(collection_id, request):
    collection = cdsapi.Client(quiet=True).client.get_collection(collection_id)

    datasets = []
    for sensor in collection.apply_constraints(request)["sensor"]:
        if sensor in ["cllg", "cmzm", "merged_np"]:
            continue  # no merged products
        sensor_r = request | {"sensor": sensor}
        for algo in collection.apply_constraints(sensor_r)["algorithm"] or [None]:
            algo_r = sensor_r | {"algorithm": algo or []}
            for version in collection.apply_constraints(algo_r)["version"]:
                version_r = algo_r | {"version": version}
                request_list = []
                for year in collection.apply_constraints(version_r)["year"]:
                    year_r = version_r | {"year": year}
                    months = collection.apply_constraints(year_r)["month"]
                    request_list.append(year_r | {"month": months})

                product = "-".join([sensor, version] + ([algo] if algo else []))
                if product == "omps-v0002-usask":
                    continue  # time_coverage_start attribute is missing
                print(f"{product=}")

                ds = download.download_and_transform(
                    collection_id,
                    request_list,
                    preprocess=preprocess,
                    transform_func=compute_ozone_mixing_ratio,
                    transform_func_kwargs={
                        "lat_slices": lat_slices,
                        "pressure": pressure,
                    },
                )
                datasets.append(ds.expand_dims(product=[product]))
    return xr.concat(datasets, "product")

## Download and transform

In [None]:
datasets = {}
for sensor, request in request_dict.items():
    print(f"{sensor=}")
    datasets[sensor] = download_combined_dataset(collection_id, request).compute()

## Download external data

In [None]:
def get_station(station_gaw_id, pressure):
    print(f"{station_gaw_id=}")
    response = requests.get(
        url="https://api.woudc.org/collections/ozonesonde/items",
        params={"station_gaw_id": station_gaw_id, "f": "json", "limit": 100000},
    )
    response.raise_for_status()
    data = response.json()
    dataarrays = []
    for feature in data["features"]:
        properties = feature["properties"]
        da = xr.DataArray(
            [
                o3pp if o3pp is None else float(o3pp)
                for o3pp in properties["o3partialpressure"]
            ],
            dims="pressure",
            coords={"pressure": list(map(float, properties["pressure"]))},
        )
        if pressure not in da["pressure"]:
            continue

        da = da.sel(pressure=pressure, drop=True)
        if "pressure" in da.dims:
            da = da.mean("pressure")  # Mean of duplicates
        time = pd.to_datetime(properties["timestamp_date"]).tz_localize(None)
        da = da.expand_dims(time=[time])
        dataarrays.append(da)

    da = xr.concat(dataarrays, "time").sortby("time").resample(time="1MS").mean()
    da = (da / pressure) * 10
    da.attrs = {"long_name": "O₃ mixing ratio", "units": "ppm"}
    return da.to_dataset(name="ozone_mixing_ratio")


ds_stations = xr.concat(
    [
        get_station(station, pressure=pressure).expand_dims(station=[station])
        for station in stations
    ],
    "station",
)

## Quick and dirty plots

In [None]:
(da_station,) = ds_stations.data_vars.values()
for sensor, ds in datasets.items():
    (da,) = ds.data_vars.values()
    da = da.dropna("product", how="all")
    da_station_cutout = da_station.sel(
        time=slice(*da["time"].dt.strftime("%Y-%m").values[[0, -1]])
    )
    cmap = plt.get_cmap("viridis", da.sizes["product"])
    colors = [cmap(i) for i in range(da.sizes["product"])]
    with plt.rc_context({"axes.prop_cycle": plt.cycler(color=colors)}):
        facet = da.plot(row="latitude", hue="product", figsize=(10, 6))
    for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
        match sel_dict["latitude"]:
            case 55:
                station = "EDT"
            case 65:
                station = "MBI"
            case latitude:
                raise NotImplementedError(f"{latitude=}")
        ax.plot(
            da_station_cutout["time"],
            da_station_cutout.sel(station=station),
            color="k",
            lw=1,
            label=station,
            zorder=1,
        )
        ax.legend()
        ax.grid()
    facet.fig.suptitle(f"{sensor = }", y=1.01)
    plt.show()