# Load the necessary libraries

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

## Load the data and preprocess it

In [None]:
# -----------------------------
# Load dataset
# -----------------------------
df = pd.read_csv(
    "CV5_GOOD_CSV_DATA/combined_monthly_data_cv5.csv", parse_dates=["time"]
)
lat, lon = 18.0, -23.75
df = df[(df["lat"] == lat) & (df["lon"] == lon)].copy()
df.set_index("time", inplace=True)
df.sort_index(inplace=True)

variables = ["TS_C", "CHL", "AODANA", "wind_speed"]
anomalies_df = pd.DataFrame(index=df.index)


# Function: anomaly + detrending
def compute_anomaly_detrend(series):
    var_df = series.dropna()
    if var_df.empty:
        return series * np.nan  # preserve index with NaN
    climatology = var_df.groupby(var_df.index.month).mean()
    anomaly = var_df - var_df.index.month.map(climatology)
    # Align back to original index, keep NaNs where original was NaN
    anomaly = anomaly.reindex(series.index)
    # Detrend only on available anomaly values
    valid = anomaly.dropna()
    if len(valid) < 3:
        return anomaly  # not enough points to detrend
    time_index = np.arange(len(valid))
    slope, intercept, *_ = stats.linregress(time_index, valid.values)
    detrended_valid = valid - (slope * time_index + intercept)
    detrended = anomaly.copy()
    detrended.loc[valid.index] = detrended_valid
    return detrended


for var in variables:
    anomalies_df[var] = compute_anomaly_detrend(df[var])

# Add deposition
cv_dry_wet_df = pd.read_csv(
    "CV5_DEPOSITIONS/total_deposition_end_of_month.csv", parse_dates=["time"]
)
cv_dry_wet_df.set_index("time", inplace=True)
cv_dry_wet_df.sort_index(inplace=True)

for dep_var in ["total_wet_deposition", "total_dry_deposition"]:
    anomalies_df[dep_var] = compute_anomaly_detrend(
        cv_dry_wet_df[dep_var].reindex(anomalies_df.index)
    )

anomalies_df.dropna(inplace=True)


# -----------------------------
# Seasons
# -----------------------------
def assign_standard_season(month):
    if month in [12, 1, 2]:
        return "Winter"
    elif month in [3, 4, 5]:
        return "Spring"
    elif month in [6, 7, 8]:
        return "Summer"
    else:
        return "Fall"


anomalies_df["season"] = anomalies_df.index.month.map(assign_standard_season)
anomalies_df["JAS"] = anomalies_df.index.month.isin([7, 8, 9])
anomalies_df["ASO"] = anomalies_df.index.month.isin([8, 9, 10])

all_variables = variables + ["total_wet_deposition", "total_dry_deposition"]


# -----------------------------
# Compute correlation + p-values
# -----------------------------
def corr_pvalue(df_vars):
    # Work on a copy that drops rows with any NaNs among the selected variables
    d = df_vars.dropna()
    if d.shape[0] < 3:
        # Not enough samples; return all NaNs corr and p=1
        cols = d.columns if d.shape[1] else df_vars.columns
        corr = pd.DataFrame(np.nan, index=cols, columns=cols)
        pvals = pd.DataFrame(1.0, index=cols, columns=cols)
        np.fill_diagonal(pvals.values, 0.0)
        return corr, pvals

    corr = d.corr()
    pvals = pd.DataFrame(np.ones_like(corr), columns=corr.columns, index=corr.index)
    for i in corr.columns:
        for j in corr.columns:
            if i == j:
                pvals.loc[i, j] = 0.0
            else:
                pair = d[[i, j]].dropna()
                if pair.shape[0] < 3 or pair[i].nunique() < 2 or pair[j].nunique() < 2:
                    pvals.loc[i, j] = 1.0
                else:
                    _, p = stats.pearsonr(pair[i], pair[j])
                    pvals.loc[i, j] = p
    return corr, pvals


seasonal_corrs = {}
seasonal_pvals = {}

for season in ["Winter", "Spring", "Summer", "Fall"]:
    if season in ["Winter", "Spring", "Summer", "Fall"]:
        season_df = anomalies_df[anomalies_df["season"] == season][all_variables]
    else:
        season_df = anomalies_df[anomalies_df[season]][all_variables]
    corr, pvals = corr_pvalue(season_df)
    seasonal_corrs[season] = corr
    seasonal_pvals[season] = pvals

# -----------------------------
# Plotting (lower triangle only + significance, no diagonals)
# -----------------------------
variable_display_names = {
    "TS_C": "SST",
    "CHL": "CHL",
    "AODANA": "AOD",
    "wind_speed": "Wind Speed",
    "total_wet_deposition": "Wet Dep.",
    "total_dry_deposition": "Dry Dep.",
}

fig, axs = plt.subplots(2, 2, figsize=(16, 12))
axs = axs.flatten()

seasons_order = list(seasonal_corrs.keys())

for i, season in enumerate(seasons_order):
    corr = seasonal_corrs[season]
    pvals = seasonal_pvals[season]

    # Rename for display
    corr_display = corr.rename(
        index=variable_display_names, columns=variable_display_names
    )
    pvals_display = pvals.rename(
        index=variable_display_names, columns=variable_display_names
    )

    # Mask upper triangle + diagonal
    mask = np.triu(np.ones_like(corr_display, dtype=bool), k=0)

    # Build annotation with *, **, ***
    annot = pd.DataFrame("", index=corr_display.index, columns=corr_display.columns)
    for r in corr_display.index:
        for c in corr_display.columns:
            if r == c:
                annot.loc[r, c] = ""  # no diagonal
            else:
                val = corr_display.loc[r, c]
                p = pvals_display.loc[r, c]
                stars = ""
                if pd.notna(p):
                    if p < 0.001:
                        stars = "***"
                    elif p < 0.01:
                        stars = "**"
                    elif p < 0.05:
                        stars = "*"
                annot.loc[r, c] = (f"{val:.2f}" if pd.notna(val) else "") + stars

    sns.heatmap(
        corr_display,
        mask=mask,
        annot=annot,
        fmt="",
        cmap="coolwarm",
        center=0,
        linewidths=0.5,
        cbar=True,
        ax=axs[i],
        annot_kws={"size": 11, "fontfamily": "Times New Roman"},
        # Explicitly set labels so seaborn knows what to draw
        xticklabels=corr_display.columns,
        yticklabels=corr_display.index,
    )

    axs[i].set_title(
        f"{season}", fontsize=18, fontweight="bold", fontfamily="Times New Roman"
    )

    # ✅ Style existing tick labels (do not reset them via set_*ticklabels)
    plt.setp(axs[i].get_xticklabels(), fontfamily="Times New Roman", rotation=0)
    plt.setp(axs[i].get_yticklabels(), fontfamily="Times New Roman", rotation=90)

    # ✅ Remove the actual tick marks for hidden labels
    # Get current tick positions
    xticks = axs[i].get_xticks()
    yticks = axs[i].get_yticks()

    # Remove last x-tick and first y-tick completely
    if len(xticks) > 0:
        axs[i].set_xticks(xticks[:-1])
    if len(yticks) > 0:
        axs[i].set_yticks(yticks[1:])

    # Update tick labels after modifying ticks
    axs[i].set_xticklabels(
        corr_display.columns[:-1], fontfamily="Times New Roman", rotation=0
    )
    axs[i].set_yticklabels(
        corr_display.index[1:], fontfamily="Times New Roman", rotation=90
    )

# Add significance legend
plt.suptitle(
    "e)", fontsize=28, fontweight="bold", fontfamily="Times New Roman", x=0.0, ha="left"
)
fig = plt.gcf()
fig.text(
    0.01, 0.02, "", fontsize=11, fontfamily="Times New Roman"
)  # Significance: * p<0.05, ** p<0.01, *** p<0.001

plt.tight_layout(rect=[0, 0.03, 1, 1])  # leave room for the footnote
plt.savefig(
    "Seasonal_anomalies_with_depositions/Good/seasonal_correlation_matrices_significant_lower_cv5_good_v2_ggod.png",
    dpi=360,
    bbox_inches="tight",
    facecolor="white",
)
plt.show()