In [None]:
import numpy as np
import xarray as xr
# our local module:
import wavenumber_frequency_functions as wf
import matplotlib as mpl
import matplotlib.pyplot as plt

def wf_analysis(x, **kwargs):
    """Return normalized spectra of x using standard processing parameters."""
    # Get the "raw" spectral power
    # OPTIONAL kwargs: 
    # segsize, noverlap, spd, latitude_bounds (tuple: (south, north)), dosymmetries, rmvLowFrq

    z2 = wf.spacetime_power(x, **kwargs)
    z2avg = z2.mean(dim='component')
    z2.loc[{'frequency':0}] = np.nan # get rid of spurious power at \nu = 0
    # the background is supposed to be derived from both symmetric & antisymmetric
    background = wf.smooth_wavefreq(z2avg, kern=wf.simple_smooth_kernel(), nsmooth=50, freq_name='frequency')
    # separate components
    z2_sym = z2[0,...]
    z2_asy = z2[1,...]
    # normalize
    nspec_sym = z2_sym / background 
    nspec_asy = z2_asy / background
    return nspec_sym, nspec_asy


def plot_normalized_symmetric_spectrum(s, ax, ofil=None):
    """Basic plot of normalized symmetric power spectrum with shallow water curves."""
    fb = [0, .5]  # frequency bounds for plot
    # get data for dispersion curves:
    swfreq,swwn = wf.genDispersionCurves()
    # swfreq.shape # -->(6, 3, 50)
    swf = np.where(swfreq == 1e20, np.nan, swfreq)
    swk = np.where(swwn == 1e20, np.nan, swwn)
    
#    fig = plt.subplots()
    c = 'darkgray' # COLOR FOR DISPERSION LINES/LABELS
    z = s.transpose().sel(frequency=slice(*fb), wavenumber=slice(-10,10))
    z.loc[{'frequency':0}] = np.nan
    kmesh0, vmesh0 = np.meshgrid(z['wavenumber'], z['frequency'])
    img = ax.contourf(kmesh0, vmesh0, z, levels=np.linspace(0.2, 3.0, 15), cmap='RdGy_r',  extend='both')
    for ii in range(3,6):
        ax.plot(swk[ii, 0,:], swf[ii,0,:], color=c)
        ax.plot(swk[ii, 1,:], swf[ii,1,:], color=c)
        ax.plot(swk[ii, 2,:], swf[ii,2,:], color=c)
    ax.axvline(0, linestyle='dashed', color='lightgray')
    ax.set_xlim([-10,10])
    ax.set_ylim(fb)    
    ax.set_title("Normalized Symmetric Component EXPT04(L=1hr;O=4hr)")
    plt.colorbar(img, ax=ax)
#     if ofil is not None:
#         fig.savefig(ofil, bbox_inches='tight', dpi=144)


def plot_normalized_asymmetric_spectrum(s, ax, ofil=None):
    """Basic plot of normalized symmetric power spectrum with shallow water curves."""

    fb = [0, .5]  # frequency bounds for plot
    # get data for dispersion curves:
    swfreq,swwn = wf.genDispersionCurves()
    # swfreq.shape # -->(6, 3, 50)
    swf = np.where(swfreq == 1e20, np.nan, swfreq)
    swk = np.where(swwn == 1e20, np.nan, swwn)

#    fig = plt.subplots()
    c = 'darkgray' # COLOR FOR DISPERSION LINES/LABELS
    z = s.transpose().sel(frequency=slice(*fb), wavenumber=slice(-10,10))
    z.loc[{'frequency':0}] = np.nan
    kmesh0, vmesh0 = np.meshgrid(z['wavenumber'], z['frequency'])
    img = ax.contourf(kmesh0, vmesh0, z, levels=np.linspace(0.2, 1.8, 17), cmap='RdGy_r', extend='both')
    for ii in range(0,3):
        ax.plot(swk[ii, 0,:], swf[ii,0,:], color=c)
        ax.plot(swk[ii, 1,:], swf[ii,1,:], color=c)
        ax.plot(swk[ii, 2,:], swf[ii,2,:], color=c)
    ax.axvline(0, linestyle='dashed', color='lightgray')
    ax.set_xlim([-10,10])
    ax.set_ylim(fb)
    ax.set_title("Normalized Anti-symmetric Component EXPT04(L=1hr;O=4hr)")
    plt.colorbar(img, ax=ax)
#     if ofil is not None:
#         fig.savefig(ofil, bbox_inches='tight', dpi=144)

#
# LOAD DATA, x = DataArray(time, lat, lon), e.g., daily mean precipitation
#
def get_data(filename, variablename):
    try: 
        ds = xr.open_dataset(filename)
        ds = ds.sel(time=slice('2001-01-01','2015-12-31'))
#        for i in range(ds['olr'].values.shape[0]):
#            ds['olr'].values[i,:,:][np.isnan(ds['olr'].values[i,:,:])]=np.random.randint(100,300,1)
#        ds['olr'][np.isnan(ds['olr'])]=0
    except ValueError:
        ds = xr.open_dataset(filename, decode_times=False)
    
    return ds[variablename]

In [None]:
latBound = (15,-15)  # latitude bounds for OBSERVATION
spd      = 1    # SAMPLES PER DAY
nDayWin  = 96   # Wheeler-Kiladis [WK] temporal window length (days)
nDaySkip = -65  # time (days) between temporal windows [segments]
                    # negative means there will be overlapping temporal segments
twoMonthOverlap = 65
opt      = {'segsize': nDayWin, 
            'noverlap': twoMonthOverlap, 
            'spd': spd, 
            'latitude_bounds': latBound, 
            'dosymmetries': True, 
            'rmvLowFrq':True}

In [None]:
import xarray as xr
from pathlib import Path
import myfunctions as mf

In [None]:
# =========================
# User-defined metadata
# =========================

varname = "pr"

# =========================
# Base CEDA paths
# =========================

CEDA_BASE = Path("/badc/cmip6/data/CMIP6")

In [None]:
#Model Names
MODELS = {
    "UKESM1-0-LL":  {"institution": "MOHC",         "ensemble": "r1i1p1f2",  "grid": "gn",},
    "CNRM-ESM2-1":  {"institution": "CNRM-CERFACS", "ensemble": "r1i1p1f2",  "grid": "gr",},
    # "MPI-ESM1-2-LR":{"institution": "MPI-M",        "ensemble": "r1i1p1f1",  "grid": "gn",},
    # "CESM2-WACCM":  {"institution": "NCAR",         "ensemble": "r1i1p1f1",  "grid": "gn",},
    # "IPSL-CM6A-LR": {"institution": "IPSL",         "ensemble": "r1i1p1f1",  "grid": "gr",},
}


In [None]:
#Experiment details
EXPERIMENTS = {
    "HIST":     {"project": "CMIP",        "scenario": "historical", "color": "black"},
    "SSP245":   {"project": "ScenarioMIP", "scenario": "ssp245"},
    "SSP585":   {"project": "ScenarioMIP", "scenario": "ssp585"},
    "G6solar":  {"project": "GeoMIP",      "scenario": "G6solar"},
    "G6sulfur": {"project": "GeoMIP",      "scenario": "G6sulfur"},
}

In [None]:
import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(
    nrows=len(MODELS),
    ncols=2,   # symmetric | asymmetric
    figsize=(12, 3 * len(MODELS)),
)

axes = np.atleast_2d(axes)

for row, (model_name, model_meta) in enumerate(MODELS.items()):

    # --- read DAILY data (unchanged logic) ---
    for exp, meta in EXPERIMENTS.items():

        if exp != "SSP245":   # WK usually one experiment at a time
            continue

        if model_name == "CESM2-WACCM":
            ensemble = "r1i1p1f2" if meta["scenario"] == "G6sulfur" else "r1i1p1f1"
        else:
            ensemble = model_meta["ensemble"]

        base = (
            CEDA_BASE
            / meta["project"]
            / model_meta["institution"]
            / model_name
            / meta["scenario"]
            / ensemble
            / "day"
            / varname
            / model_meta["grid"]
            / "latest"
        )

        print(f"Reading {model_name} {exp}")
        ds = mf.open_files(str(base))
        da = mf.read_var(ds, varname)

    # --- WK REQUIRED FIXES ---
    da = da.rename({"lat": "lat", "lon": "lon"})
    da = da.sortby("lat")
    da = da.sel(time=slice("2071-01-01", "2100-12"))
    da = da.load()

    # ensure daily
    da = da.resample(time="1D").mean()

    # --- WK OPTIONS ---
    latBound = (-15, 15)
    spd = 1
    nDayWin = 96
    twoMonthOverlap = 65

    opt = dict(
        segsize=nDayWin,
        noverlap=twoMonthOverlap,
        spd=spd,
        latitude_bounds=latBound,
        dosymmetries=True,
        rmvLowFrq=True,
    )

    # --- WK ANALYSIS ---
    sym, asym = wf_analysis(da, **opt)

    # --- PLOT ---
    plot_normalized_symmetric_spectrum(sym, axes[row, 0])
    plot_normalized_asymmetric_spectrum(asym, axes[row, 1])

    axes[row, 0].set_ylabel(model_name, fontsize=9)

axes[0, 0].set_title("Symmetric", fontsize=11)
axes[0, 1].set_title("Asymmetric", fontsize=11)

plt.suptitle("Wheelerâ€“Kiladis Spectra (Daily CMIP6)", y=0.95)
plt.tight_layout()
plt.show()


need o plot spectra for all models.
and then need to plot the arrow-type plot for them to show if MJO moves faster or slower in a GeoMIP simulation contrasting fairst 30 and the last 30 years
