# Notebook for calculating the significance between PI and Plio simulations for Fig. 1

In [1]:
import glob
import os

import xarray as xr
from scipy import stats

### Time shift helper function

In [2]:
def _shift_one_month_back_single(t):
    """Shift a single cftime datetime back by one month."""
    year = t.year
    month = t.month
    day = t.day
    hour = getattr(t, "hour", 0)
    minute = getattr(t, "minute", 0)
    second = getattr(t, "second", 0)

    if month == 1:
        return t.__class__(year - 1, 12, day, hour, minute, second)
    else:
        return t.__class__(year, month - 1, day, hour, minute, second)


def shift_time_back_one_month(ds):
    """
    Return a copy of ds with time shifted back by one month.
    Assumes ds.time is a cftime axis.
    """
    ds = ds.copy()
    new_time = [_shift_one_month_back_single(t) for t in ds["time"].values]
    ds["time"] = ("time", new_time)
    return ds

### Load files and subset years

In [3]:
def load_prect_timeseries(path, precc_glob, precl_glob, year_start, year_end):
    """
    Load PRECT = PRECC + PRECL, shift time back 1 month,
    and subset by [year_start, year_end] inclusive.

    Returns: DataArray PRECT(time, ncol, ...)
    """
    print(f"Reading in {path + precc_glob}...")
    precc_files = sorted(glob.glob(os.path.join(path, precc_glob)))
    precl_files = sorted(glob.glob(os.path.join(path, precl_glob)))

    if not precc_files:
        raise FileNotFoundError(
            f"No PRECC files matched {os.path.join(path, precc_glob)}"
        )
    if not precl_files:
        raise FileNotFoundError(
            f"No PRECL files matched {os.path.join(path, precl_glob)}"
        )

    ds_precc = xr.open_mfdataset(precc_files, combine="by_coords")
    ds_precl = xr.open_mfdataset(precl_files, combine="by_coords")

    # Shift time first
    ds_precc = shift_time_back_one_month(ds_precc)
    ds_precl = shift_time_back_one_month(ds_precl)

    da_precc = ds_precc["PRECC"]
    da_precl = ds_precl["PRECL"]

    prect = da_precc + da_precl
    prect.name = "PRECT"
    prect.attrs["long_name"] = "total precipitation rate (PRECC+PRECL)"

    # Subset by model year
    years = prect["time"].dt.year
    prect = prect.sel(time=(years >= year_start) & (years <= year_end))
    print(
        f"    PRECT years:\n    {prect.time[0].values.item().strftime('%Y-%m-%d')} to {prect.time[-1].values.item().strftime('%Y-%m-%d')}\n"
    )

    return prect

### Subset by months and run a gridpoint-wise Welch t-test

In [4]:
def subset_months(da, months):
    """Return da with only the specified calendar months (list of ints)."""
    return da.sel(time=da["time"].dt.month.isin(months))


def diff_and_ttest(plio_da, pi_da, months):
    """
    For a given month subset, compute:
      - mean difference (Plio - PI) over time
      - p-values from Welch t-test along time axis

    Returns: (diff_da, p_da), both with no 'time' dimension (only ncol, etc.).
    """
    # Subset months
    print("    Subsetting months...")
    plio_sub = subset_months(plio_da, months)
    pi_sub = subset_months(pi_da, months)

    # Means (Plio - PI)
    print("    Taking difference...")
    plio_sub_mean = plio_sub.mean("time")
    pi_sub_mean = pi_sub.mean("time")
    diff = plio_sub_mean - pi_sub_mean

    # Convert to numpy arrays for scipy
    a = plio_sub.values
    b = pi_sub.values

    # t-test along time axis
    print("    Calculating significance...")
    t_stat, p_val = stats.ttest_ind(
        a,
        b,
        axis=0,
        equal_var=False,
        nan_policy="omit",
    )

    # Build DataArray for p-values (same non-time dims as diff)
    non_time_dims = [d for d in plio_sub.dims if d != "time"]
    coords = {d: plio_sub[d] for d in non_time_dims}

    p_da = xr.DataArray(
        p_val,
        dims=non_time_dims,
        coords=coords,
        name="p_value",
    )
    p_da.attrs["long_name"] = "Welch t-test p-value (Plio vs PI)"

    return plio_sub_mean, pi_sub_mean, diff, p_da

### Load all four experiments

In [None]:
# paths
path = "/glade/work/malbright/final_nam_manuscript_files/remapped/"

# HR: Plio 30–59, PI 70–99
prect_plio_hr = load_prect_timeseries(
    path,
    precc_glob="b.e13.B1850C5CN.ne120_g16.pliohiRes.002.cam.h0.PRECC.*.nc",
    precl_glob="b.e13.B1850C5CN.ne120_g16.pliohiRes.002.cam.h0.PRECL.*.nc",
    year_start=30,
    year_end=59,
)

prect_pi_hr = load_prect_timeseries(
    path,
    precc_glob="b.e13.B1850C5CN.ne120_g16.tuning.005.cam.h0.PRECC.*.nc",
    precl_glob="b.e13.B1850C5CN.ne120_g16.tuning.005.cam.h0.PRECL.*.nc",
    year_start=70,
    year_end=99,
)

# LR: Plio 376–425, PI 239–288
prect_plio_lr = load_prect_timeseries(
    path,
    precc_glob="b.e13.B1850C5CN.ne30_g16.plio.001.cam.h0.PRECC.*.nc",
    precl_glob="b.e13.B1850C5CN.ne30_g16.plio.001.cam.h0.PRECL.*.nc",
    year_start=376,
    year_end=425,
)

prect_pi_lr = load_prect_timeseries(
    path,
    precc_glob="b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.PRECC.*.nc",
    precl_glob="b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.PRECL.*.nc",
    year_start=239,
    year_end=288,
)

Reading in /glade/work/malbright/final_nam_manuscript_files/remapped/b.e13.B1850C5CN.ne30_g16.plio.001.cam.h0.PRECC.*.nc...
    PRECT years:
    0376-01-01 to 0425-12-01

Reading in /glade/work/malbright/final_nam_manuscript_files/remapped/b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.PRECC.*.nc...
    PRECT years:
    0239-01-01 to 0288-12-01



### Run t-tests for JJAS, JJ, AS and build output datasets

In [None]:
month_groups = {
    "JJAS": [6, 7, 8, 9],
    "JJ": [6, 7],
    "AS": [8, 9],
}

outdir = "/glade/work/malbright/final_nam_manuscript_files/"
os.makedirs(outdir, exist_ok=True)

for label, months in month_groups.items():

    # High-resolution
    print("=========== Processing HR ===========")
    plio_mean_hr, pi_mean_hr, diff_hr, p_hr = diff_and_ttest(
        prect_plio_hr, prect_pi_hr, months
    )

    sig_hr = (p_hr < 0.05)
    sig_hr.name = f"sig_mask_{label}"
    sig_hr.attrs["long_name"] = f"Significance mask (p < 0.05) ({label})"
    sig_hr.attrs["note"] = (
        "1 = significant, 0 = not significant; "
        "two-sided Welch t-test, unequal variances"
    )

    # switch bools to ints
    sig_hr = sig_hr.astype("int16")

    ds_hr = xr.Dataset()

    ds_hr[sig_hr.name] = sig_hr

    # Plio mean
    plio_var_name_hr = f"prect_plio_mean_{label}"
    ds_hr[plio_var_name_hr] = plio_mean_hr
    ds_hr[plio_var_name_hr].attrs["long_name"] = f"Pliocene PRECT mean ({label})"
    ds_hr[plio_var_name_hr].attrs[
        "note"
    ] = f"High-res; Plio years: 30–59; months: {months}"

    # PI mean
    pi_var_name_hr = f"prect_pi_mean_{label}"
    ds_hr[pi_var_name_hr] = pi_mean_hr
    ds_hr[pi_var_name_hr].attrs["long_name"] = f"Preindustrial PRECT mean ({label})"
    ds_hr[pi_var_name_hr].attrs["note"] = f"High-res; PI years: 70–99; months: {months}"

    # Difference (Plio - PI)
    diff_name_hr = f"prect_diff_{label}"
    ds_hr[diff_name_hr] = diff_hr
    ds_hr[diff_name_hr].attrs[
        "long_name"
    ] = f"Plio - PI PRECT mean difference ({label})"
    ds_hr[diff_name_hr].attrs["note"] = (
        "High-res; Plio years: 30–59, PI years: 70–99; " f"months: {months}"
    )

    hr_outfile = os.path.join(
        outdir,
        f"prect_plio_minus_pi_HR_ttest_{label}.nc",
    )
    print("    Saving file...")
    ds_hr.to_netcdf(hr_outfile)
    print("    Saved!\n")

    # Low-resolution
    print("=========== Processing LR ===========")
    plio_mean_lr, pi_mean_lr, diff_lr, p_lr = diff_and_ttest(
        prect_plio_lr, prect_pi_lr, months
    )

    sig_lr = (p_lr < 0.05)
    sig_lr.name = f"sig_mask_{label}"
    sig_lr.attrs["long_name"] = f"Significance mask (p < 0.05) ({label})"
    sig_lr.attrs["note"] = (
        "1 = significant, 0 = not significant; "
        "two-sided Welch t-test, unequal variances"
    )

    sig_lr = sig_lr.astype("int16")

    ds_lr = xr.Dataset()

    ds_lr[sig_lr.name] = sig_lr

    # Plio mean
    plio_var_name_lr = f"prect_plio_mean_{label}"
    ds_lr[plio_var_name_lr] = plio_mean_lr
    ds_lr[plio_var_name_lr].attrs["long_name"] = f"Pliocene PRECT mean ({label})"
    ds_lr[plio_var_name_lr].attrs[
        "note"
    ] = f"Low-res; Plio years: 376–425; months: {months}"

    # PI mean
    pi_var_name_lr = f"prect_pi_mean_{label}"
    ds_lr[pi_var_name_lr] = pi_mean_lr
    ds_lr[pi_var_name_lr].attrs["long_name"] = f"Preindustrial PRECT mean ({label})"
    ds_lr[pi_var_name_lr].attrs[
        "note"
    ] = f"Low-res; PI years: 239–288; months: {months}"

    # Difference (Plio - PI)
    diff_name_lr = f"prect_diff_{label}"
    ds_lr[diff_name_lr] = diff_lr
    ds_lr[diff_name_lr].attrs[
        "long_name"
    ] = f"Plio - PI PRECT mean difference ({label})"
    ds_lr[diff_name_lr].attrs["note"] = (
        "Low-res; Plio years: 376–425, PI years: 239–288; " f"months: {months}"
    )

    lr_outfile = os.path.join(
        outdir,
        f"prect_plio_minus_pi_LR_ttest_{label}.nc",
    )
    print("    Saving file...")
    ds_lr.to_netcdf(lr_outfile)
    print("    Saved!\n")

    print(f">>Wrote HR + LR files for {label}:")
    print("    ", hr_outfile)
    print("    ", lr_outfile, "\n")

    Subsetting months...
    Taking difference...
    Calculating significance...
    Saving file...
    Saved!

>>Wrote HR + LR files for JJAS:
     /glade/work/malbright/final_nam_manuscript_files/prect_plio_minus_pi_LR_ttest_JJAS.nc 

    Subsetting months...
    Taking difference...
    Calculating significance...
    Saving file...
    Saved!

>>Wrote HR + LR files for JJ:
     /glade/work/malbright/final_nam_manuscript_files/prect_plio_minus_pi_LR_ttest_JJ.nc 

    Subsetting months...
    Taking difference...
    Calculating significance...
    Saving file...
    Saved!

>>Wrote HR + LR files for AS:
     /glade/work/malbright/final_nam_manuscript_files/prect_plio_minus_pi_LR_ttest_AS.nc 

