In [1]:
import numpy as np
from sklearn.linear_model import LinearRegression
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('default')
sns.set_palette("colorblind")
from matplotlib import rcParams
rcParams['font.family'] = 'sans-serif'
rcParams['font.weight'] = 'light'
rcParams['mathtext.fontset'] = 'cm'
rcParams['mathtext.rm'] = 'serif'
mpl.rcParams["figure.dpi"] = 500
import cartopy.crs as ccrs
import cartopy as ct
import matplotlib.colors as c
import regionmask
import cmasher as cmr
import scipy
from cartopy.util import add_cyclic_point
mpl.rcParams['hatch.linewidth'] = 0.375
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from geocat.comp import eofunc_eofs, eofunc_pcs
from datetime import datetime
import warnings
from matplotlib.patches import Rectangle
import pdo_functions
import importlib
importlib.reload(pdo_functions)
import random
import numba
import statsmodels.api as sm
from sklearn.metrics import r2_score

In [2]:
# Open datasets
v_200 = xr.open_mfdataset('/hurrell-scratch2/ivyglade/pdo/conus404/v200/daily/*.nc', combine='nested', concat_dim='Time')
u_200 = xr.open_mfdataset('/hurrell-scratch2/ivyglade/pdo/conus404/u200/daily/*.nc', combine='nested', concat_dim='Time')

In [7]:
# subset only MAM
v_200_mam = v_200.sel(Time=v_200.Time.dt.month.isin([3, 4, 5]))
u_200_mam = u_200.sel(Time=u_200.Time.dt.month.isin([3, 4, 5]))

In [8]:
# Calculate wind speed
wind_200_mam = np.sqrt(u_200_mam['__xarray_dataarray_variable__'] **2 + v_200_mam['__xarray_dataarray_variable__'] **2)

In [23]:
# Group by month and take the monthly mean
wind_200_mam = wind_200_mam.resample(Time='1M').mean().dropna(dim='Time').load()

  self.index_grouper = pd.Grouper(


In [18]:
def detrend_dim(da, dim, deg):
    # detrend along a single dimension
    p = da.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    return da - fit

In [27]:
# Detrend
wind_200_mam_detrend = detrend_dim(wind_200_mam, 'Time', 2)

In [29]:
# Remove the seasonal cycle 
wind_200_mam_anoms = wind_200_mam_detrend.groupby('Time.month') - wind_200_mam_detrend.groupby('Time.month').mean()

In [35]:
# Open SST data and compute the PDO and Nino3.4
sst = xr.open_dataset('/hurrell-scratch2/ivyglade/pdo/HadISST_sst.nc')['sst']

pdo = pdo_functions.pdo_from_hadisst(sst, 1980, 2010)

nino_34 = pdo_functions.calc_nino_34_timeseries(sst, False, 1980, 2010)

In [36]:
# Standardize Nino3.4
nino_34_std = (nino_34 - nino_34.mean()) / nino_34.std()

In [54]:
# Convert pdo to xarray
pdo_xr = xr.DataArray(pdo, coords={'time':nino_34['time']}, dims=['time'])

# Subset only 1980-2024
pdo_1980_2022 = pdo_xr.sel(time=pdo_xr.time.dt.year.isin(np.arange(1980, 2023, 1)))
nino_34_1980_2022 = nino_34_std.sel(time=nino_34_std.time.dt.year.isin(np.arange(1980, 2023, 1)))

# only MAM and take seasonal averages
pdo_1980_2022_mam = pdo_1980_2022.sel(time=pdo_1980_2022.time.dt.month.isin([3, 4, 5])).resample(time='1M').mean().dropna(dim='time')
nino_34_1980_2022_mam = nino_34_1980_2022.sel(time=nino_34_1980_2022.time.dt.month.isin([3, 4, 5])).resample(time='1M').mean().dropna(dim='time')

  self.index_grouper = pd.Grouper(
  self.index_grouper = pd.Grouper(


In [55]:
lat_len = len(wind_200_mam_anoms.west_east)
lon_len = len(wind_200_mam_anoms.south_north)

In [None]:
coef      = np.full((lat_len, lon_len, 2), np.nan)
intercept = np.full((lat_len, lon_len), np.nan)
r2_partial = np.full((lat_len, lon_len, 2), np.nan)  # for partial R²
r2_total = np.full((lat_len, lon_len), np.nan)

model = LinearRegression()

for i in range(lat_len):
    for j in range(lon_len):
        PDO = pdo_1980_2022_mam.values
        Nino = nino_34_1980_2022_mam.values
        wind = wind_200_mam_anoms[:, i, j]

        X_full = np.column_stack([PDO, Nino])

        valid = ~np.isnan(X_full).any(axis=1) & ~np.isnan(wind)

        if np.sum(valid) >= 3:
            X_valid = X_full[valid]
            y_valid = wind[valid]

            # Fit full model
            model.fit(X_valid, y_valid)
            y_pred_full = model.predict(X_valid)
            r2_full = r2_score(y_valid, y_pred_full)

            r2_total[i, j] = model.score(X_valid, y_valid)

            coef[i, j, :] = model.coef_
            intercept[i, j] = model.intercept_

            # Partial R² for PDO (remove PDO, use only Nino)
            model.fit(X_valid[:, [1]], y_valid)
            y_pred_nino_only = model.predict(X_valid[:, [1]])
            r2_nino_only = r2_score(y_valid, y_pred_nino_only)
            r2_partial[i, j, 1] = r2_full - r2_nino_only

            # Partial R² for Nino (remove Nino, use only PDO)
            model.fit(X_valid[:, [0]], y_valid)
            y_pred_pdo_only = model.predict(X_valid[:, [0]])
            r2_pdo_only = r2_score(y_valid, y_pred_pdo_only)
            r2_partial[i, j, 0] = r2_full - r2_pdo_only

    print(f'latitude {i+1} out of {lat_len} is complete.')

latitude 1 out of 1367 is complete.
latitude 2 out of 1367 is complete.
latitude 3 out of 1367 is complete.
latitude 4 out of 1367 is complete.
latitude 5 out of 1367 is complete.
latitude 6 out of 1367 is complete.
latitude 7 out of 1367 is complete.
latitude 8 out of 1367 is complete.
latitude 9 out of 1367 is complete.
latitude 10 out of 1367 is complete.
latitude 11 out of 1367 is complete.
latitude 12 out of 1367 is complete.
latitude 13 out of 1367 is complete.
latitude 14 out of 1367 is complete.
latitude 15 out of 1367 is complete.
latitude 16 out of 1367 is complete.
latitude 17 out of 1367 is complete.
latitude 18 out of 1367 is complete.
latitude 19 out of 1367 is complete.
latitude 20 out of 1367 is complete.
latitude 21 out of 1367 is complete.
latitude 22 out of 1367 is complete.
latitude 23 out of 1367 is complete.
latitude 24 out of 1367 is complete.
latitude 25 out of 1367 is complete.
latitude 26 out of 1367 is complete.
latitude 27 out of 1367 is complete.
latitude 2