## Check if we're running in Google Colab
If you are running in Google Colab, you may have to run the cell below twice because the kernel crashes; I'm not sure why this happens.

In [None]:
## check if we're in Colab
try:
    import google.colab

    ## install package that allows us to use mamba in Colab
    !pip install -q condacolab
    import condacolab

    condacolab.install()

    ## install extra packages to colab environment
    !mamba install -c conda-forge python=3.10.13 cmocean xesmf cartopy cftime cartopy

    ## connect to Google Drive (will prompt you to ask for permissions)
    from google.colab import drive

    drive.mount("/content/drive")

    ## flag telling us the notebook is running in Colab
    IN_COLAB = True

except:
    IN_COLAB = False

## Filepaths
__To run this notebook, you'll need to update the filepaths below__, which specify the location of the data (otherwise, you'll get a ```FileNotFoundError``` message when you try to open the data). These filepaths will differ for Mac vs. Windows users and depend on how you've accessed the data (e.g., mounting the WHOI file server or downloading the data).

In [None]:
if IN_COLAB:

    ## filepaths for historical/PI-control data
    hist_path = "/content/drive/My Drive/climate-data"
    pico_path = "/content/drive/My Drive/climate-data/tas_Amon_CESM2_piControl"

else:

    hist_path = (
        "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Amon/tas/gn/1"
    )

    pico_path = (
        "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/piControl/r1i1p1f1/Amon/tas/gn/1"
    )

## Imports

In [None]:
import xarray as xr
import numpy as np
import os
import time
import tqdm
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

## set default plot style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## initialize random number generator
rng = np.random.default_rng()

## Open the data

In [None]:
def trim(data):
    """Trim data in lon/lat space"""
    return data.sel(lon=slice(285, 293), lat=slice(39, 44))

Note: we're going to set ```mask_and_scale=False``` to avoid serialization warning.

#### Load Hist

In [None]:
hist_filename = "tas_Amon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc"
hist_full_path = os.path.join(hist_path, hist_filename)
T2m_hist = xr.open_dataset(hist_full_path, mask_and_scale=False)["tas"]

T2m_hist = trim(T2m_hist).compute()

#### Load PICO

In [None]:
T2m_pico_files = glob.glob(os.path.join(pico_path, "*.nc"))

In [None]:
## open dataset
T2m_pico = xr.open_mfdataset(
    T2m_pico_files,
    preprocess=trim,
    mask_and_scale=False,
)["tas"]

## Load into memory
start = time.time()
T2m_pico.load()
end = time.time()

print(end - start)

## define "climate index" function

First, a function to compute the index

In [None]:
def WH_index(T2m):
    """function to compute 'Woods Hole climate index'"""

    ## first, interpolate close to Woods Hole
    T2m_WH = T2m.interp(lat=41.5, lon=288.5, method="nearest")

    ## Get annual average
    T2m_WH = T2m_WH.groupby("time.year").mean()

    return T2m_WH

In [None]:
T2m_WH_hist = WH_index(T2m_hist).compute()
T2m_WH_pico = WH_index(T2m_pico).compute()

In [None]:
plt.plot(T2m_WH_pico.year + 650, T2m_WH_pico)
plt.plot(T2m_WH_hist.year, T2m_WH_hist)

#### Generate histogram from data

In [None]:
def get_random_sample_mean(data, nyears):
    """function draws a random sample from given dataset,
    and averages over period"""

    ## get random start year for random sample
    max_idx = len(data.year) - nyears
    idx_start = rng.choice(np.arange(0, max_idx))

    ## get random sample
    sample = data.isel(year=slice(idx_start, idx_start + nyears))

    ## get sample mean
    sample_mean = sample.mean("year")

    return sample_mean


def get_random_sample_means(data, nsamples, nyears=30):
    """get multiple random samples"""

    ## get random sample means
    sample_means = [
        get_random_sample_mean(data, nyears) for _ in tqdm.tqdm(np.arange(nsamples))
    ]

    ## Put in xr.DataArray.
    sample_dim = pd.Index(np.arange(nsamples), name="sample")
    sample_means = xr.concat(sample_means, dim=sample_dim)
    return sample_means


## get random samples
sample_means = get_random_sample_means(data=T2m_WH_pico, nsamples=3000, nyears=30)

#### Make histogram

In [None]:
bin_width = 0.1
bin_edges = np.arange(284.5, 286, bin_width)
histogram_pico, _ = np.histogram(sample_means, bins=bin_edges)

In [None]:
## blank canvas for plotting
fig, ax = plt.subplots(figsize=(4, 3))

## plot the histogram
ax.stairs(values=histogram_pico, edges=bin_edges, color="k", label="PI-control")

## plot mean value
ax.axvline(sample_means.mean(), c="k", ls="--")

## plot mean over last ~30 years
ax.axvline(
    T2m_WH_hist.isel(year=slice(-30, None)).mean("year"),
    c="r",
    ls="--",
    label=r"1984-2014",
)

## label the plot
ax.set_ylabel("# samples")
ax.set_xlabel(r"$K$")
ax.set_title(r"30-year average $T_{2m}$ in Woods Hole")
ax.legend()

plt.show()