In [None]:
import numpy as np
from sklearn.linear_model import LinearRegression
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('default')
sns.set_palette("colorblind")
from matplotlib import rcParams
rcParams['font.family'] = 'sans-serif'
rcParams['font.weight'] = 'light'
rcParams['mathtext.fontset'] = 'cm'
rcParams['mathtext.rm'] = 'serif'
mpl.rcParams["figure.dpi"] = 500
import cartopy.crs as ccrs
import cartopy as ct
import matplotlib.colors as c
import regionmask
import cmasher as cmr
import scipy
from cartopy.util import add_cyclic_point
mpl.rcParams['hatch.linewidth'] = 0.375
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from geocat.comp import eofunc_eofs, eofunc_pcs
from datetime import datetime
import warnings
from matplotlib.patches import Rectangle
import pdo_functions
import importlib
importlib.reload(pdo_functions)
import random
import numba
import statsmodels.api as sm
from sklearn.metrics import r2_score
import pandas as pd

In [None]:
# Open NDSEV data
ndsev = xr.open_dataset('/hurrell-scratch2/ivyglade/pdo/ndsev/ndsev_1940-2024_mam.nc')['__xarray_dataarray_variable__']
ndsev_nocin = xr.open_dataset('/hurrell-scratch2/ivyglade/pdo/ndsev/ndsev_no_cin_1940-2024_mam.nc')['__xarray_dataarray_variable__']
ndsev_csonly = xr.open_dataset('/hurrell-scratch2/ivyglade/pdo/ndsev/ndsev_cs_only_1940_2024_mam.nc')['__xarray_dataarray_variable__']

In [None]:
# Need to compute monthly means
ndsev_monthly = np.zeros((255, 101, 237))
for i in range(85):
    # Load in one year of data
    ndsev_load = ndsev_csonly.isel(year=i).transpose('date', 'latitude', 'longitude').values

    # Sum up NDSEV for Mar, Apr, May, respectively
    ndsev_monthly[i*3] = ndsev_load[0:31].sum(axis=0)
    ndsev_monthly[i*3+1] = ndsev_load[31:61].sum(axis=0)
    ndsev_monthly[i*3+2] = ndsev_load[61:].sum(axis=0)

    # Progress
    print(f'{1940+i} is complete.')

In [None]:
months = pd.date_range('1940-03-01', '2025-01-01', freq='ME')
mam = months[months.month.isin([3, 4, 5])]

In [None]:
# Convert ndsev_monthly to xr
ndsev_monthly_xr = xr.DataArray(ndsev_monthly, coords={'time':mam, 'latitude':ndsev['latitude'], 'longitude':ndsev['longitude']}, dims=['time', 'latitude', 'longitude'])

In [None]:
# Remove seasonal cycle
ndsev_monthly_anoms = ndsev_monthly_xr.groupby('time.month') - ndsev_monthly_xr.groupby('time.month').mean()

In [None]:
# Open SST data and compute the PDO and Nino3.4
sst = xr.open_dataset('/hurrell-scratch2/ivyglade/pdo/HadISST_sst.nc')['sst']

pdo = pdo_functions.pdo_from_hadisst(sst, 1980, 2010)

nino_34 = pdo_functions.calc_nino_34_timeseries(sst, False, 1980, 2010)

In [None]:
# Standardize Nino3.4
nino_34_std = (nino_34 - nino_34.mean()) / nino_34.std()

In [None]:
# Convert pdo to xarray
pdo_xr = xr.DataArray(pdo, coords={'time':nino_34['time']}, dims=['time'])

# Subset only 1940-2024
pdo_1940_2024 = pdo_xr.sel(time=pdo_xr.time.dt.year.isin(np.arange(1940, 2025, 1)))
nino_34_1940_2024 = nino_34_std.sel(time=nino_34_std.time.dt.year.isin(np.arange(1940, 2025, 1)))

# only MAM and take seasonal averages
pdo_1940_2024_mam = pdo_1940_2024.sel(time=pdo_1940_2024.time.dt.month.isin([3, 4, 5]))#.resample(time='YE').mean()
nino_34_1940_2024_mam = nino_34_1940_2024.sel(time=nino_34_1940_2024.time.dt.month.isin([3, 4, 5]))#.resample(time='YE').mean()

In [None]:
lat_len = len(ndsev.latitude)
lon_len = len(ndsev.longitude)

In [None]:
coef      = np.full((lat_len, lon_len, 2), np.nan)
intercept = np.full((lat_len, lon_len), np.nan)
r2_partial = np.full((lat_len, lon_len, 2), np.nan)  # for partial R²
r2_total = np.full((lat_len, lon_len), np.nan)

model = LinearRegression()

for i in range(lat_len):
    for j in range(lon_len):
        PDO = pdo_1940_2024_mam.values
        Nino = nino_34_1940_2024_mam.values
        cape = ndsev_monthly_anoms[:, i, j]

        X_full = np.column_stack([PDO, Nino])

        valid = ~np.isnan(X_full).any(axis=1) & ~np.isnan(cape)

        if np.sum(valid) >= 3:
            X_valid = X_full[valid]
            y_valid = cape[valid]

            # Fit full model
            model.fit(X_valid, y_valid)
            y_pred_full = model.predict(X_valid)
            r2_full = r2_score(y_valid, y_pred_full)

            r2_total[i, j] = model.score(X_valid, y_valid)

            coef[i, j, :] = model.coef_
            intercept[i, j] = model.intercept_

            # Partial R² for PDO (remove PDO, use only Nino)
            model.fit(X_valid[:, [1]], y_valid)
            y_pred_nino_only = model.predict(X_valid[:, [1]])
            r2_nino_only = r2_score(y_valid, y_pred_nino_only)
            r2_partial[i, j, 1] = r2_full - r2_nino_only

            # Partial R² for Nino (remove Nino, use only PDO)
            model.fit(X_valid[:, [0]], y_valid)
            y_pred_pdo_only = model.predict(X_valid[:, [0]])
            r2_pdo_only = r2_score(y_valid, y_pred_pdo_only)
            r2_partial[i, j, 0] = r2_full - r2_pdo_only

    print(f'latitude {i+1} out of {lat_len} is complete.')

In [None]:
pdo_coef = coef.swapaxes(0, 2).swapaxes(1, 2)[0]
nino_coef = coef.swapaxes(0, 2).swapaxes(1, 2)[1]

In [None]:
pdo_r_2 = r2_partial.swapaxes(0, 2).swapaxes(1, 2)[1]
nino_r_2 = r2_partial.swapaxes(0, 2).swapaxes(1, 2)[0]

In [None]:
# IDing when the regression slope is positive
pdo_pos_coef = np.where(pdo_coef > 0, 1, pdo_coef)
nino_pos_coef = np.where(nino_coef > 0, 1, nino_coef)

In [None]:
# IDing when the regression slope is negative
pdo_coef_sign = np.where(pdo_pos_coef < 0, -1, 1)
nino_coef_sign = np.where(nino_pos_coef < 0, -1, 1)

In [None]:
# Calculating r from r^2 
pdo_r = pdo_coef_sign * np.sqrt(pdo_r_2)
nino_r = nino_coef_sign * np.sqrt(nino_r_2)

In [None]:
# Calculating the t-statistic
pdo_t = pdo_r * ((255-2)**0.5) / ((1-pdo_r **2)**0.5)
nino_t = nino_r * ((255-2)**0.5) / ((1-nino_r **2)**0.5)

In [None]:
# Calculating the p-values
pdo_p = 2 * (1 - scipy.stats.t.cdf(abs(pdo_t), 255-2))
nino_p = 2 * (1 - scipy.stats.t.cdf(abs(nino_t), 255-2))

In [None]:
# Calculating adjusted p-values to account for the False Discovery Rate
pdo_adj_p = pdo_functions.control_FDR(pdo_p, 101, 237)
nino_adj_p = pdo_functions.control_FDR(nino_p, 101, 237)

In [None]:
# Evaluating significance
pdo_sig = np.where(pdo_p < pdo_adj_p, 3, 0)
nino_sig = np.where(nino_p < nino_adj_p, 3, 0)

In [None]:
nino_sig.sum()

In [None]:
# Load full colormap
cmap_full = cmr.fusion_r
cmap_full_other = cmr.fusion

# Extract first half (lower values)
cmap_half = c.LinearSegmentedColormap.from_list(
    'fusion_r_half',
    cmap_full(np.linspace(0.5, 1.0, 256))
)

cmap_other_half = c.LinearSegmentedColormap.from_list(
    'fusion_r_half',
    cmap_full_other(np.linspace(0.5, 1.0, 256))
)

In [None]:
# land mask
land_110 = regionmask.defined_regions.natural_earth_v4_1_0.land_110
era5_land = xr.where(land_110.mask_3D(ndsev)==True, 1, np.nan).squeeze()

In [None]:
fig, ax = plt.subplots(2, 3, subplot_kw=dict(projection=ccrs.AlbersEqualArea(central_longitude=-97, central_latitude=36.5)))

ax = [ax[0, 0], ax[0, 1], ax[0, 2], \
      ax[1, 0], ax[1, 1], ax[1, 2]]

bounds = [-1.2, -1.1, -1, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, -0.05, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2]
# bounds = [-0.6, -0.55, -0.5, -0.45, -0.4, -0.35, -0.3, -0.25, -0.2, -0.15, -0.1, -0.05, -0.025, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]
norm = c.BoundaryNorm(bounds, plt.get_cmap('cmr.fusion_r').N)

r_bounds = [-0.3, -0.275, -0.25, -0.225, -0.2, -0.175, -0.15, -0.125, -0.1, -0.075, -0.05, -0.025, -0.0125, 0.0125, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.225, 0.25, 0.275, 0.3]
r_norm = c.BoundaryNorm(r_bounds, plt.get_cmap('cmr.fusion_r').N)

r_2_bounds = np.arange(0, 0.07375, 0.00375)
# r_2_bounds = np.arange(0, 0.055, 0.005)
r_2_norm = c.BoundaryNorm(r_2_bounds, cmap_half.N)

for i in range(6):
    ax[i].add_feature(ct.feature.STATES, lw=0.25, edgecolor='xkcd:gunmetal')
    ax[i].coastlines(lw=0.25, color='xkcd:gunmetal')
    ax[i].spines['geo'].set_linewidth(0.25)
    ax[i].spines['geo'].set_edgecolor('xkcd:gunmetal')
    ax[i].set_facecolor('xkcd:grayish')

ax[0].pcolormesh(ndsev['longitude'], ndsev['latitude'], pdo_coef*era5_land, transform=ccrs.PlateCarree(), shading='auto', cmap='cmr.fusion_r', norm=norm)
ax[3].pcolormesh(ndsev['longitude'], ndsev['latitude'], nino_coef*era5_land, transform=ccrs.PlateCarree(), shading='auto', cmap='cmr.fusion_r', norm=norm)

ax[1].pcolormesh(ndsev['longitude'], ndsev['latitude'], pdo_r*era5_land, transform=ccrs.PlateCarree(), shading='auto', cmap='cmr.fusion_r', norm=r_norm)
ax[4].pcolormesh(ndsev['longitude'], ndsev['latitude'], nino_r*era5_land, transform=ccrs.PlateCarree(), shading='auto', cmap='cmr.fusion_r', norm=r_norm)

ax[2].pcolormesh(ndsev['longitude'], ndsev['latitude'], pdo_r_2*era5_land, transform=ccrs.PlateCarree(), shading='auto', cmap=cmap_half, norm=r_2_norm)
ax[5].pcolormesh(ndsev['longitude'], ndsev['latitude'], nino_r_2*era5_land, transform=ccrs.PlateCarree(), shading='auto', cmap=cmap_half, norm=r_2_norm)

ax[2].contourf(ndsev['longitude'], ndsev['latitude'], pdo_sig*era5_land, transform=ccrs.PlateCarree(), hatches=[None, '\\\\\\\\\\\\\\\\\\'], colors=None, alpha=0)

ax[0].set_title('Regression Coefficients')
ax[1].set_title('Correlation')
ax[2].set_title('r$^2$')

ax[0].text(-3400000, -200000, 'PDO', rotation='vertical', fontweight='normal', fontsize=12)
ax[3].text(-3400000, -600000, 'Nino3.4', rotation='vertical', fontweight='normal', fontsize=12)

cax = plt.axes([0.01, .2, 0.31, 0.02])
cbar = plt.colorbar(mpl.cm.ScalarMappable(cmap='cmr.fusion_r', norm=norm), cax=cax, orientation='horizontal', spacing='proportional', extend='both', \
                    ticks=[-1.2, -0.8, -0.4, 0, 0.4, 0.8, 1.2])
cbar.set_label(r'days month$^{-1}$ standard deviation$^{-1}$', size=8, fontweight='normal', color='black')
cbar.ax.tick_params(which='both', labelsize=8, width=0.5, length=0, labelcolor='black')
cbar.outline.set_linewidth(0.5)
cbar.outline.set_color('black')
# cbar.outline.set_visible(False)

cax2 = plt.axes([0.34, .2, 0.31, 0.02])
cbar2 = plt.colorbar(mpl.cm.ScalarMappable(cmap='cmr.fusion_r', norm=r_norm), cax=cax2, orientation='horizontal', spacing='proportional', extend='both', \
                    ticks=[-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3])
cbar2.set_label(r'correlation (r)', size=8, fontweight='normal', color='black')
cbar2.ax.tick_params(which='both', labelsize=8, width=0.5, length=0, labelcolor='black')
cbar2.outline.set_linewidth(0.5)
cbar2.outline.set_color('black')
# cbar.outline.set_visible(False)

cax3 = plt.axes([0.68, .2, 0.31, 0.02])
cbar3 = plt.colorbar(mpl.cm.ScalarMappable(cmap=cmap_half, norm=r_2_norm), cax=cax3, orientation='horizontal', spacing='proportional', extend='max', \
                    ticks=[0.01, 0.03, 0.05, 0.07])
cbar3.set_label(r'r$^2$', size=8, fontweight='normal', color='black')
cbar3.ax.tick_params(which='both', labelsize=8, width=0.5, length=0, labelcolor='black')
cbar3.outline.set_linewidth(0.5)
cbar3.outline.set_color('black')
# cbar.outline.set_visible(False)

plt.subplots_adjust(left=0,
                    bottom=0.25, 
                    right=1., 
                    top=0.7, 
                    wspace=0.05, 
                    hspace=0.05)