# Analyse measures of dependence from CWatM data

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append("..")

from pathlib import Path

from tqdm import tqdm
import numpy as np
import pandas as pd
import dataframe_image as dfi
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

import src.data.cwatm_data as cwatm_data
import src.visualization.visualize as visualize

In [2]:
CWATM_MEASURES_FOLDER = Path("../data/processed", "bivariate_metrics", "CWatM")

REGIONS = ["dry cold", "dry warm", "wet cold", "wet warm"]

## Load data

In [None]:
measures_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_global.csv"),
                          index_col=["input", "output"])
measures_dc_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_dry cold.csv"),
                          index_col=["input", "output"])
measures_dw_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_dry warm.csv"),
                          index_col=["input", "output"])
measures_wc_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_wet cold.csv"),
                          index_col=["input", "output"])
measures_ww_df = pd.read_csv(CWATM_MEASURES_FOLDER.joinpath("measures_wet warm.csv"),
                          index_col=["input", "output"])

measures_df

In [4]:
import json

with open("../data/processed/CWatM_data/variable_group_dict.json", "r") as fp:
    variable_group_dict = json.load(fp)

## Process measures results

This processing consists of:
1. Dropping un-interesting variable pairs (either `NaN` pearson or `0` MIC)
2. Selecting the relevant measures for this study:
    - pearson
    - spearman
    - MIC
    - MAS
    - MEV
3. Computing the rank for each measure
4. Computing the p-value of the top scoring 10% of variable pairs (by MIC)
    - Using the shuffled (permutated) measures

With the above data as a basis. We will continue by:

1. Using the Benjamini and Hochberg procedure to control FDR at `alpha = 0.05` (for MIC)
2. On the statistically significant variable pairs:
    - Computing the non-linearity score MIC - pearson^2

Tables and visuals:
- Table of the top scoring 10% variable pairs with the 5 measures and MIC-pearson^2. With rank and p-value.
- (Pearson) correlation matrix of the 5 measures and MIC-pearson^2
- Scatter plot of p-values and the Benjamini and Hochberg line

In [None]:
measures_df = cwatm_data.process_measures_df(measures_df)
measures_dc_df = cwatm_data.process_measures_df(measures_dc_df)
measures_dw_df = cwatm_data.process_measures_df(measures_dw_df)
measures_wc_df = cwatm_data.process_measures_df(measures_wc_df)
measures_ww_df = cwatm_data.process_measures_df(measures_ww_df)

measures_df

In [None]:
ranks_measures_df = cwatm_data.compute_ranks_df(measures_df)
ranks_measures_dc_df = cwatm_data.compute_ranks_df(measures_dc_df)
ranks_measures_dw_df = cwatm_data.compute_ranks_df(measures_dw_df)
ranks_measures_wc_df = cwatm_data.compute_ranks_df(measures_wc_df)
ranks_measures_ww_df = cwatm_data.compute_ranks_df(measures_ww_df)

ranks_measures_df

In [None]:
p_values_measures_df = cwatm_data.compute_p_values_complete(
    shuffled_data_path=CWATM_MEASURES_FOLDER.joinpath("shuffled"),
    region="global",
    actual_df=measures_df
)
p_values_measures_dc_df = cwatm_data.compute_p_values_complete(
    shuffled_data_path=CWATM_MEASURES_FOLDER.joinpath("shuffled"),
    region="dry cold",
    actual_df=measures_dc_df
)
p_values_measures_dw_df = cwatm_data.compute_p_values_complete(
    shuffled_data_path=CWATM_MEASURES_FOLDER.joinpath("shuffled"),
    region="dry warm",
    actual_df=measures_dw_df
)
p_values_measures_wc_df = cwatm_data.compute_p_values_complete(
    shuffled_data_path=CWATM_MEASURES_FOLDER.joinpath("shuffled"),
    region="wet cold",
    actual_df=measures_wc_df
)
p_values_measures_ww_df = cwatm_data.compute_p_values_complete(
    shuffled_data_path=CWATM_MEASURES_FOLDER.joinpath("shuffled"),
    region="wet warm",
    actual_df=measures_ww_df
)

p_values_measures_df

In [None]:
significant_p_values_series, benjamini_hochberg_data = cwatm_data.control_FDR_benjamini_hochberg(
    p_values_series_in=p_values_measures_df["MIC_p-value"],
    alpha=0.05
)
significant_p_values_dc_series, benjamini_hochberg_dc_data = cwatm_data.control_FDR_benjamini_hochberg(
    p_values_series_in=p_values_measures_dc_df["MIC_p-value"],
    alpha=0.05
)
significant_p_values_dw_series, benjamini_hochberg_dw_data = cwatm_data.control_FDR_benjamini_hochberg(
    p_values_series_in=p_values_measures_dw_df["MIC_p-value"],
    alpha=0.05
)
significant_p_values_wc_series, benjamini_hochberg_wc_data = cwatm_data.control_FDR_benjamini_hochberg(
    p_values_series_in=p_values_measures_wc_df["MIC_p-value"],
    alpha=0.05
)
significant_p_values_ww_series, benjamini_hochberg_ww_data = cwatm_data.control_FDR_benjamini_hochberg(
    p_values_series_in=p_values_measures_ww_df["MIC_p-value"],
    alpha=0.05
)

significant_p_values_series

In [None]:
def plot_benjamini_hochberg_results(bh_data,
                                    title: str,
                                    fig = None,
                                    axes = None
                                    ):
    
    if fig is None or axes is None:
        fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 6))
    
    # Plot the sorted p-values
    axes.plot(bh_data["data"]['rank'], bh_data["data"]['p_value'], marker='x', linestyle='none', label='P-values')
    
    # Plot the BH critical line
    axes.plot(bh_data["data"]['rank'], bh_data["data"]['bh_critical_value'], color='red', label='BH Critical Value')
    
    # Add horizontal line at alpha level
    axes.axhline(y=bh_data["alpha"], color='grey', linestyle='--', label=f'Alpha = {bh_data["alpha"]}')
    
    # Highlight significant points
    significant_points = bh_data["data"]['p_value'] <= bh_data['threshold_p_value']
    axes.scatter(bh_data["data"]['rank'][significant_points],
                bh_data["data"]['p_value'][significant_points],
                color='green',
                label='Significant',
                zorder=5)
    
    axes.grid(True)
    axes.legend()
    fig.supxlabel('Rank of P-value', fontsize=12)
    fig.supylabel('P-value', fontsize=12)
    fig.suptitle(f'Benjamini-Hochberg Procedure - {title}', fontsize=16)

    return fig


# Plot the results
fig = plot_benjamini_hochberg_results(benjamini_hochberg_data, title="Global")
fig_dc = plot_benjamini_hochberg_results(benjamini_hochberg_dc_data, title="Dry Cold")
fig_dw = plot_benjamini_hochberg_results(benjamini_hochberg_dw_data, title="Dry Warm")
fig_wc = plot_benjamini_hochberg_results(benjamini_hochberg_wc_data, title="Wet Cold")
fig_ww = plot_benjamini_hochberg_results(benjamini_hochberg_ww_data, title="Wet Warm")

fig.savefig(f"../reports/figures/CWatM_data/plot_benjamini_hochberg_global.png", dpi=300)
fig_dc.savefig(f"../reports/figures/CWatM_data/plot_benjamini_hochberg_dc.png", dpi=300)
fig_dw.savefig(f"../reports/figures/CWatM_data/plot_benjamini_hochberg_dw.png", dpi=300)
fig_wc.savefig(f"../reports/figures/CWatM_data/plot_benjamini_hochberg_wc.png", dpi=300)
fig_ww.savefig(f"../reports/figures/CWatM_data/plot_benjamini_hochberg_ww.png", dpi=300)

In [9]:

significant_measures_df = measures_df[significant_p_values_series]
significant_measures_df = cwatm_data.compute_non_linearity(significant_measures_df)
significant_ranks_measures_df = cwatm_data.compute_ranks_df(significant_measures_df)

significant_measures_dc_df = measures_dc_df[significant_p_values_dc_series]
significant_measures_dc_df = cwatm_data.compute_non_linearity(significant_measures_dc_df)
significant_ranks_measures_dc_df = cwatm_data.compute_ranks_df(significant_measures_dc_df)

significant_measures_dw_df = measures_dw_df[significant_p_values_dw_series]
significant_measures_dw_df = cwatm_data.compute_non_linearity(significant_measures_dw_df)
significant_ranks_measures_dw_df = cwatm_data.compute_ranks_df(significant_measures_dw_df)

significant_measures_wc_df = measures_wc_df[significant_p_values_wc_series]
significant_measures_wc_df = cwatm_data.compute_non_linearity(significant_measures_wc_df)
significant_ranks_measures_wc_df = cwatm_data.compute_ranks_df(significant_measures_wc_df)

significant_measures_ww_df = measures_ww_df[significant_p_values_ww_series]
significant_measures_ww_df = cwatm_data.compute_non_linearity(significant_measures_ww_df)
significant_ranks_measures_ww_df = cwatm_data.compute_ranks_df(significant_measures_ww_df)

## Visualize results

In [None]:
pd.set_option('display.max_rows', 500)

In [None]:


# top_10percent_table_dc_df = significant_measures_dc_df.join([significant_ranks_measures_dc_df, p_values_measures_dc_df],
#                                                             how="left")

# top_10percent_table_dc_df = top_10percent_table_dc_df[["pearson",   "pearson_rank",  # "pearson_p-value",  
#                                                     "spearman",  "spearman_rank", # "spearman_p-value",
#                                                     "MIC",       "MIC_rank",      # "MIC_p-value",
#                                                     "MAS",       "MAS_rank",      # "MAS_p-value",
#                                                     "MEV",       "MEV_rank",      # "MEV_p-value",
#                                                     "MIC - p^2", "MIC - p^2_rank"
#                                                     ]]

# top_10percent_table_dc_df = top_10percent_table_dc_df.convert_dtypes()

# top_10percent_table_dc_df.sort_values("MIC", ascending=False).head(n=int(0.1 * len(top_10percent_table_dc_df))).round(3)

### 1. The top inputs for each metric

#### 1.1 Separate tables for each region

In [None]:
def make_top_n_inputs_by_measure_for_each_output(input_df: pd.DataFrame,
                                                 top_n: int = 6,
                                                 with_value: bool = True):
    
    outputs = ["evap-total", "potevap", "tws", "qtot", "qr"]
    stats = ["pearson", "spearman", "MIC"]
    if with_value:
        subcols = ["input", "value"]
    else:
        subcols = ["input"]

    ind_tuples = [(output, rank + 1) for output in outputs for rank in range(top_n)]
    index = pd.MultiIndex.from_tuples(ind_tuples, names=["output", "rank"])

    col_tuples = [(stat, subcol) for stat in stats for subcol in subcols]
    columns = pd.MultiIndex.from_tuples(col_tuples, names=["measure", "subcol"])

    result_df = pd.DataFrame(index=index, columns=columns)

    temp_df = input_df.copy()

    for output in outputs:
        df_output = temp_df.xs(output, level='output')
        for stat in stats:
            df_stat_sorted = df_output.sort_values(by=stat, ascending=False)
            top = df_stat_sorted.iloc[:top_n]
            inputs = top.index.values
            values = top[stat].values
            for i, (input_var, value) in enumerate(zip(inputs, values)):
                rank = i + 1
                result_df.loc[(output, rank), (stat, "input")] = input_var
                if with_value:
                    result_df.loc[(output, rank), (stat, "value")] = str(np.round(value, 2))

    return result_df


# Define cell selection logic for bold styling
def bold_cells_global(row):
    styles = []
    print(row)
    print(row.index)
    for idx in row.index:
        # Access MultiIndex level 0 for the current row
        if isinstance(row.name, tuple) and row.name[0] == "evap-total" and row[idx] == "huss":
            # print(idx)
            styles.append("font-weight: bold;")
        else:
            styles.append("")
    return styles


def style_table_for_export(df: pd.DataFrame,
                           regions: bool = False,
                           bold_cells_function = None
                           ):

    styled_df = df.style

    if bold_cells_function is not None:
        # Apply the bold formatting
        styled_df = styled_df.apply(bold_cells_function, axis="columns", result_type="expand")

    for _, group_df in df.groupby("output"):
        styled_df.set_table_styles({group_df.index[-1]: [{'selector': '', 'props': 'border-bottom: 2px solid black;'}],
                                    group_df.index[0]:  [{'selector': '.level0', 'props': 'border-bottom: 2px solid black;'}],}, 
                                    overwrite=False, axis=1)

    if regions:
        # Add vertical lines for column group boundaries
        if "value" in df.columns.get_level_values("subcol"):
            column_group_labels = pd.Series({
                # ("global", "MIC", "value"): ("global", "MIC", "value"),
                ("dry cold", "MIC", "value"): ("dry cold", "MIC", "value"),
                ("dry warm", "MIC", "value"): ("dry warm", "MIC", "value"),
                ("wet cold", "MIC", "value"): ("wet cold", "MIC", "value"),
            })
        else:
            column_group_labels = pd.Series({
                # ("global", "MIC", "input"): ("global", "MIC", "input"),
                ("dry cold", "MIC", "input"): ("dry cold", "MIC", "input"),
                ("dry warm", "MIC", "input"): ("dry warm", "MIC", "input"),
                ("wet cold", "MIC", "input"): ("wet cold", "MIC", "input"),
            })
            
        column_styles = []

        for _, col_group in column_group_labels.groupby(column_group_labels):
            last_column = col_group.index[-1]  # Get the last column in the group
            column_idx = df.columns.get_loc(last_column)  # Get the index of the column
            column_styles.append({'selector': f'th.col{column_idx}, td.col{column_idx}',
                                  'props': 'border-right: 2px solid black;'})

        # Apply vertical line styles
        styled_df = styled_df.set_table_styles(column_styles, overwrite=False)

    return styled_df


top_inputs_by_measure_for_each_output = make_top_n_inputs_by_measure_for_each_output(
    significant_measures_df,
    with_value=False
)

top_inputs_by_measure_for_each_output_with_value = make_top_n_inputs_by_measure_for_each_output(
    significant_measures_df,
    with_value=True
)

top_inputs_by_measure_for_each_output.to_csv(
    "../data/processed/bivariate_metrics/CWatM/summary_tables/top_inputs_by_measure_for_each_output_global.csv"
)
dfi.export(style_table_for_export(top_inputs_by_measure_for_each_output,
                                  bold_cells_function=bold_cells_global),
           "../reports/tables/CWatM_data/top_inputs_by_measure_for_each_output_global.png",
           table_conversion="chrome",
           dpi=300)

top_inputs_by_measure_for_each_output_with_value.to_csv(
    "../data/processed/bivariate_metrics/CWatM/summary_tables/top_inputs_by_measure_for_each_output_global_with_value.csv"
)
dfi.export(style_table_for_export(top_inputs_by_measure_for_each_output_with_value,
                                  bold_cells_function=bold_cells_global),
           "../reports/tables/CWatM_data/top_inputs_by_measure_for_each_output_global_with_value.png",
           table_conversion="chrome",
           dpi=300)

# display(style_table_for_export(top_inputs_by_measure_for_each_output))

#### 1.2 Single table regions

In [272]:
def make_top_n_inputs_by_measure_for_each_output_with_regions(global_df: pd.DataFrame,
                                                              region_dc_df: pd.DataFrame,
                                                              region_dw_df: pd.DataFrame,
                                                              region_wc_df: pd.DataFrame,
                                                              region_ww_df: pd.DataFrame,
                                                              top_n: int = 6,
                                                              with_value: bool = True,
                                                              top_column: str = "measure"):
    
    # table_global = make_top_n_inputs_by_measure_for_each_output(global_df, top_n=top_n, with_value=with_value)
    table_dc = make_top_n_inputs_by_measure_for_each_output(region_dc_df, top_n=top_n, with_value=with_value)
    table_dw = make_top_n_inputs_by_measure_for_each_output(region_dw_df, top_n=top_n, with_value=with_value)
    table_wc = make_top_n_inputs_by_measure_for_each_output(region_wc_df, top_n=top_n, with_value=with_value)
    table_ww = make_top_n_inputs_by_measure_for_each_output(region_ww_df, top_n=top_n, with_value=with_value)

    # table_global = pd.concat([table_global], keys=["global"], names=["region"], axis=1)
    table_dc = pd.concat([table_dc], keys=["dry cold"], names=["region"], axis=1)
    table_dw = pd.concat([table_dw], keys=["dry warm"], names=["region"], axis=1)
    table_wc = pd.concat([table_wc], keys=["wet cold"], names=["region"], axis=1)
    table_ww = pd.concat([table_ww], keys=["wet warm"], names=["region"], axis=1)

    # table_all = table_global.join([table_dc, table_dw, table_wc, table_ww])
    table_all = table_dc.join([table_dw, table_wc, table_ww])

    if top_column == "measure":
        table_all = table_all.reorder_levels(["measure", "region", "subcol"], axis="columns").sort_index(axis="columns", level="measure")
    elif top_column == "region":
        table_all = table_all.reorder_levels(["region", "measure", "subcol"], axis="columns").sort_index(axis="columns", level="region")
    else:
        raise NotImplementedError(f"top_column '{top_column}' not implemented")
    table_all = table_all.reindex(columns=["pearson", "spearman", "MIC"], level="measure")
    # table_all = table_all.reindex(columns=["global", "dry cold", "dry warm", "wet cold", "wet warm"], level="region")
    table_all = table_all.reindex(columns=["dry cold", "dry warm", "wet cold", "wet warm"], level="region")

    return table_all

top_inputs_by_measure_for_each_output_with_regions_and_value = make_top_n_inputs_by_measure_for_each_output_with_regions(
    significant_measures_df,
    significant_measures_dc_df,
    significant_measures_dw_df,
    significant_measures_wc_df,
    significant_measures_ww_df,
    top_n=6,
    with_value=True,
    top_column="region"
)

top_inputs_by_measure_for_each_output_with_regions = make_top_n_inputs_by_measure_for_each_output_with_regions(
    significant_measures_df,
    significant_measures_dc_df,
    significant_measures_dw_df,
    significant_measures_wc_df,
    significant_measures_ww_df,
    top_n=6,
    with_value=False,
    top_column="region"
)


top_inputs_by_measure_for_each_output_with_regions_and_value.to_csv(
    "../data/processed/bivariate_metrics/CWatM/summary_tables/top_inputs_by_measure_for_each_output_regions_with_value.csv"
)
dfi.export(style_table_for_export(top_inputs_by_measure_for_each_output_with_regions_and_value,
                                  regions=True),
           "../reports/tables/CWatM_data/top_inputs_by_measure_for_each_output_regions_with_value.png",
           table_conversion="chrome",
           dpi=300)

top_inputs_by_measure_for_each_output_with_regions.to_csv(
    "../data/processed/bivariate_metrics/CWatM/summary_tables/top_inputs_by_measure_for_each_output_regions.csv"
)
dfi.export(style_table_for_export(top_inputs_by_measure_for_each_output_with_regions,
                                  regions=True),
           "../reports/tables/CWatM_data/top_inputs_by_measure_for_each_output_regions.png",
           table_conversion="chrome",
           dpi=300)

# display(style_table_for_export(top_inputs_by_measure_for_each_output_with_regions_and_value))
# display(style_table_for_export(top_inputs_by_measure_for_each_output_with_regions))

#### 1.3 Line plots of top ranking inputs

In [None]:
def plot_values_and_rank(input_df,
                         region,
                         n_top = 15):
    temp_df = input_df.copy()
    fig = visualize.plot_measure_values_and_rank(
        measures_df=temp_df,
        # measures=["pearson", "spearman", "MIC"],
        measures=["pearson", "spearman", "MIC", "MIC - p^2", "MAS"],
        sort_values_by="MIC",
        n_top=n_top
    )
    fig.suptitle(region.upper(), x=0.38)
    return fig

N_TOP = 15

measures_values_and_rank_fig = plot_values_and_rank(significant_measures_df, region="global", n_top=N_TOP)
# plot_values_and_rank(significant_measures_dc_df, region="dry cold")
# plot_values_and_rank(significant_measures_dw_df, region="dry warm")
# plot_values_and_rank(significant_measures_wc_df, region="wet cold")
# plot_values_and_rank(significant_measures_ww_df, region="wet warm")

# measures_values_and_rank_fig.savefig(f"../reports/figures/CWatM_data/measures_values_and_rank_{N_TOP}.png", dpi=300)


In [246]:

def prepare_data_for_line_plot(top_inputs_by_measure_for_each_output_df, 
                               measure_level_name='measure', 
                               input_level_name='subcol', 
                               rank_name='rank', 
                               output_name='output'):
    
    # Ensure that the input DataFrame has multi-level columns
    if not isinstance(top_inputs_by_measure_for_each_output_df.columns, pd.MultiIndex):
        raise ValueError("DataFrame must have MultiIndex columns.")
        
    # Similarly, ensure that the index is multi-level and contains something like (output, rank)
    if not isinstance(top_inputs_by_measure_for_each_output_df.index, pd.MultiIndex):
        raise ValueError("DataFrame must have MultiIndex rows for this transformation.")
    
    # Identify the levels in the column MultiIndex
    col_levels = top_inputs_by_measure_for_each_output_df.columns.names
    if measure_level_name not in col_levels:
        raise ValueError(f"'{measure_level_name}' not found in column levels: {col_levels}")
        
    if input_level_name not in col_levels:
        raise ValueError(f"'{input_level_name}' not found in column levels: {col_levels}")
    
    # Identify the levels in the row MultiIndex
    row_levels = top_inputs_by_measure_for_each_output_df.index.names
    if rank_name not in row_levels or output_name not in row_levels:
        raise ValueError(f"Row index must contain levels '{output_name}' and '{rank_name}'. Found: {row_levels}")
    
    stacked_df = top_inputs_by_measure_for_each_output_df.stack(level=measure_level_name, future_stack=True)
    
    if isinstance(stacked_df, pd.Series):
        # If it's a Series, we reset_index to turn it into a DataFrame
        long_df = stacked_df.reset_index()
        # long_df = long_df.rename(columns={0: input_level_name})
    else:
        long_df = stacked_df.reset_index()
    
    long_df.columns.name = None

    return long_df


def plot_line_measures_ranks(long_df: pd.DataFrame,
                             max_rank: int = None,
                             variable_group_dict = None,
                             fig = None,
                             axes = None
                             ):
    
    outputs = long_df["output"].unique().tolist()
    output_title_mapping = {
        "evap-total": "Total evapotranspiration (evap-total)",
        "potevap": "Potential evapotranspiration (potevap)",
        "tws": "Total water storage (tws)",
        "qtot": "Total runoff (qtot)",
        "qr": "Ground water recharge (qr)"
    }
    variable_group_color_mapping = {
        "Ground Water": "darkblue",
        "Land Cover": "orange",
        "Land Surface": "purple",
        "Routing": "blue",
        "Soil": "brown",
        "Forcings": "red"
    }

    measures = ["pearson", "spearman", "MIC"]
    
    # Identify the unique measures and ranks
    ranks = sorted(long_df["rank"].unique())

    # Assign a numerical position to each measure for plotting on x-axis
    # This might be needed if measures are strings.
    measure_positions = {m: i for i, m in enumerate(measures)}
    
    if fig is None or axes is None:
        fig, axes = plt.subplots(nrows=len(outputs), ncols=1, figsize=(8, 12),
                                 sharex=True)
    
    if len(outputs) == 1:
        axes = [axes]

    for i, output in enumerate(outputs):

        temp_df = long_df[long_df["output"] == output]

        # Group data by input variable so that we plot each input's line/points separately
        grouped = temp_df.groupby("input")
        text_counter = 0
        
        # Plot each group of input data
        for input_val, group_df in grouped:
            # Sort by measure to ensure a consistent line
            group_df = group_df.assign(measure_order=group_df["measure"].map(measure_positions))
            group_df = group_df.sort_values('measure_order')

            # Convert measure names to their x positions
            x_vals = group_df["measure_order"].values
            y_vals = group_df["rank"].values
            measures_vals = group_df["value"].astype(float).values

            # Solve gaps in x_vals that lead to lines crossing the plot through the middle
            if max_rank is not None and len(x_vals) > 1:
                new_x = []
                new_y = []
                new_measures = []
                for idx in range(len(x_vals) - 1):
                    # Append the current point
                    new_x.append(x_vals[idx])
                    new_y.append(y_vals[idx])
                    new_measures.append(measures_vals[idx])

                    # Check for a gap
                    gap = x_vals[idx+1] - x_vals[idx]
                    if gap > 1:
                        # Insert intermediate points for each integer step in the gap
                        for intermediate_x in range(x_vals[idx] + 1, x_vals[idx+1]):
                            new_x.append(intermediate_x)
                            new_y.append(max_rank * 2)  # Add the placeholder y-value
                            new_measures.append(0.0)

                # Don't forget the last original point
                new_x.append(x_vals[-1])
                new_y.append(y_vals[-1])
                new_measures.append(measures_vals[-1])

                # Replace old x_vals, y_vals with the expanded ones
                x_vals = np.array(new_x)
                y_vals = np.array(new_y)
                measures_vals = np.array(new_measures)

            if variable_group_dict is not None and input_val in variable_group_dict.keys():
                variable_color = variable_group_color_mapping[variable_group_dict[input_val]]
            else:
                variable_color = "black"

            # Plot the line connecting points for this input
            # If input appears in multiple measures, you'll get a line connecting them.
            # If it only appears once, it'll be a single point.
            axes[i].plot(x_vals, y_vals, marker=None, color="grey", linestyle="--", alpha=0.5)

            # Create a mask for hollow vs. solid
            hollow_mask = measures_vals < 0.6
            grey_mask = (0.6 <= measures_vals) & (measures_vals < 0.7)
            solid_mask = 0.7 <= measures_vals

            # Plot hollow points
            axes[i].plot(x_vals[hollow_mask], y_vals[hollow_mask],
                         marker="o", linestyle="", label=None,
                        #  color=variable_color,
                        color="black",
                         markerfacecolor="white",
                        #  alpha = 0.5
                         )

            # Plot grey points
            axes[i].plot(x_vals[grey_mask], y_vals[grey_mask],
                         marker="o", linestyle="", label=None,
                        #  color=variable_color,
                        color="black",
                         markerfacecolor="grey",
                        #  alpha = 0.75
                         )

            # Plot solid points
            axes[i].plot(x_vals[solid_mask], y_vals[solid_mask],
                         marker="o", linestyle="", label=None,
                        #  color=variable_color,
                        color="black",
                         markerfacecolor="black",
                        #  alpha = 1.0
                         )

            # axes[i].plot(x_vals, y_vals, marker='o', color="black", linestyle="", label=input_val)
        
            # Add text at the last point
            if len(x_vals) > 0 and (y_vals[-1] < max_rank + 1).any():

                axes[i].text(x_vals[-1] + 0.03, y_vals[-1] + 0.0, input_val, 
                             va='center', ha='left', fontsize=8, color=variable_color)
                text_counter += 1

        # Set the x-ticks to the measure labels
        axes[i].set_xticks(range(len(measures)))
        axes[i].set_xticklabels(measures)

        # Set y-ticks to ranks if you wish
        axes[i].set_yticks(ranks)
        if max_rank is not None:
            axes[i].set_ylim(0.5, max_rank + 0.5)
        axes[i].invert_yaxis()
        
        # axes[i].set_ylabel("rank")
        axes[i].set_title(output_title_mapping[output])

    axes[i].set_xlim(None, 2 + 0.85)


    from matplotlib.lines import Line2D

    variable_group_legend_elements = []
    for group, color in variable_group_color_mapping.items():
        legend_element = Line2D([0], [0], marker='s', color='none', label=group,
                                markerfacecolor=color, markersize=6)
        variable_group_legend_elements.append(legend_element)

    legend1 = axes[i].legend(handles=variable_group_legend_elements, title="Variable Groups",
                             markerscale=1.5,
                             loc="lower left",
                             ncol=3,
                             bbox_to_anchor=(0.05, -0.9),
                             )
    
    fig.add_artist(legend1)
    
    from matplotlib.patches import Circle
    from matplotlib.legend_handler import HandlerPatch

    measure_value_legend_elements = [
        Circle((0, 0), radius=0.5, facecolor="black", edgecolor="black", label="> 0.7"),
        Circle((0, 0), radius=0.5, facecolor="grey", edgecolor="black", label="> 0.6, < 0.7"),
        Circle((0, 0), radius=0.5, facecolor="white", edgecolor="black", label="< 0.6")
    ]

    # Custom handler to make the Circle patches appear as circular in the legend
    class HandlerCircle(HandlerPatch):
        def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
            center = (xdescent + width / 2, ydescent + height / 2)
            radius = min(width, height) / 2
            p = Circle(center, radius)
            self.update_prop(p, orig_handle, legend)
            p.set_transform(trans)
            return [p]

    legend = axes[i].legend(handles=measure_value_legend_elements, title="Measure value",
                            markerscale=1.5,
                            loc="lower right",
                            # ncol=3,
                            bbox_to_anchor=(0.95, -0.9),
                            handler_map={Circle: HandlerCircle()}
                            )
    fig.add_artist(legend)


    # axes[i].set_xsuplabel("measure")
    fig.supxlabel("Measure of dependence", y=0.105)
    fig.supylabel("Rank")
    # fig.supylabel("Output variable")
    fig.suptitle(f"Ranking of the top {max_rank} inputs for each output", fontsize=15)

    # fig.subplots_adjust(bottom=0.1)
    fig.tight_layout()

    return fig

def plot_line_measures_ranks_for_regions(long_dc_df: pd.DataFrame,
                                         long_dw_df: pd.DataFrame,
                                         long_wc_df: pd.DataFrame,
                                         long_ww_df: pd.DataFrame,
                                         max_rank: int
                                         ):
    
    fig, axes = plt.subplots(nrows=5, ncols=4, figsize=(8 * 4, 12),
                             sharex="col", sharey="row")
    
    axes_dc = axes[:, 0]
    axes_dw = axes[:, 1]
    axes_wc = axes[:, 2]
    axes_ww = axes[:, 3]

    plot_line_measures_ranks(long_df=long_dc_df, max_rank=max_rank, variable_group_dict=variable_group_dict, fig=fig, axes=axes_dc)
    plot_line_measures_ranks(long_df=long_dw_df, max_rank=max_rank, variable_group_dict=variable_group_dict, fig=fig, axes=axes_dw)
    plot_line_measures_ranks(long_df=long_wc_df, max_rank=max_rank, variable_group_dict=variable_group_dict, fig=fig, axes=axes_wc)
    plot_line_measures_ranks(long_df=long_ww_df, max_rank=max_rank, variable_group_dict=variable_group_dict, fig=fig, axes=axes_ww)

    fig.supxlabel("Measure of dependence", fontsize=16)
    fig.supylabel("Rank", x=0.0, fontsize=16)
    fig.suptitle(f"Ranking of the top {max_rank} inputs for each output - Regions", fontsize=22)
    fig.subplots_adjust(left=0.015, top=0.9)
    # fig.tight_layout()

    # You can pick one axis from each subset (top-left ax for A, top-right ax for B)
    axis_dc_top = axes_dc[0]
    axis_dw_top = axes_dw[0]
    axis_wc_top = axes_wc[0]
    axis_ww_top = axes_ww[0]
    
    # Place text above these axes in data or axes coordinates
    axis_dc_top.text(0.5, 1.18, "Dry Cold", transform=axis_dc_top.transAxes, ha='center', fontsize=16)
    axis_dw_top.text(0.5, 1.18, "Dry Warm", transform=axis_dw_top.transAxes, ha='center', fontsize=16)
    axis_wc_top.text(0.5, 1.18, "Wet Cold", transform=axis_wc_top.transAxes, ha='center', fontsize=16)
    axis_ww_top.text(0.5, 1.18, "Wet Warm", transform=axis_ww_top.transAxes, ha='center', fontsize=16)

    return fig

In [None]:
fig = plot_line_measures_ranks(
    long_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_df,
            top_n=6,
            with_value=True
        )
    ),
    max_rank=6,
    variable_group_dict=variable_group_dict
)

fig.savefig("test.png")

In [None]:
TOP_N = 6

fig_global = plot_line_measures_ranks(
    long_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    max_rank=TOP_N,
    variable_group_dict=variable_group_dict
)
fig_global.suptitle(f"Ranking of the top {TOP_N} inputs for each output - Global")

fig_dc = plot_line_measures_ranks(
    long_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_dc_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    max_rank=TOP_N,
    variable_group_dict=variable_group_dict
)
fig_dc.suptitle(f"Ranking of the top {TOP_N} inputs for each output - Dry Cold")

fig_dw = plot_line_measures_ranks(
    long_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_dw_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    max_rank=TOP_N,
    variable_group_dict=variable_group_dict
)
fig_dw.suptitle(f"Ranking of the top {TOP_N} inputs for each output - Dry Warm")

fig_wc = plot_line_measures_ranks(
    long_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_wc_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    max_rank=TOP_N,
    variable_group_dict=variable_group_dict
)
fig_wc.suptitle(f"Ranking of the top {TOP_N} inputs for each output - Wet Cold")

fig_ww = plot_line_measures_ranks(
    long_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_ww_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    max_rank=TOP_N,
    variable_group_dict=variable_group_dict
)
fig_ww.suptitle(f"Ranking of the top {TOP_N} inputs for each output - Wet Warm")

fig_global.savefig(f"../reports/figures/CWatM_data/plot_top_inputs_by_measure_for_each_output_global.png", dpi=300)
fig_dc.savefig(f"../reports/figures/CWatM_data/plot_top_inputs_by_measure_for_each_output_dry-cold.png", dpi=300)
fig_dw.savefig(f"../reports/figures/CWatM_data/plot_top_inputs_by_measure_for_each_output_dry-warm.png", dpi=300)
fig_wc.savefig(f"../reports/figures/CWatM_data/plot_top_inputs_by_measure_for_each_output_wet-cold.png", dpi=300)
fig_ww.savefig(f"../reports/figures/CWatM_data/plot_top_inputs_by_measure_for_each_output_wet-warm.png", dpi=300)

In [None]:

fig_regions = plot_line_measures_ranks_for_regions(
    long_dc_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_dc_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    long_dw_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_dw_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    long_wc_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_wc_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    long_ww_df=prepare_data_for_line_plot(
        make_top_n_inputs_by_measure_for_each_output(
            input_df=significant_measures_ww_df,
            top_n=TOP_N,
            with_value=True
        )
    ),
    max_rank=TOP_N
)

fig_regions.savefig(f"../reports/figures/CWatM_data/plot_top_inputs_by_measure_for_each_output_regions.png", dpi=300)


### 2. Scatter plots of the relevant variable pairs from section 1

### 3. Bar plot of relationship types and value of MIC in detecting them

#### 3.1 Plot of top measures

In [345]:
def plot_bar_measures(measures_df: pd.DataFrame,
                      ranks_df: pd.DataFrame,
                      measures, 
                      sort_values_by: str = "MIC",
                      n_top: int = 10,
                      fig = None,
                      axes = None,
                      ):
    
    df = measures_df.abs()
    if sort_values_by is not None:
        df = df[measures].sort_values(sort_values_by, ascending=False)
    df = df.head(n_top)

    df_rank = ranks_df.loc[df.index]

    if fig is None or axes is None:
        fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(14,6))

    x = np.arange(len(df))
    width = 0.18
    multiplier = 0

    for measure in measures:
        offset = width * multiplier
        rects = axes.bar(x + offset, df[measure], width, label=measure
                         )
        # axes.bar_label(rects,
        #                labels=df_rank[f"{measure}_rank"].values,
        #                padding=3)
        multiplier += 1
    
    axes.set_xticks(x + width, df.index.to_list(), rotation=90)

    axes.set_ylim(0, 1)
    axes.set_ylabel("Value")
    axes.legend(loc="upper right", ncols=4)

    y_value = 0.7

    axes.relim()
    axes.autoscale_view()
    x_min, x_max = axes.get_xlim()
    axes.hlines(y=y_value, xmin=x_min, xmax=x_max, color='gray', linestyle='--')
    axes.relim()
    # Get the current y-ticks
    y_ticks = axes.get_yticks()
    # Ensure the desired y_value is included in the ticks
    if y_value not in y_ticks:
        y_ticks = list(y_ticks) + [y_value]
    # Set the updated y-ticks
    axes.set_yticks(sorted(y_ticks))

    fig.suptitle("Measures value by input variable")

    return fig


def plot_bar_measures_for_output(measures_df: pd.DataFrame,
                                 ranks_df: pd.DataFrame,
                                 output: str,
                                 measures, 
                                 sort_values_by: str = "MIC",
                                 n_top: int = 10,
                                 fig = None,
                                 axes = None,
                                 ):
    
    measures_df = measures_df[measures_df.index.get_level_values("output") == output]
    ranks_df = ranks_df[ranks_df.index.get_level_values("output") == output]

    fig = plot_bar_measures(measures_df=measures_df,
                            ranks_df=ranks_df,
                            measures=measures,
                            sort_values_by=sort_values_by,
                            n_top=n_top,
                            fig=fig,
                            axes=axes
                            )
    
    fig.suptitle(f"Measures value by input variable for '{output}'")

    return fig


def plot_bar_measures_for_all_outputs(measures_df: pd.DataFrame,
                                      ranks_df: pd.DataFrame,
                                      measures, 
                                      sort_values_by: str = "MIC",
                                      n_top: int = 10,
                                      ):
    
    fig, axes = plt.subplots(nrows=5, ncols=1, figsize=(12, 16))

    outputs = ["evap-total", "potevap", "tws", "qtot", "qr"]
    output_title_mapping = {
        "evap-total": "Total evapotranspiration (evap-total)",
        "potevap": "Potential evapotranspiration (potevap)",
        "tws": "Total water storage (tws)",
        "qtot": "Total runoff (qtot)",
        "qr": "Ground water recharge (qr)"
    }
    variable_group_color_mapping = {
        "Ground Water": "darkblue",
        "Land Cover": "orange",
        "Land Surface": "purple",
        "Routing": "blue",
        "Soil": "brown",
        "Forcings": "red"
    }

    for i, output in enumerate(outputs):

        axis = axes[i]

        plot_bar_measures_for_output(measures_df=measures_df, ranks_df=ranks_df,
                                        output=output,
                                        measures=measures, sort_values_by=sort_values_by, n_top=n_top, fig=fig,
                                        axes=axis,
                                        )
    
        legend = axis.get_legend()
        legend.remove()
        ticks = axis.get_xticks()
        labels = axis.get_xticklabels()
        # print(eval(labels[0].get_text())[0])
        labels = [eval(label.get_text())[0] for label in labels]
        if output in ["qtot", "qr"]:
            axis.set_xticks(ticks=ticks, labels=labels, rotation=5)
        else:
            axis.set_xticks(ticks=ticks, labels=labels, rotation=0)
        axis.set_ylabel("")

        axis.set_title(output_title_mapping[output], fontsize=16)

    fig.supylabel("Value", 
                  x=0.005,
                  fontsize=16
                  )
    fig.supxlabel("Input variable",
                  y=0.05,
                  fontsize=16
                  )

    # Get the handles and labels
    handles, labels = axes[0].get_legend_handles_labels()

    fig.legend(handles=handles,
                labels=labels,
                title="Metrics",
                title_fontsize="large",
                # labels=measures,    
                markerscale=3,
                handleheight=2,
                loc="lower center",
                ncol=4,
                )

    fig.suptitle(f"Metrics profile of the top {n_top} input for each output - Global", fontsize=18)
    fig.tight_layout()
    fig.subplots_adjust(
        top=0.935,
        bottom=0.1,
    )

    return fig



In [346]:
# fig = plot_bar_measures(measures_df=significant_measures_df,
#                         ranks_df=significant_ranks_measures_df,
#                         measures=["pearson", "spearman", "MIC", "MAS"],
#                         sort_values_by="MIC",
#                         n_top=50
#                         )

In [347]:
# fig = plot_bar_measures_for_output(measures_df=significant_measures_df,
#                                    ranks_df=significant_ranks_measures_df,
#                                    output="qr",
#                                    measures=["pearson", "spearman", "MIC", "MAS"],
#                                    sort_values_by="MIC",
#                                    n_top=10
#                                    )

In [None]:
fig = plot_bar_measures_for_all_outputs(measures_df=significant_measures_df,
                                    ranks_df=significant_ranks_measures_df,
                                   measures=["pearson", "spearman", "MIC", "MAS"],
                                   sort_values_by="MIC",
                                   n_top=6
                                   )

fig.savefig(f"../reports/figures/CWatM_data/plot_metrics_profile_top_inputs_for_each_output_global.png", dpi=300)


#### 3.2 Plot of relationship types

In [None]:
PROCESSED_DATA_FOLDER_PATH = Path("../data/processed")

all_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all_land.parquet"))
forcings_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings_land.parquet"))
outputs_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs_land.parquet"))

data_df = pd.concat((all_land_df, forcings_land_df, outputs_land_df), axis=1)
data_df

In [283]:
from typing import List, Tuple

from matplotlib.markers import MarkerStyle
from matplotlib.ticker import MaxNLocator
import matplotlib.gridspec as gridspec


def plot_bar_measures_for_groups(measures_df: pd.DataFrame,
                                 ranks_df: pd.DataFrame,
                                 pairs_1: List[Tuple[str, str]],
                                 pairs_2: List[Tuple[str, str]],
                                 pairs_3: List[Tuple[str, str]],
                                 measures, 
                                 sort_values_by: str = "MIC",
                                 fig = None,
                                 axes = None
                                 ):
    
    measures_df = measures_df.loc[pairs_1 + pairs_2 + pairs_3]
    ranks_df = ranks_df.loc[pairs_1 + pairs_2 + pairs_3]

    df = measures_df.abs()
    if sort_values_by is not None:
        df = df[measures].sort_values(sort_values_by, ascending=False)

    df_rank = ranks_df.loc[df.index]

    if fig is None or axes is None:
        fig, axes = plt.subplots(nrows=1, ncols=3, 
                                 figsize=(14,6),
                                 width_ratios=(len(pairs_1)/len(df),
                                               len(pairs_2)/len(df),
                                               len(pairs_3)/len(df)),
                                 sharey=True)

    for i, pairs in enumerate([pairs_1, pairs_2, pairs_3]):

        temp_df = df.loc[pairs]
        temp_df_rank = df_rank.loc[pairs]

        x = np.arange(len(temp_df))
        width = 0.18
        multiplier = 0

        for measure in measures:
            offset = width * multiplier
            rects = axes[i].bar(x + offset, temp_df[measure], width, label=measure if i == 2 else "_",
                                )
            # axes[i].bar_label(rects,
            #                   labels=temp_df_rank[f"{measure}_rank"].values,
            #                   padding=3)
            multiplier += 1
        
        axes[i].set_xticks(x + width, temp_df.index.to_list(), rotation=0)

    # axes[0].set_ylabel("Measure of dependence value")
    axes[0].set_ylim(0, 1)

    axes[0].set_title("Simple relationships")
    axes[1].set_title("Threshold-like relationships")
    axes[2].set_title("Complex relationships")

    # draw a 0.7 line
    for axis in axes:

        y_value = 0.7

        axis.relim()
        axis.autoscale_view()
        x_min, x_max = axis.get_xlim()
        axis.hlines(y=y_value, xmin=x_min, xmax=x_max, color='gray', linestyle='--')
        # Get the current y-ticks
        y_ticks = axes[i].get_yticks()
        # Ensure the desired y_value is included in the ticks
        if y_value not in y_ticks:
            y_ticks = list(y_ticks) + [y_value]
        # Set the updated y-ticks
        axes[i].set_yticks(sorted(y_ticks))

    return fig

def plot_scatterplot(df: pd.DataFrame,
                     input_col: str,
                     output_col: str,
                     axis
                     ):

    marker_style = MarkerStyle(marker=".",
                               fillstyle="full")
    
    axis.scatter(x=df[input_col], y=df[output_col],
                marker=marker_style,
                color="black",
                edgecolor="none",
                s=25,
                alpha=0.25,
                )
    
    axis.set_xlabel(input_col)
    axis.set_ylabel(output_col)
    
    axis.xaxis.set_major_locator(MaxNLocator(7))



def plot_bar_measures_for_groups_with_scatterplots(measures_df: pd.DataFrame,
                                                   ranks_df: pd.DataFrame,
                                                   data_df: pd.DataFrame,
                                                   pairs_1: List[Tuple[str, str]],
                                                   pairs_2: List[Tuple[str, str]],
                                                   pairs_3: List[Tuple[str, str]],
                                                   measures, 
                                                   sort_values_by: str = "MIC"
                                                   ):
    
    fig = plt.figure(figsize=(14, 13))

    # Create two subfigures stacked vertically
    subfigs = fig.subfigures(nrows=2, ncols=1,
                             height_ratios=(0.4, 0.6)
                             )
    subfig_top = subfigs[0]
    subfig_bottom = subfigs[1]

    # In the top subfigure, create a 1x3 layout for the bar plots
    axes_bar = subfig_top.subplots(
        nrows=1, ncols=3,
        sharey=True,
        gridspec_kw={
            'width_ratios': [
                len(pairs_1)/len(measures_df),
                len(pairs_2)/len(measures_df),
                len(pairs_3)/len(measures_df)
            ]
        }
    )

    # Plot the bar measures into the top subfigure
    plot_bar_measures_for_groups(measures_df=measures_df,
                                 ranks_df=ranks_df,
                                 pairs_1=pairs_1,
                                 pairs_2=pairs_2,
                                 pairs_3=pairs_3,
                                 measures=measures,
                                 sort_values_by=sort_values_by,
                                 fig=subfig_top,
                                 axes=axes_bar)

    subfig_top.supylabel("Value", 
                         x=0.008
                         )
    subfig_top.supxlabel("Variable pairs", y=0.01)
    subfig_top.suptitle("Relationship types and their metrics profile", fontsize=20)
    
    # fig.tight_layout()
    subfig_top.subplots_adjust(
        # bottom=0.13,
        left=0.05,
        # right=0.9
    )
    subfig_top.legend(title="Metrics",
                      title_fontsize="large",
                   #    labels=measures, 
                      markerscale=3,
                      handleheight=2,
                      loc="center right",
                      ncol=1,
                      )

    # In the bottom subfigure, create a 2x3 layout for scatterplots
    axes_scatter = subfig_bottom.subplots(nrows=2, ncols=3,
                                          gridspec_kw={
                                              'hspace': 0.2,  # Increase vertical space between subplots
                                              'wspace': 0.25   # Increase horizontal space if you want as well
                                          },
                                          )

    letter_x = -0.08
    letter_y = 1.03
    letter_fontsize = 14

    axes_00 = axes_scatter[0, 0]
    axes_10 = axes_scatter[1, 0]
    plot_scatterplot(df=data_df,
                     input_col="tasmax",
                     output_col="potevap",
                     axis=axes_00)
    plot_scatterplot(df=data_df,
                     input_col="pr",
                     output_col="evap-total",
                     axis=axes_10)
    axes_00.text(letter_x, letter_y, "A", transform=axes_00.transAxes,
                 fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    axes_bar[0].text(0.05, 0.95, "A", transform=axes_bar[0].transAxes,
                     fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    axes_10.text(letter_x, letter_y, "B",  transform=axes_10.transAxes,
                 fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    axes_bar[0].text(0.55, 0.95, "B", transform=axes_bar[0].transAxes,
                     fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    
    axes_01 = axes_scatter[0, 1]
    axes_11 = axes_scatter[1, 1]
    plot_scatterplot(df=data_df,
                     input_col="percolationImp",
                     output_col="qr",
                     axis=axes_01)
    plot_scatterplot(df=data_df,
                     input_col="secondStorDepth",
                     output_col="qr",
                     axis=axes_11)
    axes_01.text(letter_x, letter_y, "C",  transform=axes_01.transAxes,
                 fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    axes_bar[1].text(0.05, 0.95, "C", transform=axes_bar[1].transAxes,
                     fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    axes_11.text(letter_x, letter_y, "D",  transform=axes_11.transAxes,
                 fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    axes_bar[1].text(0.55, 0.95, "D", transform=axes_bar[1].transAxes,
                     fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    
    axes_02 = axes_scatter[0, 2]
    axes_12 = axes_scatter[1, 2]
    plot_scatterplot(df=data_df,
                     input_col="chanleng",
                     output_col="potevap",
                     axis=axes_02)
    axes_12.axis("off")
    axes_02.text(letter_x, letter_y, "E",  transform=axes_02.transAxes,
                 fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)
    axes_bar[2].text(0.05, 0.95, "E", transform=axes_bar[2].transAxes,
                     fontweight='bold', va='center', ha='left', fontsize=letter_fontsize)

    subfig_bottom.suptitle("Relationship types examples", y=0.9800, fontsize=20)
    axes_00.set_title("Simple relationships", y=1.05)
    axes_01.set_title("Threshold-like relationships", y=1.05)
    axes_02.set_title("Complex relationships", y=1.05)

    # fig.tight_layout()


    return fig


In [None]:
fig = plot_bar_measures_for_groups_with_scatterplots(measures_df=significant_measures_df,
                                                     ranks_df=significant_ranks_measures_df,
                                                     data_df=data_df,
                                                     pairs_1=[("tasmax", "potevap"),
                                                              ("pr", "evap-total")],
                                                     pairs_2=[("percolationImp", "qr"),
                                                              ("secondStorDepth", "qr")],
                                                     pairs_3=[("chanleng", "potevap")],
                                                     measures=["pearson", "spearman", "MIC", "MAS"],
                                                     sort_values_by=None,
                                                     )

fig.savefig(f"../reports/figures/CWatM_data/plot_relationship_types_by_metrics_with_examples.png", dpi=300)


In [None]:
# fig = plot_bar_measures_for_groups(measures_df=significant_measures_df,
#                                    ranks_df=significant_ranks_measures_df,
#                                    pairs_1=[("tasmax", "potevap"),
#                                             ("pr", "evap-total")],
#                                    pairs_2=[("percolationImp", "qr"),
#                                             ("secondStorDepth", "qr")],
#                                    pairs_3=[("chanleng", "potevap")],
#                                    measures=["pearson", "spearman", "MIC", "MAS"],
#                                    sort_values_by=None,
#                                    )

# fig.savefig(f"../reports/figures/CWatM_data/plot_relationship_types_by_metrics.png", dpi=300)


## Supporting information plots

In [None]:
def plot_corr_matrix(input_df):
    temp_df = input_df.copy()
    temp_df["pearson"] = temp_df["pearson"]**2
    temp_df["spearman"] = temp_df["spearman"]**2
    pearson_corr_matrix = temp_df.corr(method="pearson").round(2)
    fig = px.imshow(pearson_corr_matrix, text_auto=True, zmin=-1, zmax=+1)
    fig.show()

plot_corr_matrix(significant_measures_df)
plot_corr_matrix(significant_measures_dc_df)
plot_corr_matrix(significant_measures_dw_df)
plot_corr_matrix(significant_measures_wc_df)
plot_corr_matrix(significant_measures_ww_df)


In [None]:
def plot_scatter_matrix(input_df):
    temp_df = input_df.copy()
    fig = go.Figure(
        data=go.Splom(
            dimensions=[dict(label=col,
                             values=temp_df[col]) for col in temp_df.columns],
        diagonal_visible=False, # remove plots on diagonal
        # showupperhalf=False,
        text=measures_df.index.to_list(),
        )
    )
    fig.update_layout(
        title='Measures of dependence',
        width=900,
        height=600,
        hovermode="x",
    )
    fig.show()

plot_scatter_matrix(significant_measures_df)
plot_scatter_matrix(significant_measures_dc_df)
plot_scatter_matrix(significant_measures_dw_df)
plot_scatter_matrix(significant_measures_wc_df)
plot_scatter_matrix(significant_measures_ww_df)