# Spatial trends and anomalies of xco2 and xch4 from satellite

## Import packages

In [None]:
import calendar

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define Parameters

In [None]:
# Choose variable (xch4 or xco2)
variable = "xco2"
assert variable in [
    f"{prefix}{suffix}"
    for prefix in ("xch4", "xco2")
    for suffix in ("", "_nobs", "_stderr", "_stddev")
]

# Minimum value of land fraction used for masking
min_land_fraction = 0.5  # None: Do not apply mask

# Choose a time period
year_start = 2012
year_stop = 2021

# Define region for analysis
lon_slice = slice(-180, 180)
lat_slice = slice(-90, 90)

## Define request

In [None]:
request = (
    "satellite-carbon-dioxide" if variable.startswith("xco2") else "satellite-methane",
    {
        "processing_level": "level_3",
        "variable": variable.split("_")[0],
        "sensor_and_algorithm": "merged_obs4mips",
        "version": "4.4",
        "format": "zip",
    },
)

## Functions to cache

In [None]:
transform_func_kwargs = {
    "min_land_fraction": min_land_fraction,
    "variable": variable,
    "year_start": year_start,
    "year_stop": year_stop,
    "lon_slice": lon_slice,
    "lat_slice": lat_slice,
}


def get_da(
    ds, min_land_fraction, variable, year_start, year_stop, lon_slice, lat_slice
):
    da = ds[variable].sel(time=slice(str(year_start), str(year_stop)))
    da = utils.regionalise(da, lon_slice=lon_slice, lat_slice=lat_slice)
    if min_land_fraction is not None:
        return da.where(ds["land_fraction"] >= min_land_fraction)
    return da


def convert_units(da):
    if da.name.endswith("_nobs"):
        return da

    with xr.set_options(keep_attrs=True):
        if da.name.startswith("xch4") and da.attrs["units"] != "ppb":
            da *= 1.0e9
            da.attrs["units"] = "ppb"
        elif da.name.startswith("xco2") and da.attrs["units"] != "ppm":
            da *= 1.0e6
            da.attrs["units"] = "ppm"
    return da


def compute_seasonal_timeseries(ds, **get_da_kwargs):
    # Shift years (shift -1 to get D(year-1)J(year)F(year))
    da = get_da(ds, **get_da_kwargs)
    da = da.assign_coords(year=ds["time"].dt.year.shift(time=-1).astype(int))
    # Get rid of 1st JF and last D, so it become [MAM, JJA, SON, DJF, ..., SON]
    da = da.isel(time=slice(2, -1))
    da = da.groupby("year").map(diagnostics.seasonal_weighted_mean)
    return convert_units(da).to_dataset()


def compute_statistics(ds, **get_da_kwargs):
    da = get_da(ds, **get_da_kwargs)
    da = diagnostics.spatial_weighted_statistics(da)
    return convert_units(da).to_dataset()


def compute_monthly_anomalies(ds, **get_da_kwargs):
    da = get_da(ds, **get_da_kwargs)
    with xr.set_options(keep_attrs=True):
        da = da.groupby("time.month") - da.groupby("time.month").mean()
    return convert_units(da)


def compute_mann_kendall_trend(da, **mann_kendall_kwargs):
    coords_name = {"time": "time", "x": "longitude", "y": "latitude"}
    ds_trend = Mann_Kendall_test(
        da, coords_name=coords_name, **mann_kendall_kwargs
    ).compute()
    return ds_trend.rename({k: v for k, v in coords_name.items() if k != "time"})


def compute_seasonal_detrended_anomaly(da, **polyfit_kwargs):
    da_trend = xr.polyval(
        da["time"], da.polyfit("time", **polyfit_kwargs).polyfit_coefficients
    )
    da_detrended = da - da_trend
    return da_detrended.groupby("time.year").map(diagnostics.seasonal_weighted_mean)


def compute_anomaly_trends(ds, **get_da_kwargs):
    da_anomaly = compute_monthly_anomalies(ds, **get_da_kwargs)

    # Mann-Kendall
    ds_mann_kendall = compute_mann_kendall_trend(
        da_anomaly, alpha=0.05, method="theilslopes"
    ).where(da_anomaly.notnull().any("time"))
    ds_mann_kendall["trend"].attrs = {
        "long_name": f"Trend of anomalies of {da_anomaly.attrs.get('long_name', da_anomaly.name)}",
        "units": f"{da_anomaly.attrs['units']}/month",
    }

    # Detrended anomalies
    da_detrended = compute_seasonal_detrended_anomaly(da_anomaly, deg=1)
    da_detrended.attrs = {
        "long_name": f"Detrended of anomalies of {da_anomaly.attrs.get('long_name', da_anomaly.name)}",
        "units": f"{da_anomaly.attrs['units']}",
    }

    ds_mann_kendall["detrended_anomaly"] = da_detrended
    return ds_mann_kendall


def compute_coverage(da, missing_values=1.0e20, dim="time"):
    return (da != missing_values).sum(dim) / da.sizes[dim]

## Coverage

In [None]:
ds = download.download_and_transform(*request)
da = ds["xco2"].groupby("time.month").map(compute_coverage)
da["month"] = [calendar.month_name[month] for month in da["month"].values]
da.attrs = {"long_name": "XCO$_2$ Data Coverage"}
facet = plot.projected_map(
    da,
    col="month",
    col_wrap=3,
    cmap="Greens",
    cbar_kwargs={"orientation": "horizontal"},
    levels=np.arange(0, 1.1, 0.1),
)

## Global annual variability

In [None]:
ds_seasonal = download.download_and_transform(
    *request,
    transform_func=compute_seasonal_timeseries,
    transform_func_kwargs=transform_func_kwargs,
)
_ = plot.projected_map(
    ds_seasonal[variable],
    projection=ccrs.Robinson(),
    col="season",
    row="year",
    robust=True,
)

## Global mean values

In [None]:
ds_stats = download.download_and_transform(
    *request,
    transform_func=compute_statistics,
    transform_func_kwargs=transform_func_kwargs,
)
fig, ax = plt.subplots(1, 1)
ds_stats[variable].drop_sel(diagnostic="std").plot(hue="diagnostic", ax=ax)
mean = ds_stats[variable].sel(diagnostic="mean")
std = ds_stats[variable].sel(diagnostic="std")
ax.fill_between(ds_stats["time"], mean + std, mean - std, color="k", alpha=0.25)
plt.grid()

## Global trends

In [None]:
ds_trend = download.download_and_transform(
    *request,
    transform_func=compute_anomaly_trends,
    transform_func_kwargs=transform_func_kwargs,
)

plot.projected_map(ds_trend["trend"], robust=True, projection=ccrs.Robinson())
plot.projected_map(
    ds_trend["p"],
    plot_func="contourf",
    show_stats=False,
    cmap="none",
    add_colorbar=False,
    levels=[0, 0.05, 1],
    hatches=["", "/" * 5],
)

## Detrended seasonal anomalies

In [None]:
_ = plot.projected_map(
    ds_trend["detrended_anomaly"],
    projection=ccrs.Robinson(),
    col="season",
    row="year",
    robust=True,
)