In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
from adjustText import adjust_text
import pandas as pd
from scipy.stats import pearsonr, spearmanr, kendalltau
import numpy as np
import seaborn as sns
from matplotlib.patches import Patch
import sys

sys.path.insert(1,'../CompareNormalisation')
from Functions import *
import MetricMapping

# Access the mappings:
type_mapping = MetricMapping.type_mapping
name_mapping = MetricMapping.name_mapping

label_resolutions=["10 minute", "30 minute", "60 minute"]
resolution_index_res = {'10m': 0, '30m': 1, '60m': 2}

### Read in data and remove the log/yj endings

In [None]:
transformed_minmax_scaled = pd.read_csv("../Data/NotScaled_AllRes.csv")

column_names = transformed_minmax_scaled.columns.str.replace('_log', '')
column_names = column_names.str.replace('_yj', '')
transformed_minmax_scaled.columns = column_names

### Remove columns we're not interested in (rn)

In [None]:
cols_to_delete = [f'relative_amp_scaled', 'peak_mean_ratio_scaled', 'total_precip', 'duration', 'event_num', 'BSC_Index',
    'T50','min_intensity']

for res in ['5m', '10m', '30m', '60m']:
    for col in cols_to_delete:
        del transformed_minmax_scaled[col + f'_{res}']

#### Delineate categorical and continuous metrics

In [None]:
metrics = set()
for col in transformed_minmax_scaled.columns:
    metric, res = split_metric_resolution(col)
    if res is not None:
        metrics.add(metric)
metrics = list(metrics)

In [None]:
categorical_metrics = ['3rd_ARR',  '3rd_rcg',  '3rd_w_peak', '4th_w_peak', '5th_w_peak', 'third_ppr', '3rd_w_most', 
                       '4th_w_most', '5th_w_most']
continuous_metrics = [metric for metric in metrics if metric not in categorical_metrics]
intermittency_metrics = ['intermittency', 'event_dry_ratio']

In [None]:
summary_df = compute_metric_sensitivity_by_resolution(df=transformed_minmax_scaled,
    continuous_metrics=continuous_metrics,
    categorical_metrics=categorical_metrics,
    resolutions=["10m", "30m", "60m"])

In [None]:
summary_df["type2"] = summary_df["metric"].map(type_mapping)
summary_df = summary_df.sort_values(
    by=["type2", "rank_corr"],
    key=lambda col: (
        col.map({'Asymmetry': 0, 'Peakiness': 1, 'Concentration': 2, 'Intermittency': 3})
        if col.name == "type2" else -col))

In [None]:
# Split intermittency and other types
df_intermittency = summary_df[summary_df["type2"] == "Intermittency"]
df_other = summary_df[(summary_df["type2"] != "Intermittency") & (summary_df["type"] != "categorical")]
df_categorical = summary_df[summary_df["type"] == "categorical"]

## Plot

In [None]:
unique_metrics_main = df_other["metric"].unique()
n_cols_main = 6
n_rows_main = -(-len(unique_metrics_main) // n_cols_main)

fig_main, axs_main = plt.subplots(ncols=n_cols_main, nrows=n_rows_main,
                                  figsize=(4.2 * n_cols_main, 3.5 * n_rows_main),
                                  sharex=True, sharey=True)
axs_main = axs_main.flatten()

# Track max/min values for autoscaled limits
x_min, x_max, y_min, y_max = float('inf'), float('-inf'), float('inf'), float('-inf')


for i, this_metric in enumerate(unique_metrics_main):
    ax = axs_main[i]
    metric_name_for_plot = name_mapping[this_metric]
    metric_data = df_other[df_other["metric"] == this_metric]
    scatter_without_labels(ax, metric_data, metric_name_for_plot, type_color_map_2, resolution_index_res)

    # Update global x/y limits
    x_vals = metric_data["rank_corr"]
    y_vals = metric_data["val_diff"]
    if not x_vals.empty and not y_vals.empty:
        x_min = min(x_min, x_vals.min())
        x_max = max(x_max, x_vals.max())
        y_min = min(y_min, y_vals.min())
        y_max = max(y_max, y_vals.max())
        
        # Round outer limits to nearest multiple of 5 or 10
        x_min_rounded = np.floor(x_min / 5) * 1
        x_max_rounded = np.ceil(x_max / 5) * 1
        y_min_rounded = np.floor(y_min / 5) * 5
        y_max_rounded = np.ceil(y_max / 5) * 5

        # Set tick positions every 5 or 10 units
        xticks = np.arange(x_min_rounded, x_max_rounded +0.25, 0.25)
        yticks = np.arange(y_min_rounded, y_max_rounded + 1, 25)

        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)        
        ax.set_xticks(xticks)
        ax.set_yticks(yticks)
        ax.grid(True)

    for i, ax in enumerate(axs_main):
        if i % n_cols_main == 0:  # First column in each row
            ax.set_ylabel("sMAPE from 5m", fontsize=23)
        else:
            ax.set_ylabel('')    
    for i, ax in enumerate(axs_main):
        row = i // n_cols_main
        if row == n_rows_main - 1:  # Last row
            ax.set_xlabel("Spearman’s ρ", fontsize=23)
        else:
            ax.set_xlabel('')

for ax in axs_main[len(unique_metrics_main):]:
    ax.axis('off')

# fig_main.suptitle("Summary statistics (Asymmetry, Peakiness, Concentration)", fontsize=25)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
fig_main.savefig("../Figures/metrics_main_types.png", dpi=300, facecolor='white')

# --- PLOT 2: Intermittency Figure ---
if not df_intermittency.empty:
    unique_metrics_int = df_intermittency["metric"].unique()
    n_cols_int = 6
    n_rows_int = -(-len(unique_metrics_int) // n_cols_int)

    fig_int, axs_int = plt.subplots(ncols=n_cols_int, nrows=n_rows_int,
                                    figsize=(4.2 * n_cols_int, 4.5 * n_rows_int),
                                    sharex=True, sharey=True)
    axs_int = axs_int.flatten()

    for i, this_metric in enumerate(unique_metrics_int):
        ax = axs_int[i]
        metric_name_for_plot = name_mapping[this_metric]
        metric_data = df_intermittency[df_intermittency["metric"] == this_metric]
        scatter_without_labels(ax, metric_data, metric_name_for_plot, type_color_map_2, resolution_index_res )
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        ax.set_xticks(xticks)
        ax.set_yticks(yticks)
        ax.grid(True)

    for i, ax in enumerate(axs_int):
        if i % n_cols_int == 0:  # First column in each row
            ax.set_ylabel("sMAPE from 5m", fontsize=23)
        else:
            ax.set_ylabel('')    
            
    for i, ax in enumerate(axs_int):
        row = i // n_cols_int
        if row == n_rows_int - 1:  # Last row
            ax.set_xlabel("Spearman’s ρ", fontsize=23)
        else:
            ax.set_xlabel('')            
    
    resolutions=["10m", "30m", "60m"]
    legend_types = list(type_color_map_2.keys())
    legend_axes_start = len(unique_metrics_int)  # first free subplot

    for i, metric_type in enumerate(legend_types):
        idx = legend_axes_start + i
        if idx < len(axs_int):
            ax_legend = axs_int[idx]
            ax_legend.axis('off')
            colors = type_color_map_2[metric_type]
            patches = [Patch(facecolor=colors[j], label=label_resolutions[j], alpha=1) for j in range(len(resolutions))]
            ax_legend.legend(handles=patches, title=metric_type, loc='center', frameon=False, ncol=1, handlelength=2,
                             fontsize=25, title_fontsize =30)
    
#     ax.set_xlabel("Spearman’s ρ", fontsize=19)
#     ax.set_ylabel("sMAPE from 5m", fontsize=19)        
        
    for ax in axs_int[len(unique_metrics_int):]:
        ax.axis('off')

#     fig_int.suptitle("Intermittency metrics", fontsize=25)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()
    fig_int.savefig("../Figures/metrics_intermittency.png", dpi=300, facecolor='white')


In [None]:
resolutions = ["10m", "30m", "60m"]

# Filter the dataframe
filtered_df=summary_df[summary_df['type']=='categorical']
# Sort by type_2
filtered_df = filtered_df.sort_values(
    by=["type", "rank_corr"],
    key=lambda col: (
        col.map({'Categorical': 0})
        if col.name == "type" else -col))

# Unique metrics
unique_metrics = filtered_df["metric"].unique()
# Plot setup
n_cols = 6
n_rows = -(-len(unique_metrics) // n_cols)
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(4.2 * n_cols, 4 * n_rows), sharex=True, sharey=True)
axs = axs.flatten()

# Plot each metric
for i, this_metric in enumerate(unique_metrics):
    ax = axs[i]
    metric_name_for_plot = name_mapping[this_metric]
    metric_data = filtered_df[filtered_df["metric"] == this_metric]
    scatter_without_labels_cat(ax, metric_data, metric_name_for_plot, type_color_map_2, resolution_index_res )

# Hide unused axes
for ax in axs[len(unique_metrics):]:
    ax.axis('off')

for i, ax in enumerate(axs):
    row = i // n_cols
    if row == n_rows - 1:  # Last row
        ax.set_xlabel("Kendall’s τ", fontsize=23)
    else:
        ax.set_xlabel('')   
        
for i, ax in enumerate(axs):
    if i % n_cols == 0:  # First column in each row
        ax.set_ylabel("% diff. from 5m", fontsize=23)
    else:
        ax.set_ylabel('')     

# Create custom legend handles using the palette
palette = type_color_map_2['categorical']
legend_handles = [
    Patch(facecolor=palette[i], edgecolor='white', label=res, alpha=1)
    for i, res in enumerate(resolutions)]

fig.legend(
    handles=legend_handles,
    labels=label_resolutions,
    title="Categorical",
     frameon=False,
    title_fontsize=25,    # larger title
    fontsize=22,          # larger labels
    bbox_to_anchor=(0.53, 0.26),
    loc='center left',
    labelspacing=0.5,     # space between labels
    handlelength=2.7,     # length of color boxes
    handletextpad=0.8     # space between box and text
)

# Main title and layout
# fig.suptitle("Categorical metrics", fontsize=25)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# Save
fig.savefig("../Figures/categorical_metrics_single_figure.png", dpi=300, facecolor='white')