In [None]:
from utils.ocean_basins import get_zoned_df
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.basemap import Basemap

In [None]:
features = ['SST', 'SAL', 'ice_frac', 'mixed_layer_depth', 'heat_flux_down', 'water_flux_up', 'stress_X', 'stress_Y', 'currents_X', 'currents_Y','fco2','fco2_pre','co2flux','co2flux_pre']

data_1960 = pd.read_pickle('../../../../../../../Volumes/T7 Shield/exp_1/ORCA025.L46.LIM2vp.CFCSF6.MOPS.JRA.LP04-KLP002.hind_1960_df.pkl')
data_2015 = pd.read_pickle('../../../../../../../Volumes/T7 Shield/exp_1/ORCA025.L46.LIM2vp.CFCSF6.MOPS.JRA.LP04-KLP002.hind_2015_df.pkl')
data_1960["month"] = data_1960["time_counter"].apply(lambda x: x.month)
data_2015["month"] = data_2015["time_counter"].apply(lambda x: x.month)

months = [1,4,7,10]
month_labels = ['','January','','','April','','','July','','','October','','']
region_labels = ['Arctic','North Atlantic','Equatorial Pacific','Southern Ocean']

datasets = {
    "1960": data_1960,
    "2015": data_2015
}

In [None]:
for year_label, df in datasets.items():
    regions = get_zoned_df(df)  # returns a list of 4 DataFrames

    for region_idx, region_df in enumerate(regions):
        for month in months:
            # Filter by month
            month_df = region_df[region_df['month'] == month]
            corr_df = month_df[features].corr().loc[features, features]

           # Compute correlation matrix
            corr_matrix = corr_df.corr()

            # Plot heatmap
            fig, ax = plt.subplots(figsize=(10, 8))
            cax = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)

            # Add colorbar
            fig.colorbar(cax)

            # Set ticks
            ax.set_xticks(np.arange(len(features)))
            ax.set_yticks(np.arange(len(features)))
            ax.set_xticklabels(features, rotation=45, ha='right')
            ax.set_yticklabels(features)

            # Add correlation values in the cells
            for i in range(len(features)):
                for j in range(len(features)):
                    text = f"{corr_matrix.iloc[i, j]:.2f}"
                    ax.text(j, i, text, ha='center', va='center', color='black', fontsize=7)

            print(region_labels[region_idx])
            print(month_labels[month])

            ax.set_title("Correlation Heatmap " + region_labels[region_idx] + " " + month_labels[month] + " " + year_label)
            plt.tight_layout()
            plt.savefig("../plots/general/heatmaps/heatmap_" + region_labels[region_idx] + "_" + month_labels[month] + "_" + year_label + ".png", dpi=300, bbox_inches='tight')