# Clip HYRAS data to catchment extent

In [None]:
import os
import time
import json
import rioxarray
import pandas as pd
import geopandas as gpd
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import dask
import contextily as ctx
# import exactextract only if WEIGHTED_STATISTICS is True
if True:
    from exactextract import exact_extract

In [None]:
# Define IDs and variables
CAMELS_IDS = [1, 2, 3, 4, 5]
HYRAS_VARIABLES = ["Precipitation",
                   "RadiationGlobal",
                   "Humidity",
                   "TemperatureMean", 
                   "TemperatureMax", 
                   "TemperatureMin"]

In [None]:
# Define paths
INPUT_PATH = "../input_data"
CATCHMENT_PATH = "../../../catchments_harz"
RESULT_PATH = "../output_data"
WEIGHTED_STATISTICS = True
SAVE_NETCDF = True # save the netcdf file, can be very slow

In [None]:
start_time = time.time()

In [None]:
# Mapping variables to datasets
var_mapping = {
    "Humidity": dict(variable_name = "hurs", datapath = f"{INPUT_PATH}/hyras/Humidity/hurs_hyras_5_1951_2020_v5-0_de.nc"),
    "Precipitation": dict(variable_name = "pr", datapath = f"{INPUT_PATH}/hyras/Precipitation/pr_hyras_1_1931_2020_v5-0_de.nc"),
    "RadiationGlobal": dict(variable_name = "rsds", datapath = f"{INPUT_PATH}/hyras/RadiationGlobal/*.nc"),
    "TemperatureMax": dict(variable_name = "tasmax", datapath = f"{INPUT_PATH}/hyras/TemperatureMax/tasmax_hyras_5_1951_2020_v5-0_de.nc"),
    "TemperatureMin": dict(variable_name = "tasmin", datapath = f"{INPUT_PATH}/hyras/TemperatureMin/tasmin_hyras_5_1951_2020_v5-0_de.nc"),
    "TemperatureMean": dict(variable_name = "tas", datapath = f"{INPUT_PATH}/hyras/TemperatureMean/tas_hyras_5_1951_2020_v5-0_de.nc"),
}
# Mapping catchment IDs to shapefiles
id_mapping = {
    1: dict(shapefile = "innerste_reservoir_catchment.shp", catchment_id = 1),
    2: dict(shapefile = "oker_reservoir_catchment.shp", catchment_id = 2),
    3: dict(shapefile = "ecker_reservoir_catchment.shp", catchment_id = 3),
    4: dict(shapefile = "soese_reservoir_catchment.shp", catchment_id = 4),
    5: dict(shapefile = "grane_reservoir_catchment.shp", catchment_id = 5),
}

# empty list to store warnings
warnings = []

In [None]:
for camels_id in CAMELS_IDS:
    for hyras_variable in HYRAS_VARIABLES:
        print(f"Processing {hyras_variable} for CAMELS_ID {camels_id}")
        
        # get the variable and data_path
        variable = var_mapping[hyras_variable]["variable_name"]
        data_path = var_mapping[hyras_variable]["datapath"]
        # Load data with dask for better performance with large datasets
        ds = xr.open_mfdataset(data_path, combine="by_coords", chunks="auto").unify_chunks()

        # Conditional slicing for Precipitation starting from 1951
        if hyras_variable == 'Precipitation':
            ds = ds.sel(time=slice('1951', None))

        # need to set the crs (EPSG:3034)
        ds.rio.write_crs("EPSG:3034", inplace=True)
        
        # drop variable time_bnds, x_bnds_clipped_clipped_clipped, y_bnds and coordinate crs_HYRAS (makes problems with xarray)
        ds = ds.drop_vars("time_bnds")
        ds = ds.drop_vars("x_bnds")
        ds = ds.drop_vars("y_bnds")
        
        # set the spatial dimensions
        ds.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=True)

        catchment_info = id_mapping[camels_id]
        shapefile_path = os.path.join(CATCHMENT_PATH, catchment_info["shapefile"])
        catchment = gpd.read_file(shapefile_path)
        # Transform crs to EPSG:3034 to be in the same crs as the hyras data
        catchment = catchment.to_crs(epsg=3034)

        # make output folder if it does not exist
        os.makedirs(f"{RESULT_PATH}/{camels_id}/data", exist_ok=True)
        # Save catchment as geojson to the output folder depending on the CAMELS_ID
        catchment.to_file(f"{RESULT_PATH}/{camels_id}/data/catchment.geojson", driver="GeoJSON")

        # Do the Clip
        # Clip the data to the catchment shape, all_touched=True to get all pixels that are at least partially in the catchment
        ds_clipped = ds.rio.clip(catchment.geometry, all_touched=True)

        # Load the data into memory, this yielded the fastest computation times
        ds_clipped = ds_clipped.load()

        # Plotting spatial data
        # plot ds_clipped together with gdf_polygon
        fig_spatial, ax = plt.subplots(figsize=(16, 7))

        # plot ds_clipped on top
        ds_clipped[variable].isel(time=0).plot(alpha=1, ax=ax, cmap="viridis")

        # plot catchment first, big red border, no fill
        catchment.plot(ax=ax, color="none", edgecolor="black", linewidth=3)

        # add basemap but this needs an in internet connection and sometimes takes a while
        try:
            ctx.add_basemap(ax, crs=ds_clipped.rio.crs.to_string(), source=ctx.providers.OpenTopoMap)
        except Exception as e:
            print(f"Basemap loading not succesfull: {e}")
            warnings.append(f"Basemap loading not succesfull: {e}")

        # Increase x and y limits
        xmin, xmax = ax.get_xlim()
        ymin, ymax = ax.get_ylim()
        ax.set_xlim(xmin - 0.2*(xmax-xmin), xmax + 0.2*(xmax-xmin))
        ax.set_ylim(ymin - 0.2*(ymax-ymin), ymax + 0.2*(ymax-ymin))

        # Add a title
        ax.set_title(f"{hyras_variable} clipped to catchment {camels_id}")

        # Aggregate to timeseries and calculate statistics
        # Remove the grid_mapping key from the variable's attributes (problems with xarray)
        ds_clipped[variable].attrs.pop("grid_mapping", None)

        # drop variable crs_HYRAS (problems with xarray)
        ds_clipped = ds_clipped.drop_vars("crs_HYRAS")

        # if WEIGHTED_STATISTICS is True, use exactextract to calculate weighted statistics
        if WEIGHTED_STATISTICS:
            # list of statistics to calculate
            statistics = ["mean", "min", "median", "max", "stdev"]

            # calculate the weighted statistics
            df_weighted = exact_extract(ds_clipped[variable], catchment, statistics, output="pandas")

            # process df_weighted to get it to the right format
            df = df_weighted.T

            # get the time index from the xarray dataset
            time_index = ds_clipped.time.values

            # create a list of dataframes, each dataframe contains the timeseries for one statistic
            sliced_dfs = [df.iloc[i:i+len(time_index)] for i in range(0, len(df), len(time_index))]

            # set the index to the time values and rename the columns
            for i, df in enumerate(sliced_dfs):
                df.index = time_index
                df.columns = [f"{variable}_{statistics[i]}"]

            # concatenate the dataframes
            df_timeseries = pd.concat(sliced_dfs, axis=1)

        # Plotting timeseries 
        fig_timeseries = plt.figure(figsize=(10, 7))

        # Define the height ratios for the subplots
        gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1]) 

        # Plot all columns except 'hurs_std' in the first subplot
        ax0 = plt.subplot(gs[0])
        lines1 = df_timeseries.drop(columns=[f"{variable}_stdev"]).groupby(pd.Grouper(freq='Y')).mean().plot(ax=ax0, lw=2, legend=False)
        ax0.set_title(f"{hyras_variable} yearly mean timeseries for catchment {camels_id}\n")
        ax0.xaxis.set_visible(False)  # Remove x-axis

        # Plot 'hurs_std' in the second subplot
        ax1 = plt.subplot(gs[1])
        lines2 = df_timeseries[f"{variable}_stdev"].groupby(pd.Grouper(freq='Y')).mean().plot(ax=ax1, lw=2, color='orange', legend=False)
        
        # Create a shared legend
        lines = lines1.get_lines() + lines2.get_lines()
        labels = [line.get_label() for line in lines]
        
        # Move the legend outside of the plot to the bottom
        fig_timeseries.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0), ncol=len(lines))
        plt.tight_layout()

        # Save results
        
        # Make output directory if it does not exist
        os.makedirs(f"{RESULT_PATH}/{camels_id}/plots", exist_ok=True)
        os.makedirs(f"{RESULT_PATH}/{camels_id}/data", exist_ok=True)
        
        # Save figures
        fig_spatial.savefig(f"{RESULT_PATH}/{camels_id}/plots/{hyras_variable}_catchment_clipped.png", dpi=300, bbox_inches="tight")
        fig_timeseries.savefig(f"{RESULT_PATH}/{camels_id}/plots/{hyras_variable}_timeseries.png", dpi=300, bbox_inches="tight")
        
        # Save timeseries data
        df_timeseries.to_csv(f"{RESULT_PATH}/{camels_id}/data/{camels_id}_{hyras_variable}.csv")
        
        # Save clipped data
        if SAVE_NETCDF:
            ds_clipped.to_netcdf(f"{RESULT_PATH}/{camels_id}/data/{camels_id}_{hyras_variable}_clipped.nc")
        # close xarray datasets
        ds.close()
        ds_clipped.close()
end_time = time.time()