# xch4 level 2 growth rates

## Import libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import xarray as xr
from c3s_eqc_automatic_quality_control import download
from xarray.groupers import BinGrouper

plt.style.use("seaborn-v0_8-notebook")

## Define parameters

In [None]:
year_start = 2004
year_stop = 2021

## Define request

In [None]:
collection_id = "satellite-methane"
request = {
    "processing_level": "level_2",
    "variable": "xch4",
    "sensor_and_algorithm": "merged_emma",
    "version": "4_5",
    "year": [str(year) for year in range(year_start - 1, year_stop + 2)],
    "month": [f"{i:02d}" for i in range(1, 13)],
    "day": [f"{i:02d}" for i in range(1, 32)],
}

## Define functions to cache

In [None]:
def weight_dataset(obj):
    weights = np.abs(np.cos(np.deg2rad(obj["latitude"])))
    return obj.weighted(weights)


def spatial_weighted_mean(obj, dim=None):
    return weight_dataset(obj).mean(dim, keep_attrs=True)


def spatial_weighted_std(obj, dim=None):
    return weight_dataset(obj).std(dim, keep_attrs=True)


def regrid(ds, d_lon, d_lat, lon1=180):
    if lon1 not in (180, 360):
        raise ValueError(f"lon1 must be 180 or 360. {lon1=}")
    lon0 = -180 if lon1 == 180 else 0

    coords = {}
    for name, start, stop, step in zip(
        ["latitude", "longitude"],
        [-90, lon0],
        [90, lon1],
        [d_lat, d_lon],
    ):
        if step is None:
            continue
        coords[name] = BinGrouper(
            np.arange(start, stop + step, step),
            include_lowest=True,
            labels=np.arange(start + step / 2, stop + step / 2, step),
        )
    ds = ds.compute()  # Groupby map does not work with dask
    ds = ds.groupby(**coords).map(spatial_weighted_mean)
    ds = ds.drop_vars(set(coords) & set(ds.variables)).rename(
        {f"{coord}_bins": coord for coord in coords}
    )
    return ds


def daily_regrid(ds):
    ds = ds[["xch4", "latitude", "longitude"]]
    datasets = []
    for time, ds_time in tqdm.tqdm(ds.resample(time="1D")):
        ds_time = regrid(ds_time, d_lon=2, d_lat=2)
        datasets.append(ds_time.expand_dims(time=[time]))
    return xr.concat(datasets, "time")


def monthly_regrid_in_bands(ds, zonal_first):
    datasets = []
    for time, ds_time in ds.resample(time="1MS"):
        if zonal_first:
            ds_time = ds_time.mean("longitude", keep_attrs=True)
        ds_time = regrid(ds_time, d_lat=20, d_lon=None)
        datasets.append(ds_time.expand_dims(time=[time]))
    return xr.concat(datasets, "time")


def compute_growth_rate(ds, zonal_first):
    da = monthly_regrid_in_bands(ds, zonal_first)["xch4"]
    da = (
        da.rolling(time=12, center=True)
        .construct("window_dim")
        .isel(window_dim=[0, -1])
        .diff("window_dim")
        .squeeze()
    )
    da.attrs = {"units": "ppm/year", "long_name": "Growth Rate"}
    return da

## Download and transform

In [None]:
ds = download.download_and_transform(
    collection_id,
    request,
    chunks={"year": 1},
    transform_func=daily_regrid,
)

## Compute growth rate

In [None]:
dataarrays = []
for zonal_first in [True, False]:
    da = compute_growth_rate(ds, zonal_first=zonal_first)
    dataarrays.append(
        da.expand_dims(method=["Zonal-first" if zonal_first else "Standard"])
    )
da = xr.concat(dataarrays, "method").sel(time=slice(str(year_start), str(year_stop)))

## Plot monthly growth rate

In [None]:
facet = da.plot(col="latitude", col_wrap=3, hue="method")
for ax in facet.axs.flatten():
    ax.grid()
    for label in ax.get_xticklabels():
        label.set_rotation(90)

## Heat map

In [None]:
with xr.set_options(keep_attrs=True):
    anomaly = da - spatial_weighted_mean(da)
anomaly.attrs["long_name"] = "Δ" + anomaly.long_name
facet = anomaly.plot(row="method", robust=True, x="time", figsize=(10, 10))
for ax in facet.axs.flatten():
    for label in ax.get_xticklabels():
        label.set_rotation(90)

## Plot yearly mean growth rate

In [None]:
grouped = da.groupby("time.year")
da_mean = grouped.map(spatial_weighted_mean, dim=["time", "latitude"])
da_std = grouped.map(spatial_weighted_std, dim=["time", "latitude"])
df_mean = da_mean.to_pandas().T
df_std = da_std.to_pandas().T
ax = df_mean.plot.bar(yerr=df_std)
ax.grid()
_ = ax.set_ylabel(f"{da_mean.long_name} [{da_mean.units}]")