## Check if we're running in Google Colab
If you are running in Google Colab, you may have to run the cell below twice because the kernel crashes; I'm not sure why this happens.

In [None]:
## check if we're in Colab
try:
    import google.colab

    ## install package that allows us to use mamba in Colab
    !pip install -q condacolab
    import condacolab

    condacolab.install()

    ## install extra packages to colab environment
    !mamba install -c conda-forge python=3.10.13 cmocean xesmf cartopy cftime cartopy

    ## connect to Google Drive (will prompt you to ask for permissions)
    from google.colab import drive

    drive.mount("/content/drive")

    ## flag telling us the notebook is running in Colab
    IN_COLAB = True

except:
    IN_COLAB = False

## Filepaths
__To run this notebook, you'll need to update the filepaths below__, which specify the location of the data (otherwise, you'll get a ```FileNotFoundError``` message when you try to open the data). These filepaths will differ for Mac vs. Windows users and depend on how you've accessed the data (e.g., mounting the WHOI file server or downloading the data).

In [None]:
if IN_COLAB:

    ## filepaths for historical/PI-control data
    hist_path = "/content/drive/My Drive/climate-data"
    pico_path = "/content/drive/My Drive/climate-data"

else:

    hist_path = (
        "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Amon/tas/gn/1"
    )
    
    pico_path = (
        "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/piControl/r1i1p1f1/Amon/tas/gn/1"
    )

## Imports

In [None]:
import xarray as xr
import numpy as np
import os
import time
import tqdm
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

## set default plot style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## initialize random number generator
rng = np.random.default_rng()

## Open the data

Note: we're going to set ```mask_and_scale=False``` to avoid serialization warning.

In [None]:
T2m_hist = xr.open_mfdataset(os.path.join(hist_path, "*.nc"), mask_and_scale=False)["tas"]
# T2m_pico = xr.open_mfdataset(os.path.join(pico_path, "*.nc"), mask_and_scale=False)["tas"]

## define "climate index" function

First, a function to compute the index

In [None]:
def WH_index(T2m):
    """function to compute 'Woods Hole climate index'"""

    ## first, interpolate close to Woods Hole
    # T2m_WH = T2m.interp(lat=41.5, lon=288.5, method="nearest")
    T2m_WH = T2m.isel(lat=140, lon=231)

    ## Get annual average
    T2m_WH = T2m_WH.groupby("time.year").mean()
    
    return T2m_WH

def WH_index_from_file(T2m_filepath):
    """function computes WH index given single filepath"""
    
    ## Load data at given filepath
    T2m = xr.open_dataset(T2m_filepath, mask_and_scale=False)["tas"]

    ## Compute the index
    T2m_WH = WH_index(T2m)

    ## Close the original dataset
    T2m.close()

    return T2m_WH

#### Loop through files

In [None]:
T2m_pico_files = glob.glob(os.path.join(pico_path, "*.nc"))

## Empty list to hold result
T2m_pico = []

## loop through files
for file in tqdm.tqdm(T2m_pico_files):
    T2m_pico.append(WH_index_from_file(file))

## Put in dataset format
T2m_pico = xr.concat(T2m_pico, dim="year")

#### Compute WH index for historical data

In [None]:
T2m_hist = WH_index(T2m_hist).compute()

In [None]:
T2m_pico.year

In [None]:
T2m_hist.year[0]

In [None]:
plt.plot(T2m_pico.year+650, T2m_pico)
plt.plot(T2m_hist.year, T2m_hist)

#### Generate histogram from data

In [None]:
def get_random_sample_mean(data, nyears):
    """function draws a random sample from given dataset,
    and averages over period"""

    ## get random start year for random sample
    max_idx = len(data.year) - nyears
    idx_start = rng.choice(np.arange(0, max_idx))

    ## get random sample
    sample = data.isel(year=slice(idx_start, idx_start + nyears))

    ## get sample mean
    sample_mean = sample.mean("year")
    
    return sample_mean


def get_random_sample_means(data, nsamples, nyears=30):
    """get multiple random samples"""

    ## get random sample means
    sample_means = [get_random_sample_mean(data, nyears) for _ in tqdm.tqdm(np.arange(nsamples))]

    ## Put in xr.DataArray.
    sample_dim = pd.Index(np.arange(nsamples), name="sample")
    sample_means = xr.concat(sample_means, dim=sample_dim)
    return sample_means


## get random samples
sample_means = get_random_sample_means(data=T2m_pico, nsamples=3000, nyears=30)

#### Make histogram

In [None]:
bin_width = 1/12
bin_edges = np.arange(284+1/2, 286, bin_width)
# bin_width = 1/12
# bin_edges = np.arange(271+1/12, 272+5/12, bin_width)
bin_centers = 1/2 * (bin_edges[:-1] + bin_edges[1:])

In [None]:
histogram_pico, _ = np.histogram(sample_means, bins=bin_edges)

In [None]:
fig,ax = plt.subplots(figsize=(4,3))
ax.stairs(values=histogram_pico, edges=bin_edges, color="k")
ax.axvline(sample_means.mean(), c="k", ls="--")
ax.axvline(T2m_hist.isel(year=slice(-30,None)).mean("year"), c="r", ls="--")

# old stuff

In [None]:
start = time.time()
T2m_pico.isel(lat=40, lon=50).load()
end = time.time()
print(f"{end-start:.2f} seconds")

In [None]:
start = time.time()

T2m_hist = xr.open_mfdataset(
    os.path.join(hist_path, "*.nc"), 
    mask_and_scale=False,
)["tas"]

T2m_hist = T2m_hist.chunk({"time":15000, "lon":64,"lat":64})
T2m_hist_annual = T2m_hist.groupby("time.year").mean().compute()

end = time.time()
print(f"{end - start:.2f} seconds")

In [None]:
start = time.time()

T2m_pico = xr.open_mfdataset(
    os.path.join(pico_path, "*.nc"), 
    mask_and_scale=False,
)["tas"]

# T2m_pico = T2m_pico.chunk({"time":15000, "lon":64,"lat":64})
# T2m_pico_annual = T2m_pico.groupby("time.year").mean().compute()

# end = time.time()
# print(f"{end - start:.2f} seconds")

#### Other attempts

In [None]:
start = time.time()
T2m_pico = xr.open_mfdataset(
    os.path.join(pico_path, "*.nc"), 
    mask_and_scale=False,
    preprocess=WH_climate_index)["tas"]
end = time.time()
print(f"{end-start:.2f} seconds")

In [None]:
start = time.time()
T2m_pico.load()
end = time.time()
print(f"{end-start:.2f} seconds")

Do the computation

In [None]:
T2m_WH_hist = WH_climate_index(T2m_hist).compute()
T2m_WH_pico = WH_climate_index(T2m_pico).compute()

In [None]:
T2m_WH_pico