# Spatial trends and anomalies of xco2 and xch4 from satellite

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define Parameters

In [None]:
# Choose variable (xch4 or xco2)
variable = "xco2"
assert variable in ("xch4", "xco2")

## Define request

In [None]:
request = (
    "satellite-carbon-dioxide" if variable == "xco2" else "satellite-methane",
    {
        "processing_level": "level_3",
        "variable": variable,
        "sensor_and_algorithm": "merged_obs4mips",
        "version": "4.4",
        "format": "zip",
    },
)

## Define request

In [None]:
def get_da(ds):
    (varname,) = set(ds.data_vars) & {"xch4", "xco2"}
    return ds[varname]


def convert_units(da):
    with xr.set_options(keep_attrs=True):
        if da.name == "xch4" and da.attrs["units"] != "ppb":
            da *= 1.0e9
            da.attrs["units"] = "ppb"
        elif da.name == "xco2" and da.attrs["units"] != "ppm":
            da *= 1.0e6
            da.attrs["units"] = "ppm"
    return da


def compute_seasonal_timeseries(ds):
    # Shift years (shift -1 to get D(year-1)J(year)F(year))
    da = get_da(ds)
    da = da.assign_coords(year=ds["time"].dt.year.shift(time=-1).astype(int))
    # Get rid of 1st JF and last D, so it become [MAM, JJA, SON, DJF, ..., SON]
    da = da.isel(time=slice(2, -1))
    da = da.groupby("year").map(diagnostics.seasonal_weighted_mean)
    return convert_units(da).to_dataset()


def compute_statistics(ds):
    da = get_da(ds)
    da = diagnostics.spatial_weighted_statistics(da)
    return convert_units(da).to_dataset()


def compute_monthly_anomalies(ds):
    da = get_da(ds)
    with xr.set_options(keep_attrs=True):
        da = da.groupby("time.month") - da.groupby("time.month").mean()
    return convert_units(da)


def compute_mann_kendall_trend(da, **kwargs):
    coords_name = {"time": "time", "x": "longitude", "y": "latitude"}
    ds_trend = Mann_Kendall_test(da, coords_name=coords_name, **kwargs).compute()
    return ds_trend.rename({k: v for k, v in coords_name.items() if k != "time"})


def compute_seasonal_detrended_anomaly(da, **kwargs):
    da_trend = xr.polyval(da["time"], da.polyfit("time", **kwargs).polyfit_coefficients)
    da_detrended = da - da_trend
    return da_detrended.groupby("time.year").map(diagnostics.seasonal_weighted_mean)


def compute_trends(ds):
    da_anomaly = compute_monthly_anomalies(ds)

    # Mann-Kendall
    ds_mann_kendall = compute_mann_kendall_trend(
        da_anomaly, alpha=0.05, method="theilslopes"
    )
    ds_mann_kendall["trend"].attrs = {
        "long_name": f"Trend of anomalies of {da_anomaly.attrs['long_name']}",
        "units": f"{da_anomaly.attrs['units']}/month",
    }

    # Detrended anomalies
    da_detrended = compute_seasonal_detrended_anomaly(da_anomaly, deg=1)
    da_detrended.attrs = {
        "long_name": f"Detrended anomalies of {da_anomaly.attrs['long_name']}",
        "units": f"{da_anomaly.attrs['units']}",
    }

    ds_mann_kendall["detrended_anomaly"] = da_detrended
    return ds_mann_kendall

## Global annual variability

In [None]:
ds_seasonal = download.download_and_transform(
    *request, transform_func=compute_seasonal_timeseries
)
_ = plot.projected_map(
    ds_seasonal[variable],
    projection=ccrs.Robinson(),
    col="season",
    row="year",
    robust=True,
)

## Global mean values

In [None]:
ds_stats = download.download_and_transform(*request, transform_func=compute_statistics)
fig, ax = plt.subplots(1, 1)
ds_stats[variable].drop_sel(diagnostic="std").plot(hue="diagnostic", ax=ax)
mean = ds_stats[variable].sel(diagnostic="mean")
std = ds_stats[variable].sel(diagnostic="std")
ax.fill_between(ds_stats["time"], mean + std, mean - std, color="k", alpha=0.25)
plt.grid()

## Global trends

In [None]:
ds_trend = download.download_and_transform(*request, transform_func=compute_trends)

shading_kwargs = dict(robust=True, projection=ccrs.Robinson())
hatches_kwargs = {
    "plot_func": "contourf",
    "show_stats": False,
    "cmap": "none",
    "add_colorbar": False,
}

plot.projected_map(ds_trend["trend"], robust=True, projection=ccrs.Robinson())
plot.projected_map(
    ds_trend["p"],
    plot_func="contourf",
    show_stats=False,
    cmap="none",
    add_colorbar=False,
    levels=[0, 0.05, 1],
    hatches=["", "/" * 5],
)

## Detrended seasonal anomalies

In [None]:
_ = plot.projected_map(
    ds_trend["detrended_anomaly"],
    projection=ccrs.Robinson(),
    col="season",
    row="year",
    robust=True,
)