# Overview
__In this notebook, we'll:__
1. Look at examples of commonly-used xarray operations (e.g. loading and plotting data, preprocessing)
2. Look at an example of how to apply concepts from lecture to output from a "stochastic climate model"
3. Apply concepts from lecture to output from CMIP data

__"Practical" learning objectives:__
- Become familiar with basics of manipulating gridded climate data

__Conceptual learning objectives:__
- Understand purpose of model validation
- Understand difference between externally-forced and internal climate variability.
- Understand why we need ensembles of climate simulations.

# 0. Preliminaries

#### specify path to data

In [None]:
## filepath for the mount
cmip6_fp = "/Volumes/cmip6/data"

#### Import packages

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import cmocean
from matplotlib.dates import DateFormatter
import seaborn as sns
import glob
import cftime
import src.utils
import cartopy.crs as ccrs
import matplotlib.ticker as mticker

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## initialize random number generator
rng = np.random.default_rng()

# 1. Commonly-used operations...
...and examples using ```xarray``` & ```matplotlib```

### Opening data
We're going to work with a "reanalysis" product located on the CMIP6 archive. As a reminder from lecture, a reanalysis is a hybrid of model output and observations. Observations (e.g., of rainfall, temperature, or ocean salinity) are sparse (& irregular) in time in space. The purpose of the reanalysis is to fill in the gaps, creating a nice "gridded" dataset which *is* regular in time and space.

In [None]:
## specify file path to ERA5 reanalysis product.
## We'll look at surface temperature
era5_path = f"{cmip6_fp}/era5/reanalysis/single-levels/monthly-means/2m_temperature"

## List the first few files in the folder:
file_list = glob.glob(f"{era5_path}/*.nc")
print(np.sort(file_list)[:4])

## Load a single file using xarray
T2m_1980 = xr.open_dataset(f"{era5_path}/1980_2m_temperature.nc")
T2m_1980.load()
# loads into memory

## open the first 3 files (but don't load to memory)
T2m = xr.open_mfdataset(file_list[:3])

To subset data, use the ```.isel``` / ```.sel``` functions:

In [None]:
## select data for Jan 1., two different ways
print(
    np.allclose(
        T2m_1980["t2m"].isel(time=0).values,
        T2m_1980["t2m"].sel(time="1980-01-01").values,
    )
)


## function to subset longitude/latitude
def trim_to_north_atl(x):
    """trims data to the North Pacific region"""

    ## lon/lat boundaries for N. Pacific
    lon_range = [260, 360]
    lat_range = [70, 3]

    ## trim the data
    x_trimmed = x.sel(longitude=slice(*lon_range), latitude=slice(*lat_range))

    return x_trimmed

If we only care about a subset of the data (i.e., not the whole globe), it can be helpful to subset it *while* loading. To do this, pass a subsetting function to ```xr.open_mfdataset```:

In [None]:
## Load trimmed data
T2m_trimmed = xr.open_mfdataset(file_list[:3], preprocess=trim_to_north_atl)

## Compare size to original data
print(f"Shape of raw data:     {T2m['t2m'].shape}")
print(f"Shape of trimmed data: {T2m_trimmed['t2m'].shape}")

### plotting data

First, let's define a function which draws a blank map (don't worry about the details for now)

In [None]:
## First, a generic plot setup function
def plot_setup(ax, lon_range, lat_range, xticks, yticks, scale):
    """
    Create map background for plotting spatial data.
    Arguments:
        - ax: Matplotlib object containing everything in the plot.
            (I think of it as the plot "canvas")
        - lon_range/lat_range: 2-element arrays, representing plot boundaries
        - xticks/yticks: location for lon/lat labels
        - scale: number which controls linewidth and fontsize

    Returns a modified 'ax' object.
    """

    ## specify transparency/linewidths
    grid_alpha = 0.1 * scale
    grid_linewidth = 0.5 * scale
    coastline_linewidth = 0.3 * scale
    label_size = 8 * scale

    ## crop map and plot coastlines
    ax.set_extent([*lon_range, *lat_range], crs=ccrs.PlateCarree())
    ax.coastlines(linewidth=coastline_linewidth)

    ## plot grid
    gl = ax.gridlines(
        draw_labels=True,
        linestyle="--",
        alpha=grid_alpha,
        linewidth=grid_linewidth,
        color="k",
        zorder=1.05,
    )

    ## add tick labels
    gl.bottom_labels = False
    gl.right_labels = False
    gl.xlabel_style = {"size": label_size}
    gl.ylabel_style = {"size": label_size}
    gl.ylocator = mticker.FixedLocator(yticks)
    gl.xlocator = mticker.FixedLocator(xticks)

    return ax, gl


## Next, a function to plot the North Atlantic
def plot_setup_north_atl(ax, scale=1):
    """Create map background for plotting spatial data.
    Returns modified 'ax' object."""

    ## specify range and ticklabels for plot
    lon_range = [-100, 10]
    lat_range = [3, 70]
    xticks = [-80, -60, -40, -20, 0]
    yticks = [20, 40, 60]

    ax, gl = plot_setup(ax, lon_range, lat_range, xticks, yticks, scale)

    return ax, gl

Next, create a blank canvas and modify it using our function

In [None]:
## Create a figure object (can contain multiple "Axes" object)
fig = plt.figure(figsize=(6, 3))

## add Axes object (blank canvas for our plot)
ax = fig.add_subplot(projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)

## Let's plot T2m data for Jan 1, 1980
t2m_plot = ax.contourf(
    T2m_1980.longitude,
    T2m_1980.latitude,
    T2m_1980["t2m"].isel(time=0),
    levels=np.arange(260, 304, 4),  # contour levels to plot
    cmap="cmo.thermal",  # colormap (see https://matplotlib.org/cmocean/)
    extend="both",  # includes values outside of contour bounds
)

## add a colorbar
cb = fig.colorbar(t2m_plot, orientation="vertical", label=r"$K$")

plt.show()

### grid-wise and grid-to-point correlation

### spatial average

### remove seasonal cycle

### detrending

### interpolation / resampling

### Regridding data

# Example: analysis of synthetic climate

#### Simulate synthetic climate

In [None]:
## specify number of ensemble members and end year for simulation
n_members = 1000
tf = 2006

## simulation pre-industrial and warming scenarios
T_PI = src.utils.markov_simulation(ti=1850, tf=tf, n_members=n_members, trend=0)
T_warming = src.utils.markov_simulation(
    ti=1850, tf=tf, n_members=n_members, trend=0.005
)

## for convenience, get subset of pre-industrial control which overlaps with warming
T_PI_hist = T_PI.sel(year=T_warming.year)

#### Plot a random sample from each simulation

In [None]:
# choose a random sample
idx = rng.choice(T_PI.ensemble_member)

# make the plot
fig, ax = plt.subplots(figsize=(4, 3))

ax.plot(
    T_PI_hist.year,
    T_PI_hist.sel(ensemble_member=idx),
    color="black",
    label="P.I. control",
)
ax.plot(
    T_warming.year, T_warming.sel(ensemble_member=idx), color="red", label="warming"
)

## label axes
ax.set_xlabel("Year")
ax.set_ylabel(r"SST anomaly ($^{\circ}C$)")
ax.legend(prop={"size": 10})
ax.set_title("random ensemble member")

plt.show()

#### Plot ensemble mean and spread

In [None]:
def plot_ensemble_spread(ax, T, color, label=None):
    """plot mean and +/- 1 standard dev. of ensemble on
    given ax object."""

    ## compute stats
    mean = T.mean("ensemble_member")
    std = T.std("ensemble_member")

    ## plot mean
    mean_plot = ax.plot(mean.year, mean, label=label, color=color)

    ## plot spread
    ax.plot(mean.year, mean + std, lw=0.5, c=mean_plot[0].get_color())
    ax.plot(mean.year, mean - std, lw=0.5, c=mean_plot[0].get_color())

    return


## Plot ensemble stats
fig, ax = plt.subplots(figsize=(4, 3))

## plot data
plot_ensemble_spread(ax, T_PI_hist, color="black", label="P.I. control")
plot_ensemble_spread(ax, T_warming, color="red", label="warming")

## label axes
ax.set_xlabel("Year")
ax.set_ylabel(r"SST anomaly ($^{\circ}C$)")
ax.legend(prop={"size": 10})
ax.set_title("Ensemble results")

plt.show()

#### Is the trend significant?
- histogram for PI control (generate by computing trend for 1,000 random 40-year segments)
- compare to trend in warming simulation over last 40 years and for period 1900-1950

In [None]:
def get_slope(data, dim="year"):
    """Function to compute linear trend of SST,
    in deg/century."""

    ## fit linear trend to data
    coefs = data.polyfit(dim=dim, deg=1)["polyfit_coefficients"]

    ## Get slope (degree=1; intercept is given by degree=0).
    ## Note: units are in deg/year
    slope = coefs.sel(degree=1)

    ## convert units to deg/century
    slope *= 100

    return slope


def plot_histogram_comparison(ax, samples0, samples1, label0=None, label1=None):
    """
    Compute two histograms, one each for samples0 and samples1.
    Plot the results on the specified ax object, and label histograms
    'label0' and 'label1', respectively.
    """

    ## First, make the histograms.
    # specify histogram bins
    bin_width = 0.5
    bin_edges = np.arange(-4.75, 4.75 + bin_width, bin_width)

    # compute histograms
    hist0, _ = np.histogram(samples0, bins=bin_edges)
    hist1, _ = np.histogram(samples1, bins=bin_edges)

    ## plot histograms
    ax.stairs(values=hist0, edges=bin_edges, color="k", label=label0)
    ax.stairs(
        values=hist1,
        edges=bin_edges,
        color="r",
        label=label1,
        fill=True,
        alpha=0.3,
    )

    ax.set_ylabel("Count")
    ax.set_xlabel(r"Warming trend ($^{\circ}C~/~$century)")

    ## label plot
    ax.set_ylabel("Count")
    ax.set_xlabel(r"Warming trend ($^{\circ}C~/~$century)")

    return ax


def plot_histogram_comparison_wrapper(ax, years):
    """wrapper function to plot histogram comparison for given subset of years"""

    ## Get trends for each ensemble members
    T_PI_subset = T_PI.sel(year=years)
    T_warming_subset = T_warming.sel(year=years)
    trends_PI = get_slope(T_PI_subset)
    trends_warming = get_slope(T_warming_subset)

    ## make the plot
    ax = plot_histogram_comparison(
        ax,
        samples0=trends_PI,
        samples1=trends_warming,
        label0="PI control",
        label1="warming",
    )

    ## plot ensemble means
    T_PI_mean = T_PI_subset.mean("ensemble_member")
    T_warming_mean = T_warming_subset.mean("ensemble_member")
    ax.axvline(get_slope(T_PI_mean), ls="--", c="k", lw=1)
    ax.axvline(get_slope(T_warming_mean), ls="--", c="r", lw=1)

    return ax

In [None]:
## Make plot
fig, axs = plt.subplots(1, 2, figsize=(8, 3))

## Plot data for each subset of years
axs[0] = plot_histogram_comparison_wrapper(axs[0], years=np.arange(1850, 1890))
axs[1] = plot_histogram_comparison_wrapper(axs[1], years=np.arange(1966, 2006))

## label plot
axs[0].set_title("First 40 years")
axs[1].set_title("Last 40 years")
axs[0].legend(prop={"size": 10})
axs[1].set_yticks([])
axs[1].set_ylabel(None)
axs[1].set_ylim(axs[0].get_ylim())

plt.show()

# Now it's your turn...
Start by picking a variable/index to analyze (e.g., SST at a specific location or averaged over a specified region), a reanalysis product from the CMIP archive, and a model from the CMIP archive.

#### Model validation
1. Load the data and make a plot.
2. Compute the index for both the reanalysis and the model (in the historical simulation).
3. How do the statistics compare over the overlapping period? (e.g., mean, standard deviation, seasonal cycle, power spectrum).

#### Is there evidence of a climate change signal?
1. Next, compute the index in the model's pre-industrial control simulation (__should this be a last millenium simulation?).
2. Create a histogram of trends in the control simulation by computing the trends of randomly-sampled 40-year segments.
3. Based on this histogram, is the model-simulated trend (from the last 40 years, in the historical simulation) significant?
   
#### What are the projected future changes?
1. Next, compute the index in a future warming scenario (using the same model).
2. How do the statistics/histogram shift in the future warming scenario?