In [1]:
import warnings
from itertools import product
import glob
from datetime import datetime
from datetime import timedelta
import numpy as np
import pandas as pd
import xarray as xr

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.path as mpath
import cartopy
import cartopy.crs as ccrs
import cartopy.feature
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
import cartopy.feature as cf
import shapely.geometry as sgeom
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

from sklearn.decomposition import PCA
from scipy import stats
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from sklearn.metrics import davies_bouldin_score

import pickle
import copy
from shapely import geometry
from sklearn.metrics.pairwise import euclidean_distances
import statsmodels.api as sm
from scipy.stats import linregress

# Functions

In [2]:
def extract_region(data_array, regioncoords):
    """
    Extract a region from a DataArray with longitudes ranging from 0 to 360.
    Rearrange the data to have continuous longitudes even if the region wraps around 0 or 360 degrees.
    
    Parameters:
    data_array (xr.DataArray): Input DataArray with coordinates 'lon' (0 to 360) and 'lat'.
    min_lon (float): Minimum longitude of the region.
    max_lon (float): Maximum longitude of the region.
    min_lat (float): Minimum latitude of the region.
    max_lat (float): Maximum latitude of the region.
    
    Returns:
    xr.DataArray: Extracted region with continuous coordinates.
    """
    min_lon, max_lon, min_lat, max_lat = regioncoords

    # Handle the case where the region crosses the prime meridian or the International Date Line
    if min_lon > max_lon:
        # Split the region into two parts: one on the left of 0° and one on the right
        region1 = data_array.sel(
            lon=slice(min_lon, 360), 
            lat=slice(min_lat, max_lat)
        )
        region2 = data_array.sel(
            lon=slice(0, max_lon), 
            lat=slice(min_lat, max_lat)
        )

        # Combine the two parts along the longitude axis
        combined_region = xr.concat([region1, region2], dim='lon')

    else:
        # Directly slice the region
        combined_region = data_array.sel(
            lon=slice(min_lon, max_lon), 
            lat=slice(min_lat, max_lat)
        )

    # # Sort the longitudes to ensure they are in the correct order
    # combined_region = combined_region.sortby('lon')

    return combined_region
    
def get_average_fields_for_centroids(dataarray,labels):
    dataarray = dataarray.drop_duplicates('time',keep='first')
    labels = labels[~labels.index.duplicated(keep='first')]
    
    wrs = np.unique(labels['WR'])
    avgs = []
    for wr in wrs:
        df_wr = labels[labels['WR']==wr]
        arr_selection = dataarray.sel(time=df_wr.index)
        averagefield = arr_selection.mean('time')
        avgs.append(averagefield)
    return xr.concat(avgs,dim='WR')

import math

def plot_multiple_maps(da,freqs_labels,regioncoords,names = None, path_save=None, n_cols=2):
    """
    Plot multiple maps from a list of data arrays with a fixed number of columns and dynamic rows.
    
    Parameters:
    - da_list: list of xarray.DataArray objects to plot.
    - n_cols: Number of columns for the subplot grid (default is 2).
    """

    min_lon, max_lon, min_lat, max_lat = regioncoords
    # Convert longitudes from 0-360 to -180-180 if necessary
    def convert_lon(lon):
        return lon if lon <= 180 else lon - 360
    
    min_lon_converted = convert_lon(min_lon)
    max_lon_converted = convert_lon(max_lon)
    
    # Number of maps to plot
    n_maps = len(da.WR)
    
    # Determine the number of rows needed
    n_rows = math.ceil(n_maps / n_cols)
    
    # Create a figure with the calculated number of subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(9, 2 * (n_rows)),
                             subplot_kw={'projection': ccrs.PlateCarree(central_longitude=-100)})

    # If there's only one row, axes will not be a 2D array, so we need to adjust for that
    if n_rows == 1:
        axes = np.expand_dims(axes, axis=0)
    
    # Plot each DataArray in the provided list
    for i in range(len(da.WR.values)-1):
        row = i // n_cols
        col = i % n_cols
        
        ax = axes[row, col]

        # Convert longitudes to -180 to 180 range
        lon = (da.lon + 180) % 360 - 180
    
        # Adjust data array to match the longitude range
        da_shifted, lon_shifted = xr.broadcast(da, lon)
        
        # Plot the rectangle to highlight the specified region
        rect_style = {'edgecolor': 'black', 'facecolor': 'gray', 'linewidth': 1.5, 'alpha':0.2}
        
        ax.set_extent([min_lon, max_lon, min_lat, max_lat], crs=ccrs.PlateCarree())

        # Add gridlines every 20 degrees
        gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                          linewidth=0.2, color='gray', alpha=0.5, linestyle='--')
        gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 60))
        gl.ylocator = plt.FixedLocator(np.arange(0, 91, 20))
        gl.top_labels = False
        gl.right_labels = False
        gl.xformatter = LongitudeFormatter(zero_direction_label=True)
        gl.yformatter = LatitudeFormatter()

        tick_fontsize = 10
        # Set font size for tick labels
        gl.xlabel_style = {'size': tick_fontsize}
        gl.ylabel_style = {'size': tick_fontsize}

        # Plot the data using pcolormesh
        mini=-2
        maxi=2
        intervals = 21
        bounds=np.linspace(mini,maxi,intervals)
        mesh = ax.contourf(lon_shifted.sel(WR=i).lon, da.lat, da_shifted.sel(WR=i).values, levels=bounds, vmin=mini, vmax=maxi,
                                 cmap='bwr', transform=ccrs.PlateCarree(),extend='both')

        # Add coastlines for context
        ax.coastlines()

        # Set title for each subplot
        if names:
            ax.set_title(f'{names[i]} - Freq.: {np.round(freqs_labels[i],2)}%',fontsize=11)
        else:
            ax.set_title(f'Cluster {i+1} - Freq.: {np.round(freqs_labels[i],2)}%',fontsize=11)

    # Hide any unused subplots
    for j in range(i + 1, n_rows * n_cols):
        fig.delaxes(axes[j // n_cols, j % n_cols])

    # # Adjust layout to prevent overlapping
    plt.tight_layout(w_pad=0.1)
    # Alternatively, adjust spacing between plots using subplots_adjust
    # fig.subplots_adjust(hspace=-0.7, wspace=0.3)  # Adjust these parameters as needed


    cax = fig.add_axes([0.55, 0.22, 0.4, 0.03])  # Example position: horizontal, below the main plot
    
    # Add a horizontal colorbar
    cbar = fig.colorbar(mesh, cax=cax, orientation='horizontal')
    cbar.set_label(r'Z Anomaly ($\sigma$)')

    # main_title = f"Region: {min_lon,max_lon,min_lat,max_lat}"
    # Add the overall title for the figure
    fig.suptitle('CESM2_pi - Weather Regimes', fontsize=14, y=1.04,ha='center')

    if path_save==False:
        # Show the plot
        plt.show()
        plt.close('all')
    else:
        plt.savefig(path_save, bbox_inches='tight',dpi=200)
        plt.close('all')

In [3]:
def compute_pcs(dataarray):
    dataflattened = dataarray.stack(flat=('lat','lon')).transpose('time','flat')
    pca_obj = PCA(12, whiten=True)
    pca_obj = pca_obj.fit(dataflattened)
    datatransformed = pca_obj.transform(dataflattened)

    variance_explained = np.sum(pca_obj.explained_variance_ratio_) * 100
    return datatransformed, variance_explained
    
def compute_wrs_seeded(dataarray,n=5):
    dataflattened = dataarray.stack(flat=('lat','lon')).transpose('time','flat')
    
    pca_obj = PCA(12, whiten=True)
    pca_obj = pca_obj.fit(dataflattened)
    datatransformed = pca_obj.transform(dataflattened)

    variance_explained = np.sum(pca_obj.explained_variance_ratio_) * 100
    
    # train kmeans # transfer learning
    k_means = KMeans(n_clusters=n,
                     # init='k-means++',
                     init=era5_clusters_centers,
                     n_init=1,
                     max_iter=300, #(30/75) * 300
                     tol=0.0001,
                     verbose=0,
                     random_state=42)
    k_means.fit(datatransformed)
    clusters_centers = k_means.cluster_centers_
    labels = k_means.labels_
    distances = euclidean_distances(clusters_centers, datatransformed)

    return clusters_centers, labels, distances, variance_explained, datatransformed, k_means

In [4]:
def get_EOFs_from_PCs(PCs,da_region):
    # Assuming the original data matrix is called 'X' and has dimensions [N, M]
    # 'PC' is the principal components matrix and has dimensions [N, K]
    # We want to recover the 'EOF' matrix, which will have dimensions [M, K]
    
    # Calculate the pseudo-inverse of PC
    if isinstance(PCs, np.ndarray):
        PC_pseudo_inv = np.linalg.pinv(PCs)
    else:
        PC_pseudo_inv = np.linalg.pinv(PCs.values)
    
    # Compute the EOFs using the equation above
    EOF = np.dot(PC_pseudo_inv, da_region.stack(flat=('lat','lon')).transpose('time','flat'))
    
    # EOF will now have dimensions [K, M]. If you want [M, K], transpose the result.
    EOF = EOF.T
    
    nlat, nlon = len(da_region['lat']), len(da_region['lon'])  # Retrieve the number of latitudes and longitudes
    EOF_reshaped = EOF.reshape(nlat, nlon, EOF.shape[-1])  # Shape to [lat, lon, K]
    
    # Create an xarray.DataArray with appropriate coordinates and dimensions
    EOF_xr = xr.DataArray(
        EOF_reshaped,
        dims=['lat', 'lon', 'mode'],  # Specify dimensions: latitude, longitude, and mode (PC index)
        coords={
            'lat': da_region.coords['lat'],
            'lon': da_region.coords['lon'],
            'mode': np.arange(EOF.shape[-1])  # Create a mode coordinate [1, 2, ..., K]
        },
        name='EOFs'
    )
    return EOF_xr

In [5]:
import numpy as np
import xarray as xr

def reorder_model_eofs(obs_eofs, model_eofs):
    """
    Reorders the model EOFs to match the observation EOFs based on spatial correlation.

    Parameters:
    - obs_eofs: xarray.DataArray containing observation EOFs with dimensions [lat, lon, mode]
    - model_eofs: xarray.DataArray containing model EOFs with dimensions [lat, lon, mode]

    Returns:
    - reordered_indices: List containing the index of model EOFs that best match each observation EOF.
    """
    # Initialize a correlation matrix to store the spatial correlation between each pair of EOFs
    n_obs_modes = obs_eofs.shape[-1]  # Number of EOF modes in the observation
    n_model_modes = model_eofs.shape[-1]  # Number of EOF modes in the model

    # Flatten lat/lon dimensions for correlation calculation
    obs_flat = obs_eofs.stack(spatial=('lat', 'lon'))  # Shape: [spatial, mode]
    model_flat = model_eofs.stack(spatial=('lat', 'lon'))  # Shape: [spatial, mode]

    # Compute spatial correlation between each observation EOF and each model EOF
    correlation_matrix = np.zeros((n_obs_modes, n_model_modes))

    for i in range(n_obs_modes):
        # print(i)
        for j in range(n_model_modes):
            obs_eof = obs_flat.sel(mode=i)
            model_eof = model_flat.sel(mode=j)

            # Calculate Pearson correlation coefficient
            correlation = np.corrcoef(obs_eof, model_eof)[0,1]
            # correlation = xr.corr(obs_eof, model_eof, dim='spatial')
            correlation_matrix[i, j] = correlation

    # For each observation EOF, find the best matches among model EOFs
    reordered_indices = []
    signs = []  # Store +1 or -1 to indicate sign of correlation
    matched_model_eofs = set()

    for i in range(n_obs_modes):
        # Get the indices of model EOFs sorted by absolute correlation (highest to lowest)
        sorted_model_indices = np.argsort(-np.abs(correlation_matrix[i, :]))

        # Find the best available match that hasn't been used yet
        for best_match in sorted_model_indices:
            if best_match not in matched_model_eofs:
                reordered_indices.append(best_match)
                matched_model_eofs.add(best_match)
                
                # Determine the sign of correlation (+1 or -1) based on the actual correlation value
                sign = 1 if correlation_matrix[i, best_match] >= 0 else -1
                signs.append(sign)
                break

    return reordered_indices, signs

# Compute WRs

In [6]:
import pickle
# Load the saved object
with open('kmeans_models/k_means_model_era5.pkl', 'rb') as f:
    k_means = pickle.load(f)
path_files = '/glade/derecho/scratch/jhayron/Data4WRsClimateChange/ProcessedDataReanalyses/'
path_pcs = '/glade/derecho/scratch/jhayron/Data4WRsClimateChange/PCs_Z500/'

reanalysis = 'ERA5'
anoms_era5 = xr.open_dataset(f'{path_files}Z500Anoms_{reanalysis}.nc')
region = [180, 330, 20, 80]
data_region_era5 = extract_region(anoms_era5, region)
pcs_era5 = pd.read_csv(f'{path_pcs}PCs_{reanalysis}.csv',
                       index_col=0,parse_dates=True, names=np.arange(0,12),skiprows=1)

clusters_centers = k_means.cluster_centers_
labels = k_means.labels_
cluster_centers = np.vstack([clusters_centers,np.zeros(12)])
distances = euclidean_distances(cluster_centers, pcs_era5.values)
labels_era5 = distances.argmin(axis=0)

df_labels = pd.DataFrame(labels_era5,index=pcs_era5.index)
df_labels.columns=['WR']
df_labels['distances'] = distances.min(axis=0)
corrs = np.array([np.corrcoef(pcs_era5.values[i],cluster_centers[df_labels['WR'].iloc[i]])[0,1] for i in range(len(df_labels))])
df_labels['corr'] = corrs
# df_labels.loc[df_labels['corr']<=0.25,'WR']=np.unique(df_labels['WR'])[-1]

labels_era5 = df_labels['WR'].values

  c /= stddev[:, None]
  c /= stddev[None, :]


In [7]:
EOF_ERA5 = get_EOFs_from_PCs(pcs_era5,data_region_era5.Z_anom)

In [8]:
path_anoms = '/glade/derecho/scratch/jhayron/Data4WRsClimateChange/E3SMv2_pi/anoms_standardized.nc'

In [9]:
# Reorder PCs from the model
anoms = xr.open_dataset(path_anoms)
anoms = anoms.Z_anom.compute()
# anoms = anoms
region = [180, 330, 20, 80]
data_region = extract_region(anoms, region)#.sel(time=slice(None, '2100-12-31'))

In [10]:
pcs_model, variance_explained_member = compute_pcs(data_region)

In [11]:
EOF_model = get_EOFs_from_PCs(pcs_model,data_region)
EOF_ERA5_modelcoords = EOF_ERA5.sel(lat=EOF_model.lat,lon=EOF_model.lon,method='nearest')
indices_pcs_ordered, signs_pcs = reorder_model_eofs(EOF_ERA5_modelcoords, EOF_model)
pcs_model = pcs_model[:,indices_pcs_ordered] * signs_pcs
pcs_model = pd.DataFrame(pcs_model,index = data_region.time)

In [12]:
distances_model = euclidean_distances(cluster_centers, pcs_model.values)
labels_model = distances_model.argmin(axis=0)
df_labels_model = pd.DataFrame(labels_model,index=pcs_model.index,columns = ['WR'])

df_labels_model['distances'] = distances_model.min(axis=0)
corrs = np.array([np.corrcoef(pcs_model.values[i],cluster_centers[df_labels_model['WR'].iloc[i]])[0,1] for i in range(len(df_labels_model))])
df_labels_model['corr'] = corrs
# df_labels_model.loc[df_labels_model['corr']<=0.25,'WR']=np.unique(df_labels_model['WR'])[-1]


  c /= stddev[:, None]
  c /= stddev[None, :]


In [13]:
df_labels_model.to_csv(f'labels_pi/df_labels_pi_e3smv2.csv')

In [14]:
avgs_model = get_average_fields_for_centroids(anoms,df_labels_model)

In [15]:
freqs_model = []
for label in np.unique(labels_model):
    freqs_model.append(100*len(labels_model[labels_model==label])/len(labels_model))

In [16]:
names = ["Polar High", "Pacific Trough", "Pacific Ridge", "Alaskan Ridge", "Atlantic Ridge" ,"No WR"]

In [17]:
plot_multiple_maps(avgs_model,freqs_model,region,names=names, 
                   path_save=f'Figures/Composites_E3SMv2pi.png')