# xch4 level 2 growth rates

## Import libraries

In [None]:
import flox.xarray
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")

## Define parameters

In [None]:
year_start = 2004
year_stop = 2022

## Define request

In [None]:
collection_id = "satellite-methane"
request = {
    "processing_level": "level_2",
    "variable": "xch4",
    "sensor_and_algorithm": "merged_emma",
    "version": "4_5",
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{i:02d}" for i in range(1, 13)],
    "day": [f"{i:02d}" for i in range(1, 32)],
}

## Define functions to cache

In [None]:
def arithmetic_unweighted_average(ds, d_lon, d_lat, lon1):
    if lon1 not in (180, 360):
        raise ValueError(f"lon1 must be 180 or 360. {lon1=}")
    lon0 = -180 if lon1 == 180 else 0

    coords = {}
    expected_groups = ()
    for name, start, stop, step in zip(
        ["latitude", "longitude"], [-90, lon0], [90, lon1], [d_lat, d_lon]
    ):
        coords[name] = np.arange(start + step / 2, stop + step / 2, step)
        groups = np.arange(start, stop + step, step)
        groups[0] -= step
        expected_groups += (pd.IntervalIndex.from_breaks(groups),)

    ds = flox.xarray.xarray_reduce(
        ds, *coords, func="mean", expected_groups=expected_groups, keep_attrs=True
    )
    ds = ds.rename({f"{coord}_bins": coord for coord in coords}).assign_coords(coords)
    for coord in ds.coords:
        ds[coord].attrs["standard_name"] = coord
    return ds


def monthly_regrid(ds, d_lon, d_lat, lon1=180):
    ds = ds.set_coords(["longitude", "latitude"])
    ds_out = ds.resample(time="1MS").map(
        arithmetic_unweighted_average, d_lon=d_lon, d_lat=d_lat, lon1=lon1
    )
    return ds_out

## Download and transform

In [None]:
ds = download.download_and_transform(
    collection_id,
    request,
    chunks={"year": 1},
    transform_func=monthly_regrid,
    transform_func_kwargs={"d_lon": 1, "d_lat": 1, "lon1": 180},
)

## Compute growth rate

In [None]:
def compute_growth_rate(da):
    da = da.groupby_bins(
        "latitude",
        bins=range(-90, 91, 20),
    ).map(diagnostics.spatial_weighted_mean)
    da = (
        da.rolling(time=12, center=True)
        .construct("window_dim")
        .isel(window_dim=[0, -1])
        .diff("window_dim")
        .squeeze()
    )
    return da


da = ds["xch4"]
da = compute_growth_rate(ds["xch4"])
da.attrs = {"units": "ppm/year", "long_name": "Growth Rate"}
facet = da.plot(col="latitude_bins", col_wrap=3)
for ax in facet.axs.flatten():
    ax.grid()
    for label in ax.get_xticklabels():
        label.set_rotation(90)