In [1]:
import datetime as dt
import numpy as np
import xarray as xr

import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
import pandas as pd
from dateutil.relativedelta import relativedelta
import cartopy.crs as ccrs
import scipy as sp
import subprocess
from joblib import Parallel, delayed
import dask.array as da
from scipy.stats import spearmanr
import seaborn as sns
import regionmask
import warnings
import cartopy.feature as cfeature
warnings.filterwarnings("ignore")
import matplotlib.colors as mcolors

In [66]:
#make a download directory
import os
if not os.path.exists('download'):
    os.makedirs('download')
if not os.path.exists('data'):
    os.makedirs('data')
if not os.path.exists('figures'):
    os.makedirs('figures')

In [3]:
download = False

1_ SST vs SPI Teleconnection (Lead 2.5) Oct Start OND Season

2_ Diff between ENSO and non-ENSO years Skill (Season wise) Correlation and RPSS and if possible for Reliability

3_ Trend in Model SP

## Load Data

In [4]:
cfsv2_url = "https://iridl.ldeo.columbia.edu/SOURCES/.Models/.NMME/.NCEP-CFSv2/.HINDCAST/.PENTAD_SAMPLES_FULL/.prec/Y/12/37/RANGE/X/32/60/RANGE/S/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
cfsv2_path = 'download/cfsv2_precip.nc'

gfdlspear_url = "https://iridl.ldeo.columbia.edu/SOURCES/.Models/.NMME/.GFDL-SPEAR/.HINDCAST/.MONTHLY/.prec/Y/12/37/RANGE/X/32/60/RANGE/S/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
gfdlspear_path = 'download/gfdlspear_precip.nc'

cansipsic4_url = "https://iridl.ldeo.columbia.edu/SOURCES/.Models/.NMME/.CanSIPS-IC4/.HINDCAST/.MONTHLY/.prec/Y/12/37/RANGE/X/32/60/RANGE/S/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
cansipsic4_path = 'download/cansipsic4_precip.nc'

cesm1_url = "https://iridl.ldeo.columbia.edu/SOURCES/.Models/.NMME/.COLA-RSMAS-CESM1/.MONTHLY/.prec/Y/12/37/RANGE/X/32/60/RANGE/S/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
cesm1_path = 'download/cesm1_precip.nc'

colaccsm4_url = "https://iridl.ldeo.columbia.edu/SOURCES/.Models/.NMME/.COLA-RSMAS-CCSM4/.MONTHLY/.prec/Y/12/37/RANGE/X/32/60/RANGE/S/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
colaccsm4_path = 'download/colaccsm4_precip.nc'

nasageos1_url = "https://iridl.ldeo.columbia.edu/SOURCES/.Models/.NMME/.NASA-GEOSS2S/.HINDCAST/.MONTHLY/.prec/Y/12/37/RANGE/X/32/60/RANGE/S/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
nasageos1_path = 'download/nasageos1_precip.nc'
nasageos2_url = "https://iridl.ldeo.columbia.edu/SOURCES/.Models/.NMME/.NASA-GEOSS2S/.FORECAST/.MONTHLY/.prec/Y/12/37/RANGE/X/32/60/RANGE/S/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
nasageos2_path = 'download/nasageos2_precip.nc'

if download:
    print(cfsv2_url)
    subprocess.call(['curl', '-b', 'cookies.txt', '-k', cfsv2_url, '-o', cfsv2_path])
    print(gfdlspear_url)
    subprocess.call(['curl', '-b', 'cookies.txt', '-k', gfdlspear_url, '-o', gfdlspear_path])
    print(cesm1_url)
    subprocess.call(['curl', '-b', 'cookies.txt', '-k', cesm1_url, '-o', cesm1_path])
    print(colaccsm4_url)
    subprocess.call(['curl', '-b', 'cookies.txt', '-k', colaccsm4_url, '-o', colaccsm4_path])
    print(nasageos1_url)
    subprocess.call(['curl', '-b', 'cookies.txt', '-k', nasageos1_url, '-o', nasageos1_path])
    print(nasageos2_url)
    subprocess.call(['curl', '-b', 'cookies.txt', '-k', nasageos2_url, '-o', nasageos2_path])
    #print(cansipsic4_url)
    #subprocess.call(['curl', '-b', 'cookies.txt', '-k', cansipsic4_url, '-o', cansipsic4_path])
    

In [5]:
cfsv2 = xr.open_dataset(cfsv2_path, decode_times=True)
gfdlspear = xr.open_dataset(gfdlspear_path, decode_times=True)
cesm1 = xr.open_dataset(cesm1_path, decode_times=True)
colaccsm4 = xr.open_dataset(colaccsm4_path, decode_times=True)
nasageos = xr.open_dataset(nasageos1_path, decode_times=True)
nasageos2 = xr.open_dataset(nasageos2_path, decode_times=True).isel(M=slice(0, 4))
#concatenate the two nasageos datasets
nasageos = xr.concat([nasageos, nasageos2], dim='S')
#cansipsic4 = xr.open_dataset(cansipsic4_path, decode_times=True)

In [6]:
cmap_url = "https://iridl.ldeo.columbia.edu/SOURCES/.NOAA/.NCEP/.CPC/.Merged_Analysis/.monthly/.latest/.ver1/.prcp_est/X/-180/1/179/GRID/Y/-90/1/90/GRID/Y/12/37/RANGE/X/32/60/RANGE/T/(days%20since%201960-01-01)/streamgridunitconvert/data.nc"
cmap_path = 'download/cmap_precip.nc'

print(cmap_url)
subprocess.call(['curl', '-b', 'cookies.txt', '-k', cmap_url, '-o', cmap_path])

cmap = xr.open_dataset(cmap_path, decode_times=True)
#change datetime to 01-MM-YYYY
new_dates = pd.to_datetime(cmap['T'].values, format="%d-%m-%Y").strftime("01-%m-%Y")
cmap['T'] = pd.to_datetime(new_dates, format="%d-%m-%Y")
obs = cmap['prcp_est']

obs_3m = obs + obs.shift(T=1) + obs.shift(T=2)
obs_3m = obs_3m.dropna('T')


https://iridl.ldeo.columbia.edu/SOURCES/.NOAA/.NCEP/.CPC/.Merged_Analysis/.monthly/.latest/.ver1/.prcp_est/X/-180/1/179/GRID/Y/-90/1/90/GRID/Y/12/37/RANGE/X/32/60/RANGE/T/(days%20since%201960-01-01)/streamgridunitconvert/data.nc


In [7]:
spi_obs = xr.load_dataset('../data/spi3_cmap_1x1.nc')
#rename variable __xarray_dataarray_variable__ to spi
spi_obs = spi_obs.rename({'__xarray_dataarray_variable__':'spi'}).spi.sortby('T')

In [8]:
n_lead = 6

# Skill

In [9]:
# --- Load Data ---
spi_cfsv2 = xr.open_dataset('data/spi_hindcast_cfsv2.nc').spi
spi_gfdlspear = xr.open_dataset('data/spi_hindcast_gfdlspear.nc').spi
spi_cesm1 = xr.open_dataset('data/spi_hindcast_cesm1.nc').spi
spi_colaccsm4 = xr.open_dataset('data/spi_hindcast_colaccsm4.nc').spi
spi_nasageos = xr.open_dataset('data/spi_hindcast_nasageos.nc').spi
spi_multimodel = xr.open_dataset('data/spi_hindcast_multimodel.nc').spi

precip_cfsv2 = xr.open_dataset('data/precip_fcast_cfsv2.nc').precip
precip_gfdlspear = xr.open_dataset('data/precip_fcast_gfdlspear.nc').precip
precip_cesm1 = xr.open_dataset('data/precip_fcast_cesm1.nc').precip
precip_colaccsm4 = xr.open_dataset('data/precip_fcast_colaccsm4.nc').precip
precip_nasageos = xr.open_dataset('data/precip_fcast_nasageos.nc').precip
precip_multimodel = xr.open_dataset('data/precip_fcast_multimodel.nc').precip

spi_hindcast_dict = {
    'cfsv2': spi_cfsv2, 'gfdlspear': spi_gfdlspear, 'cesm1': spi_cesm1,
    'colaccsm4': spi_colaccsm4, 'nasageos': spi_nasageos, 'MME': spi_multimodel
}


precip_fcast_dict = {
    'cfsv2': precip_cfsv2, 'gfdlspear': precip_gfdlspear, 'cesm1': precip_cesm1,
    'colaccsm4': precip_colaccsm4, 'nasageos': precip_nasageos, 'MME': precip_multimodel
}

# --- Create Output Directories ---
for model in spi_hindcast_dict.keys():
    out_dir = f'figures/MME/corr/precip/{model}'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    if not os.path.exists(f'figures/MME/corr/spi/{model}'):
        os.makedirs(f'figures/MME/corr/spi/{model}')

# --- Create Global Land Mask (based on grid of multimodel as reference) ---
mask = regionmask.defined_regions.natural_earth_v5_0_0.land_110.mask(
    spi_multimodel.X, spi_multimodel.Y
)

for model in spi_hindcast_dict.keys():
    # Mask ocean as NaN (keep land only)
    spi_hindcast_dict[model] = spi_hindcast_dict[model].where(mask == 0, -999)
    precip_fcast_dict[model] = precip_fcast_dict[model].where(mask == 0, -999)

In [11]:
n_lead=6

In [57]:
# Load data and parse Date column
nino34 = pd.read_csv(
    "data/nino34.csv",
    parse_dates=["Date"]
)

# Set Date as index
nino34 = nino34.set_index("Date")

# Select 1950–2020
nino34_ = nino34.loc["1950-01-01":"2025-12-31", "NINO3.4"]

# 3-month rolling mean
nino_3m = pd.DataFrame(nino34_.rolling(window=3, center=False).mean(), columns=["NINO3.4"])
nino_3m["ENSO"] = (
    (nino_3m["NINO3.4"] >= 0.5) |
    (nino_3m["NINO3.4"] <= -0.5)
).astype(int)


# Correlation

In [63]:
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
seasons = ['NDJ', 'DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND']
selected_seasons = ['OND', 'NDJ', 'DJF', 'JFM', 'FMA', 'MAM']
seasons_to_label = {"OND":"a)", "NDJ":"b)", "DJF":"c)", "JFM":"d)", "FMA":"e)", "MAM":"f)"}



In [73]:

# --------------------------------------------------
# ENSO dummy → xarray-ready (monthly, datetime index)
# --------------------------------------------------
enso_dummy = nino_3m["ENSO"].copy()
enso_dummy.index = pd.to_datetime(enso_dummy.index)

lead_time = 2  # 2.5-month lead

for nmme in spi_hindcast_dict.keys():

    spi_hindcast = spi_hindcast_dict[nmme]
    precip_fcast = precip_fcast_dict[nmme]

    season_corr_precip = {s: {"ENSO": [], "Neutral": []} for s in selected_seasons}
    season_corr_spi    = {s: {"ENSO": [], "Neutral": []} for s in selected_seasons}

    for s_month in range(1, 13):

        t_month = (s_month + lead_time) % 12 or 12
        target_season = seasons[t_month - 1]
        if target_season not in selected_seasons:
            continue

        # ---- forecasts ----
        P_fcst_mon = (
            precip_fcast
            .sel(S=precip_fcast["S.month"] == s_month)
            .isel(L=lead_time)
            .drop_vars("L")
        )

        spi_fcst_mon = (
            spi_hindcast
            .sel(S=spi_hindcast["S.month"] == s_month)
            .isel(L=lead_time)
            .drop_vars("L")
        )

        # shift S → T
        T_vals = [s + pd.DateOffset(months=lead_time) for s in P_fcst_mon["S"].values]
        P_fcst_mon = P_fcst_mon.assign_coords(S=T_vals).rename({"S": "T"})
        spi_fcst_mon = spi_fcst_mon.assign_coords(S=T_vals).rename({"S": "T"})

        # ---- observations ----
        P_obs_mon = obs_3m.sel(T=obs_3m["T.month"] == t_month).sel(T=P_fcst_mon["T"])
        spi_obs_mon = spi_obs.sel(T=spi_obs["T.month"] == t_month).sel(T=spi_fcst_mon["T"])

        # ---- ENSO mask as xarray DataArray ----
        enso_state = xr.DataArray(
            enso_dummy.reindex(P_fcst_mon["T"].values).values,
            coords={"T": P_fcst_mon["T"].values},
            dims="T"
        )

        enso_mask = enso_state == 1
        neutral_mask = enso_state == 0

        # ---- ENSO-active ----
        if enso_mask.sum() > 3:
            season_corr_precip[target_season]["ENSO"].extend(
                xr.corr(
                    P_fcst_mon.where(enso_mask, drop=True),
                    P_obs_mon.where(enso_mask, drop=True),
                    dim="T"
                ).values.flatten()
            )

            season_corr_spi[target_season]["ENSO"].extend(
                xr.corr(
                    spi_fcst_mon.where(enso_mask, drop=True),
                    spi_obs_mon.where(enso_mask, drop=True),
                    dim="T"
                ).values.flatten()
            )

        # ---- Neutral ----
        if neutral_mask.sum() > 3:
            season_corr_precip[target_season]["Neutral"].extend(
                xr.corr(
                    P_fcst_mon.where(neutral_mask, drop=True),
                    P_obs_mon.where(neutral_mask, drop=True),
                    dim="T"
                ).values.flatten()
            )

            season_corr_spi[target_season]["Neutral"].extend(
                xr.corr(
                    spi_fcst_mon.where(neutral_mask, drop=True),
                    spi_obs_mon.where(neutral_mask, drop=True),
                    dim="T"
                ).values.flatten()
            )

    # --------------------------------------------------
    # FINAL PLOT: one figure, all seasons, ENSO vs Neutral
    # --------------------------------------------------
    for var, corr_dict in [("precip", season_corr_precip), ("spi", season_corr_spi)]:

        df_plot = pd.concat(
            [
                pd.DataFrame({
                    "Season": season,
                    "ENSO_state": state,
                    "Correlation": corr_dict[season][state]
                })
                for season in selected_seasons
                for state in ["ENSO", "Neutral"]
            ],
            ignore_index=True
        )

        plt.figure(figsize=(8, 6))

        sns.boxplot(
            data=df_plot,
            x="Season",
            y="Correlation",
            hue="ENSO_state",
            showfliers=False,
            palette={"Neutral": "lightblue", "ENSO": "cornflowerblue"},  # blue / red
            linewidth=1.6
        )

        # Zero line
        plt.axhline(0, color="black", linestyle="--", linewidth=1.2)

        # Title & labels
        plt.title(
            f"{nmme.upper()} – ENSO-conditioned skill",
            fontsize=18,
            fontweight="bold",
            pad=10
        )

        plt.ylabel("Spatial Correlation", fontsize=14, fontweight="bold")
        plt.xlabel("")

        # Tick styling
        plt.xticks(fontsize=13, fontweight="bold")
        plt.yticks(fontsize=13)

        # Legend styling
        plt.legend(
            title="",
            fontsize=14,
            frameon=False,
            loc="lower right"
        )

        # Axis spine styling
        ax = plt.gca()
        for spine in ["top", "right"]:
            ax.spines[spine].set_visible(False)

        for spine in ["left", "bottom"]:
            ax.spines[spine].set_linewidth(1.4)

        plt.tight_layout()


        plt.savefig(
            f"figures/{var}_{nmme}_lead2p5_ENSO_seasons.pdf",
            dpi=300
        )
        plt.close()
