# Regional plots of XCO2 level 3 satellite data

## Import libraries

In [None]:
import math

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Single time to display
time = "2016-01"

# Range for annual mean
time_range = slice("2015", "2020")

# Mask
min_land_fraction = 0.5  # Use None to switch off

# Regions
regions = {
    "global": {"lon_slice": slice(-180, 180), "lat_slice": slice(-90, 90)},
    "north_america": {"lon_slice": slice(-160, -60), "lat_slice": slice(10, 90)},
    "europe_africa": {"lon_slice": slice(-20, 80), "lat_slice": slice(-5, 60)},
    "asia": {"lon_slice": slice(70, 165), "lat_slice": slice(-15, 80)},
}

## Define request

In [None]:
collection_id = "satellite-carbon-dioxide"
request = {
    "processing_level": ["level_3"],
    "variable": "xco2",
    "sensor_and_algorithm": "merged_obs4mips",
    "version": ["4_5"],
}

## Define functions to cache

In [None]:
def convert_units(da):
    if da.name.endswith("_nobs"):
        return da

    with xr.set_options(keep_attrs=True):
        if da.name.startswith("xch4") and da.attrs["units"] != "ppb":
            da = da * 1.0e9
            da.attrs["units"] = "ppb"
        elif da.name.startswith("xco2") and da.attrs["units"] != "ppm":
            da = da * 1.0e6
            da.attrs["units"] = "ppm"
    return da


def mask_scale_and_regionalise(ds, min_land_fraction, lon_slice, lat_slice):
    if min_land_fraction is not None:
        ds = ds.where(ds["land_fraction"] >= min_land_fraction)
    for var, da in ds.data_vars.items():
        if (fill_value := da.attrs.pop("fill_value", None)) is not None:
            da = da.where(da != fill_value.astype(da.dtype))
        ds[var] = convert_units(da)
    return utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)

## Download and transform

In [None]:
datasets = {}
for region, kwargs in regions.items():
    print(f"{region=}")
    ds = download.download_and_transform(
        collection_id,
        request,
        transform_func=mask_scale_and_regionalise,
        transform_func_kwargs={"min_land_fraction": min_land_fraction} | kwargs,
    ).sel(time=time_range)
    for da in ds.data_vars.values():
        if da.attrs.get("units") in ["1", 1]:
            da.attrs.pop("units")
    datasets[region] = ds

## Plot detrended anomalies

In [None]:
for region in ["global"]:
    ds = datasets[region]
    for variable in ["xco2"]:
        da = diagnostics.annual_weighted_mean(ds[variable])
        da_mean = diagnostics.time_weighted_mean(ds[variable])
        with xr.set_options(keep_attrs=True):
            da = da - da_mean
            da_trend = xr.polyval(
                da["year"],
                da.polyfit("year", deg=1).polyfit_coefficients,
            )
            da = da - da_trend
        facet = plot.projected_map(da, col="year", col_wrap=3)
        facet.fig.suptitle(f"{region =} {variable =}")

## Plot averages

In [None]:
for region in ["global"]:
    ds = datasets[region]
    for variable in ["xco2_stderr"]:
        da = diagnostics.annual_weighted_mean(ds[variable])
        facet = plot.projected_map(da, col="year", col_wrap=3)
        facet.fig.suptitle(f"{region =} {variable =}")

## Plot single-time maps

In [None]:
for variable in ["xco2"]:
    fig, axs = plt.subplots(
        2,
        math.ceil(len(regions) / 2),
        subplot_kw={"projection": ccrs.PlateCarree()},
        figsize=(11, 6),
    )
    for ax, (region, ds) in zip(axs.flatten(), datasets.items()):
        da = ds[variable].sel(time=time)
        plot.projected_map(da, ax=ax, show_stats=False)
        ax.set_title(region)
    fig.suptitle(f"{variable =} {time = }")
    plt.show()

## Plot timeseries

In [None]:
for variable in ["xco2", "xco2_stderr", "xco2_stddev", "xco2_nobs"]:
    means = []
    stds = []
    for region, ds in datasets.items():
        da = ds[variable]
        means.append(diagnostics.spatial_weighted_mean(da).expand_dims(region=[region]))
        stds.append(diagnostics.spatial_weighted_std(da).expand_dims(region=[region]))
    da_mean = xr.concat(means, "region")
    da_std = xr.concat(stds, "region")

    facet = da_mean.plot(col="region", col_wrap=2)
    for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
        lower = da_mean.sel(sel_dict) - da_std.sel(sel_dict)
        ax.fill_between(
            da["time"],
            lower.where(lower > 0, 0),
            da_mean.sel(sel_dict) + da_std.sel(sel_dict),
            alpha=0.5,
        )
        ax.grid()
    facet.fig.suptitle(f"{variable = }")