In [None]:
# %%
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Load your data
# Path to your Feather file
feather_path = "/Trex/case_results/i.e215.I2000Clm50SpGs.hw_production.05/research_results/summary/updated_local_hour_adjusted_variables_HW98.feather"

# Read the Feather file into a DataFrame
df = pd.read_feather(feather_path)

# 1. Calculate Daily Average
df['date'] = df['time'].dt.date  # Extract date


# %%

# 3. Sort so day_in_event is computed in the correct chronological order
df = df.sort_values(['location_ID', 'event_ID', 'time'])

# 4. Compute "day_in_event"
df['day_in_event'] = df.groupby(['location_ID', 'event_ID'])['date']\
                       .transform(lambda x: (pd.to_datetime(x) - pd.to_datetime(x.min())).dt.days)

# 5. Aggregate data to daily level (mean of UHI_diff, Q2M, SOILWATER_10CM)
df_agg = (df.groupby(['KGMajorClass', 'day_in_event'], as_index=False)
            .agg(UHI_diff_mean=('UHI_diff','mean'),
                 UHI_diff_std=('UHI_diff','std'),
                 Q2M_mean=('Q2M', 'mean'),
                 SOILWATER_10CM_mean=('SOILWATER_10CM', 'mean')))


# %%
# --- REMOVE POLAR ZONE ---
df_plot = df_agg[df_agg['KGMajorClass'] != 'Polar'].copy()

# --- FILTER FOR FIRST N DAYS ---
N_DAYS = 11
df_plot = df_plot[df_plot['day_in_event'] < N_DAYS].copy()

# 6. We'll make the plot 2x2
unique_zones = sorted(df_plot['KGMajorClass'].dropna().unique())  # sort & drop NA
n_zones = len(unique_zones)

df_plot_agg = (
    df_plot.groupby(['KGMajorClass', 'day_in_event'], as_index=False)
                    .agg(
                   UHI_diff_mean=('UHI_diff_mean','mean'),
                   UHI_diff_std=('UHI_diff_std','mean'),  # simple approach: averaging standard deviations
                   Q2M_mean=('Q2M_mean', 'mean'),
                   SOILWATER_10CM_mean=('SOILWATER_10CM_mean', 'mean')
               )
)

# %% [markdown]
# ### Normalize Q2M and SOILWATER_10CM to a 0-1 scale (Min-Max scaling)

# %%
# Find global min and max for Q2M and SOILWATER_10CM
min_max_values = {}
for var in ['Q2M_mean', 'SOILWATER_10CM_mean']: # Only normalize Q2M and SOILWATER_10CM
    min_val = df_plot_agg[var].min()
    max_val = df_plot_agg[var].max()
    min_max_values[var] = {'min': min_val, 'max': max_val}

# Apply Min-Max scaling to Q2M and SOILWATER_10CM
for var in ['Q2M_mean', 'SOILWATER_10CM_mean']: # Only normalize Q2M and SOILWATER_10CM
    df_plot_agg[f'{var}_scaled'] = df_plot_agg.apply(
        lambda row: (row[var] - min_max_values[var]['min']) / (min_max_values[var]['max'] - min_max_values[var]['min']),
        axis=1
    )

# %%

# Set up 2x2 subplot grid
ncols = 2
nrows = 2

fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                         figsize=(5*ncols, 4*nrows), # Adjust figsize if needed for 2x2
                         sharex=True, sharey=True)
axes = axes.flatten() if n_zones > 1 else [axes]  # handle if only 1 zone

colors = ['C0', 'C1', 'C2'] # Define colors
variables_left_yaxis = ['UHI_diff'] # Variables for left y-axis (original scale)
variables_right_yaxis = ['Q2M', 'SOILWATER_10CM'] # Variables for right y-axis (normalized)
variable_means_left = ['UHI_diff_mean']
variable_means_right_scaled = ['Q2M_mean_scaled', 'SOILWATER_10CM_mean_scaled']


for i, zone in enumerate(unique_zones):
    if i < nrows * ncols: # only plot if there is a subplot available in 2x2 grid
        ax = axes[i]

        # Subset data for this zone
        sub_agg = df_plot_agg[df_plot_agg['KGMajorClass'] == zone].copy()

        # Plot UHI_diff on the primary (left) y-axis
        for var_index, var in enumerate(variables_left_yaxis):
            mean_var_col = variable_means_left[var_index]
            ax.plot(
                sub_agg['day_in_event'],
                sub_agg[mean_var_col],
                label=f"{var}",
                color=colors[0], # Use consistent color for UHI_diff
                marker='o',
                markersize=3,
                linestyle='-'
            )
        ax.set_ylabel("UHI_diff (°C)", color=colors[0]) # Y-axis label for UHI_diff, matching color
        ax.tick_params(axis='y', labelcolor=colors[0]) # Tick color match

        # Create secondary y-axis, sharing x-axis
        ax_right = ax.twinx()

        # Plot normalized Q2M and SOILWATER_10CM on the secondary (right) y-axis
        for var_index, var in enumerate(variables_right_yaxis):
            mean_var_col_scaled = variable_means_right_scaled[var_index]
            ax_right.plot(  
                sub_agg['day_in_event'],
                sub_agg[mean_var_col_scaled],
                label=f"{var} (normalized)",
                color=colors[var_index+1], # Colors for Q2M and SOILWATER_10CM
                marker='o',
                markersize=3,
                linestyle='-'
            )
        ax_right.set_ylabel("Normalized Q2M & SOILWATER_10CM (0-1)", color='black') # Y-label for normalized vars
        ax_right.tick_params(axis='y', labelcolor='black') # Tick color

        # Combine legends from both axes - IMPORTANT
        lines_left, labels_left = ax.get_legend_handles_labels()
        lines_right, labels_right = ax_right.get_legend_handles_labels()
        ax.legend(lines_left + lines_right, labels_left + labels_right, loc='best', title="Variables")


        ax.set_title(zone)
        ax.set_xlabel("Day in Heatwave Event")
        ax.set_xlim(0, N_DAYS - 1) # set x limit

    else:
        break # Stop if 2x2 grid is filled


# Turn off extra axes
for j in range(i + 1, nrows*ncols):
    if j < len(axes):
        fig.delaxes(axes[j])


plt.suptitle("Day-by-Day Changes of UHI_diff, Normalized Q2M & SOILWATER_10CM by Climate Zone (First 10 Days)", fontsize=14, y=0.98)
plt.tight_layout()
plt.show()