# Explore CWatM data

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

from pathlib import Path

from tqdm.notebook import tqdm
import pandas as pd
import xarray as xr

import plotly.express as px

import src.data.cwatm_data as cwatm_data

In [2]:
PROCESSED_DATA_FOLDER_PATH = Path("../data/processed")

## Load CWatM data

In [None]:
all_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all.nc")).to_dataframe()
forcings_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings.nc")).to_dataframe()
outputs_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs.nc")).to_dataframe()

In [None]:
all_df

In [None]:
forcings_df

In [None]:
outputs_df

### Process the data

In [None]:
all_land_df = cwatm_data.process_inputs_df(all_df)
all_land_df

In [None]:
forcings_land_df = forcings_df.loc[all_land_df.index]
forcings_land_df

In [None]:
outputs_land_df = outputs_df.loc[all_land_df.index]
outputs_land_df

In [None]:
all_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all_land.parquet"))
forcings_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings_land.parquet"))
outputs_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs_land.parquet"))

## Load CWatM `_land` data

In [3]:
all_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all_land.parquet"))
forcings_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings_land.parquet"))
outputs_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs_land.parquet"))

In [None]:
all_land_df.describe()

In [None]:
forcings_land_df.describe()

In [None]:
outputs_land_df.describe()

## Visualize data

In [7]:
data_df = pd.concat((all_land_df, forcings_land_df, outputs_land_df), axis=1)

In [None]:
INPUTS_COLUMNS = list(all_land_df.columns)
#     "porosity",
#     "firstStorDepth",
#     "percolationImp",
#     "tanslope",
#     "maxRootDepth_forest",
#     "maxRootDepth_grassland"
# ]
FORCINGS_COLUMNS = list(forcings_land_df.columns)
#     "pr",
#     "tas",
#     "tasmax",
#     "tasmin",
#     "ps",
#     "rlds",
#     "rsds",
#     "sfcwind",
#     "hurs",
#     "huss",
# ]
OUTPUTS_COLUMNS = list(outputs_land_df.columns)
#     "evap-total",
#     "potevap",
#     "qr",
#     "qtot"
# ]

# data_df = data_df.iloc[:1000]

data_df.shape

In [214]:
from itertools import product
from tqdm import tqdm

import matplotlib.pyplot as plt
from matplotlib.markers import MarkerStyle
import matplotlib.patheffects as mpe

def display_individual_scatterplots(df: pd.DataFrame,
                                    dst_path: Path,
                                    valid_x,
                                    valid_y,
                                    regions_df = None,
                                    regions_2x2 = True,
                                    measures_df = None,
                                    measures_regions_df_dict = None,
                                    regions_color_palette = None,
                                    ):

    marker_style = MarkerStyle(marker=".",
                               fillstyle="full")
    outline = mpe.withStroke(linewidth=4, foreground='white')
    if regions_color_palette is None:
        palette = {"wet warm": "#018571",
                    "dry warm": "#a6611a",
                    "wet cold": "#80cdc1",
                    "dry cold": "#dfc27d"}
    else:
        palette = regions_color_palette
    
    combinations = product(valid_x, valid_y)

    for input_col, output_col in tqdm(list(combinations), desc="Computing input-output combinations"):
    
        if regions_df is None:

            fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
            fig.suptitle(f"Input '{input_col}' - Output '{output_col}'")

            axis.scatter(x=df[input_col], y=df[output_col],
                         marker=marker_style,
                         color="black",
                         edgecolor="none",
                         s=30,
                         alpha=0.25,
                         )
            
            axis.set_xlabel(input_col)
            axis.set_ylabel(output_col)

            df.sort_values(input_col, ascending=True, inplace=True)

            axis.plot(df[input_col], df[output_col].rolling(window=3000,
                                                            # win_type="gaussian",
                                                            center=True,
                                                            ).mean(
                                                                # std=2000
                                                            ),
                        c="black",
                        path_effects=[outline],
                        label=f"_global"
                        )

        else:
            
            regions = regions_df["region"].unique()

            if regions_2x2:
                fig, axis = plt.subplots(nrows=2, ncols=2, figsize=(8, 6),
                                        sharex=True, sharey=True, constrained_layout=False)
                axis = axis.flatten()
                fig.suptitle(f"Input '{input_col}' - Output '{output_col}'")

                for i, region in enumerate(regions):

                    # ensure regions_df has same indexes as data_df
                    region_indices = regions_df[regions_df["region"] == region].index
                    region_indices = set(region_indices).intersection(df.index)

                    region_data_df = df.loc[list(region_indices)]
                
                    axis[i].scatter(x=region_data_df[input_col], y=region_data_df[output_col],
                                    label=region,
                                    c=palette[region],
                                    marker=marker_style,
                                    s=10,
                                    edgecolor="none",
                                    alpha=0.25,
                                    )
                    
                    axis[i].set_xlabel(input_col)
                    axis[i].set_ylabel(output_col)
                    axis[i].label_outer()
            else:
                fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
                fig.suptitle(f"Input '{input_col}' - Output '{output_col}'")

                df.sort_values(input_col, ascending=True, inplace=True)

                axis.plot(df[input_col], df[output_col].rolling(window=3000,
                                                                # win_type="gaussian",
                                                                center=True,
                                                                ).mean(
                                                                    # std=2000
                                                                ),
                            c="black",
                            path_effects=[outline],
                            label=f"_global"
                            )
                
                for i, region in enumerate(regions):

                    # ensure regions_df has same indexes as data_df
                    region_indices = regions_df[regions_df["region"] == region].index
                    region_indices = set(region_indices).intersection(df.index)

                    region_data_df = df.loc[list(region_indices)]
                
                    axis.scatter(x=region_data_df[input_col], y=region_data_df[output_col],
                                 label=region,
                                 c=palette[region],
                                 marker=marker_style,
                                 s=10,
                                 edgecolor="none",
                                 alpha=0.25,
                                 )
                    
                    region_data_df.sort_values(input_col, ascending=True, inplace=True)

                    axis.plot(region_data_df[input_col], region_data_df[output_col].rolling(window=3000,
                                                                                            # win_type="gaussian",
                                                                                            center=True,
                                                                                            ).mean(
                                                                                                # std=2000
                                                                                            ),
                              c=palette[region],
                              path_effects=[outline],
                              label=f"_{region}"
                              )
                
                axis.set_xlabel(input_col)
                axis.set_ylabel(output_col)
                axis.label_outer()

        # Adjust layout and display the plots
        fig.tight_layout()
        if regions_df is not None:
            fig.subplots_adjust(bottom=0.13)
            legend = fig.legend(#labels=regions,
                                # title="Climate regions",
                                markerscale=3,
                                loc="lower left",
                                ncol=4)
            
            for legobj in legend.legend_handles:
                legobj.set_alpha(1)

        if measures_df is not None:

            import numpy as np

            measures_values = measures_df.loc[input_col, output_col]
            fig.subplots_adjust(bottom=0.165)
            
            # Extract and round the values
            pearson_val = np.round(abs(measures_values['pearson']), 2)
            spearman_val = np.round(abs(measures_values['spearman']), 2)
            mic_val = np.round(abs(measures_values['MIC']), 2)

            # Fixed x-position and manual alignment
            y_position = -0.139
            x_label = 0.734
            x_value = x_label + 0.005  # Slightly shifted to the right for alignment

            # Add labels
            axis.text(x_label, y_position, "Pearson:\nSpearman:\nMIC:",
                        transform=axis.transAxes, va='center', ha='right', fontsize=10, color="black")

            # Add values with manual alignment
            axis.text(x_value, y_position, f"{pearson_val:.2f}\n{spearman_val:.2f}\n{mic_val:.2f}", 
                      transform=axis.transAxes, va='center', ha='left', fontsize=10, color="black")

            if regions_df is not None:

                regions = regions_df["region"].unique()

                if regions_2x2:
                    pass
                else:
                    for i, region in enumerate(regions):

                        measures_values = measures_regions_df_dict[region].loc[input_col, output_col]

                        # Extract and round the values
                        pearson_val = np.round(abs(measures_values['pearson']), 2)
                        spearman_val = np.round(abs(measures_values['spearman']), 2)
                        mic_val = np.round(abs(measures_values['MIC']), 2)
                        # Add values with manual alignment
                        axis.text(x_value + 0.055 * (1 + i), y_position, f"{pearson_val:.2f}\n{spearman_val:.2f}\n{mic_val:.2f}", 
                                  transform=axis.transAxes, va='center', ha='left', fontsize=10, color=palette[region])

        fig.savefig(dst_path.joinpath(f"{input_col}_{output_col}.png"), dpi=300)

        plt.close()


### Global

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS
                                )

### Gnann regions

In [None]:
RAW_DATA_FOLDER_PATH = Path("../data/raw")

domains_df = pd.read_csv(RAW_DATA_FOLDER_PATH.joinpath("ISIMIP_2b_aggregated_variables", "domains.csv"))
domains_df = domains_df[["lon", "lat", "domain_days_below_1_0.08_aridity_netrad"]]
regions_df = domains_df.rename(columns={"domain_days_below_1_0.08_aridity_netrad": "region"})
regions_df = regions_df.set_index(["lon", "lat"])

regions_df

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots_regions_2x2_Gnann"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS,
                                regions_df=regions_df,
                                regions_2x2=True
                                )

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots_regions_Gnann"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS,
                                regions_df=regions_df,
                                regions_2x2=False
                                )

### Global with measures

In [None]:
CWATM_MEASURES_FOLDER = Path("../data/processed", "bivariate_metrics", "CWatM")

measures_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_global.csv"),
                          index_col=["input", "output"])
measures_dc_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_dry cold.csv"),
                          index_col=["input", "output"])
measures_dw_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_dry warm.csv"),
                          index_col=["input", "output"])
measures_wc_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_wet cold.csv"),
                          index_col=["input", "output"])
measures_ww_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_wet warm.csv"),
                          index_col=["input", "output"])

measures_regions_df_dict = {
    "dry cold": measures_dc_df,
    "dry warm": measures_dw_df,
    "wet cold": measures_wc_df,
    "wet warm": measures_ww_df,
}

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots_with_measures"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS,
                                measures_df=measures_df
                                )

### Gnann regions with measures

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots_regions_Gnann_with_measures"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS,
                                regions_df=regions_df,
                                regions_2x2=False,
                                measures_df=measures_df,
                                measures_regions_df_dict=measures_regions_df_dict
                                )

### Chanleng regions with measures

In [216]:
regions_df = pd.read_csv("../data/processed/CWatM_data/chanleng_regions.csv", index_col=["lon", "lat"])

CWATM_MEASURES_FOLDER = Path("../data/processed", "bivariate_metrics", "CWatM")

measures_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_global.csv"),
                          index_col=["input", "output"])
measures_1_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_chanleng_1.csv"),
                          index_col=["input", "output"])
measures_2_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_chanleng_2.csv"),
                          index_col=["input", "output"])
measures_3_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_chanleng_3.csv"),
                          index_col=["input", "output"])
measures_4_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_chanleng_4.csv"),
                          index_col=["input", "output"])

measures_regions_df_dict = {
    1: measures_1_df,
    2: measures_2_df,
    3: measures_3_df,
    4: measures_4_df,
}

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots_regions_chanleng_with_measures"),
                                valid_x=["chanleng"],
                                valid_y=OUTPUTS_COLUMNS,
                                regions_df=regions_df,
                                regions_2x2=False,
                                measures_df=measures_df,
                                measures_regions_df_dict=measures_regions_df_dict,
                                regions_color_palette={1: "red",
                                                       2: "blue",
                                                       3: "green",
                                                       4: "purple"},
                                )

## Interactive

In [None]:
from src.visualization import visualize


data_df = data_df.sample(frac=0.2)

visualize.plot_scatter_with_dropdown(df=data_df,
                                     default_x="pr",
                                     default_y="potevap",
                                     valid_x=sorted(INPUTS_COLUMNS + FORCINGS_COLUMNS),
                                     valid_y=sorted(OUTPUTS_COLUMNS))