# Impact of Drought on the Leaf Area Index in Yunnan Province, China

## Import libraries

In [None]:
import calendar
import os

import geopandas as gpd
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import pandas as pd
import regionmask
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

os.environ["CDSAPI_RC"] = os.path.expanduser("~/couplet_victor/.cdsapirc")
plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
year_start = 2007
year_stop = 2013

# Space
shapefile = "zip+https://geodata.ucdavis.edu/gadm/gadm4.1/shp/gadm41_CHN_shp.zip"
layer = "gadm41_CHN_1"
province = "Yunnan"
lon_slice = slice(97.53, 106.19)
lat_slice = slice(29.23, 21.14)

## Define request

In [None]:
collection_id = "satellite-lai-fapar"
base_request = {
    "variable": "lai",
    "sensor": "vgt",
    "horizontal_resolution": "1km",
    "product_version": "v3",
    "area": [53, 73, 18, 135],
}

requests = []
for time in pd.date_range(str(year_start), str(year_stop + 1), freq="ME"):
    requests.append(
        base_request
        | {
            "satellite": "spot",
            "year": str(time.year),
            "month": f"{time.month:02d}",
            "nominal_day": ["10", "20", str(time.day)],
        }
    )

## Define functions to cache

In [None]:
def get_gdf_province(shapefile, layer, province):
    gdf = gpd.read_file(shapefile, layer=layer)
    gdf = gdf[gdf["NAME_1"] == province]
    return gdf.to_crs("EPSG:4326")


def mask_dataset_province_and_filter(
    ds, shapefile, layer, province, filter_string, lon_slice, lat_slice
):
    # Reproject to EPSG:4326 (if needed)
    gdf = get_gdf_province(shapefile, layer, province)

    # Ensure dataset has a CRS
    if not ds.rio.crs:
        ds = ds.rio.write_crs("EPSG:4326")

    # Clip dataset with Yunnan boundary
    ds = ds.rio.clip(gdf.geometry, gdf.crs, drop=False, invert=False)

    # Filter
    filters = {"conservative": 0xFC1, "relaxed": 0x9C1}
    quality_mask = (ds["retrieval_flag"].astype("int") & filters[filter_string]) == 0
    lai_mask = ds["LAI"] >= 0
    ds = ds.where(quality_mask & lai_mask)

    # Regionalise
    return utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)

## Download and transform

In [None]:
datasets = []
for filter_string in ("conservative", "relaxed"):
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=mask_dataset_province_and_filter,
        transform_func_kwargs={
            "shapefile": shapefile,
            "layer": layer,
            "province": province,
            "lon_slice": lon_slice,
            "lat_slice": lat_slice,
            "filter_string": filter_string,
        },
    )
    datasets.append(ds.expand_dims(filter=[filter_string]))
ds = xr.concat(datasets, "filter")

## Plot valid percentage

In [None]:
gdf = get_gdf_province(shapefile, layer, province)
mask = regionmask.mask_geopandas(gdf, ds["longitude"], ds["latitude"])
valid_percentage = (
    100 * ds["LAI"].count(["latitude", "longitude"]) / mask.notnull().sum()
)
valid_percentage = valid_percentage.assign_coords(
    month=valid_percentage["time"].dt.month
)
valid_percentage.attrs = {"long_name": "Valid", "units": "%"}
for label, da in valid_percentage.groupby("filter"):
    da = da.squeeze(drop=True)
    scatter = da.plot.scatter(
        hue="month",
        s=20,
        alpha=0.7,
        edgecolors="black",
        linewidth=0.2,
        cmap=plt.cm.rainbow,
        figsize=(12, 6),
        cbar_kwargs={"label": "Month", "ticks": range(1, 13)},
    )
    scatter.colorbar.ax.set_yticklabels(calendar.month_abbr[1:13])
    plt.title(f"Valid LAI (%) Over {province} ({label.title()} Filter)")
    plt.grid()
    plt.show()

## Plot monthly mean

In [None]:
monthly_mean = (
    diagnostics.spatial_weighted_mean(ds)
    .groupby("time.year")
    .map(diagnostics.monthly_weighted_mean)
)
norm = mcolors.Normalize(vmin=min(monthly_mean["year"]), vmax=max(monthly_mean["year"]))
colors = [
    cm.Reds(norm(year)) if year in (2009, 2010) else cm.Blues(0.4 + 0.4 * norm(year))
    for year in monthly_mean["year"]
]
da = monthly_mean["LAI"]
da.attrs = {"long_name": "Monthly Mean LAI"}
with plt.rc_context({"axes.prop_cycle": plt.cycler(color=colors)}):
    facet = monthly_mean["LAI"].plot(col="filter", hue="year", figsize=(14, 6))
for ax in facet.axs.flatten():
    ax.grid()
plt.show()