# Example: Woods Hole temperature variability
In this example, we'll go over:

1. Loading and plotting gridded climate data  
2. Seasonal cycle and long-term trends, and how to compute them
3. How to compute/visualize some statistics (e.g., mean & correlation)
4. Power spectral density

```{admonition} To-do
write code as functions (which take in an index, for example)
```

Before beginning, let's import some packages:

## Package imports

In [None]:
import pathlib
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import datetime
import seaborn as sns
import cmocean
import matplotlib.patches as mpatches
import scipy.signal

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## Functions

### Plotting functions

In [None]:
def plot_setup(fig, projection, lon_range, lat_range):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## increase resolution for projection
    ## (otherwise lines plotted on surface won't follow curved trajectories)
    projection.threshold /= 1000

    ## Create subplot with given projection
    ax = fig.add_subplot(projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    return ax


def plot_box_outline(ax, lon_range, lat_range):
    """
    Plot box outlining the specifed lon/lat range on given
    ax object.
    """

    ## get width and height
    height = lat_range[1] - lat_range[0]
    width = lon_range[1] - lon_range[0]

    ## add rectangle to plot
    ax.add_patch(
        mpatches.Rectangle(
            xy=[lon_range[0], lat_range[0]],
            height=height,
            width=width,
            transform=ccrs.PlateCarree(),
            facecolor="none",
            edgecolor="k",
            linewidth=1,
        )
    )

    return ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance
    Args:
        - 'amp': amplitude of maximum value for colorbar
        - 'delta': increment for colorbar
    """
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def spatial_avg(data):
    """function to compute spatial average of data on grid with constant
    longitude/latitude spacing."""

    ## first, compute cosine of latitude (after converting degrees to radians)
    latitude_radians = np.deg2rad(data.latitude)
    cos_lat = np.cos(latitude_radians)

    ## get weighted average using xarray
    avg = data.weighted(weights=cos_lat).mean(["longitude", "latitude"])

    return avg


def get_trend(data, dim="time", deg=1):
    """
    Get trend for an xr.dataarray along specified dimension,
    by fitting polynomial of degree 'deg'.
    """

    ## Get coefficients for best fit
    polyfit_coefs = data.polyfit(dim=dim, deg=deg)["polyfit_coefficients"]

    ## Get best fit line (linear trend in this case)
    trend = xr.polyval(data[dim], polyfit_coefs)

    return trend


def detrend(data, dim="time", deg=1):
    """
    Remove trend of degree 'deg' from data, along dimension 'dim'.
    """

    return data - get_trend(data, dim=dim, deg=deg)

### Region-specific plotting functions

In [None]:
def plot_setup_atlantic(fig):
    """Plot Atlantic region"""

    ## set map projection to orthographic
    proj = ccrs.Orthographic(central_longitude=-50, central_latitude=40)

    return plot_setup(fig, proj, lon_range=[-100, 0], lat_range=[10, 70])


def plot_setup_woodshole(fig):
    """Plot zoomed-in view of Woods Hole"""

    ## set map projection to orthographic
    proj = ccrs.Orthographic(central_longitude=-67.5, central_latitude=40)

    return plot_setup(fig, proj, lon_range=[-75, -60], lat_range=[35, 45])

## Loading and plotting the data

### Load the data

```{admonition} To-do
Specify the path to the file server in the following code cell.
```

In [None]:
## path to the data
fp = pathlib.Path("/Users/theo/research/nashvar/data/HadISST.nc")

## open the data
## '.compute()' loads the data into memory
data = xr.open_mfdataset(fp)["sst"].compute()

### Plot the data at the most recent time
We can select the last time step with ```data.isel(time=-1)```.

```{admonition} To-do
Label this!
```

In [None]:
## blank canvas to plot on
fig = plt.figure(figsize=(5, 3))

## draw background map of Atlantic
ax = plot_setup_atlantic(fig)

## plot the data
plot_data = ax.contourf(
    data.longitude,
    data.latitude,
    data.isel(time=-1),
    transform=ccrs.PlateCarree(),
    levels=10,
    extend="both",
    cmap="cmo.thermal",
)

## create colorbath
colorbar = fig.colorbar(plot_data, label=r"$K$")

## Mark Woods Hole on map
ax.scatter(
    288.5, 41.5, transform=ccrs.PlateCarree(), marker="*", c="k", s=50, zorder=10
)

## label
ax.set_title(r"September 2022 $T_{2m}$")

plt.show()

## Define an index
Next, let's define the "Woods Hole temperature index", $T_{wh}$ as the temperature averaged near Woods Hole. Below, we outline the region to average over.

In [None]:
## define outline of area for computing the index
T_WH_LON_RANGE = [287.5, 293.5]
T_WH_LAT_RANGE = [39, 44]

Next, overlay an outline of this region on the map from before.

In [None]:
## blank canvas to plot on
fig = plt.figure(figsize=(5, 3))

## draw background map of Atlantic
ax = plot_setup_woodshole(fig)

## plot the data
xx, yy = np.meshgrid(data.longitude.values, data.latitude.values)
plot_data = ax.pcolormesh(
    xx,
    yy,
    data.isel(time=-1),
    transform=ccrs.PlateCarree(),
    cmap="cmo.thermal",
)

## create colorbath
colorbar = fig.colorbar(plot_data, label=r"$K$")

## Mark Woods Hole on map
ax.scatter(
    288.5, 41.5, transform=ccrs.PlateCarree(), marker="*", c="k", s=100, zorder=10
)

## plot outline of region used to compute index
ax = plot_box_outline(ax, lon_range=T_WH_LON_RANGE, lat_range=T_WH_LAT_RANGE)

## label the plot
ax.set_title(r"September 2022 $T_{2m}$")

plt.show()

### Compute the index

In [None]:
## a function to do the computation
def compute_T_wh(x):
    """Compute Woods Hole temperature index"""

    ## get subset of data inside the box
    ## the '[::-1]' is needed because latitude is
    ## in descending order in this dataset.
    data_subset = x.sel(
        latitude=slice(*T_WH_LAT_RANGE[::-1]), longitude=slice(*T_WH_LON_RANGE)
    )

    ## compute spatial average
    return spatial_avg(data_subset)


## carry out the computation here
idx = compute_T_wh(data)

### Plot the index

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

## Plot the data
ax.plot(idx.time, idx)

## restrict to last 50 years
ax.set_xlim([datetime.date(1970, 1, 1), None])

## label axes
ax.set_title(r"Woods Hole temperature index")
ax.set_ylabel(r"$K$")
ax.set_xticks(
    [datetime.date(1979, 1, 1), datetime.date(2000, 6, 30), datetime.date(2021, 12, 31)]
)
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))

plt.show()

## Diagnostics

### Seasonal cycle

In [None]:
## compute mean and standard deviation
seasonal_mean = idx.groupby("time.month").mean("time")
seasonal_std = idx.groupby("time.month").std("time")

## plot
fig, ax = plt.subplots(figsize=(4, 3))

## mean
ax.plot(np.arange(1, 13), seasonal_mean, c="k", label="mean")

## mean ± std
ax.plot(
    np.arange(1, 13), seasonal_mean + seasonal_std, c="k", lw=0.5, label="mean ± 1 std."
)
ax.plot(np.arange(1, 13), seasonal_mean - seasonal_std, c="k", lw=0.5)

## label
ax.set_xticks([1, 3, 8, 12], labels=["Jan", "Mar", "Aug", "Dec"])
ax.set_title(r"$T_{wh}$ climatology")
ax.legend()

plt.show()

### Compute anomalies

In [None]:
## compute anomalies at gridpoint level, then recompute temperature index
data_anom = data.groupby("time.month") - data.groupby("time.month").mean()
idx_anom = compute_T_wh(data_anom).compute()

### Compute linear trend

In [None]:
## compute linear trend (set deg=2 for quadratic trend)
trend_linear = get_trend(idx_anom, deg=1)

### Next, let's plot the anomalies

In [None]:
## plot decomposed time series
fig, ax = plt.subplots(figsize=(4, 3))

## plot the data
ax.plot(idx.time, idx_anom)

## plot trend lines
ax.plot(idx.time, trend_linear, c="k", label="linear trend")

## restrict to last 50 years and label axes
ax.set_xlim([datetime.date(1970, 1, 1), None])
ax.set_ylabel(r"$K$")
ax.set_xticks(
    [datetime.date(1979, 1, 1), datetime.date(2000, 6, 30), datetime.date(2021, 12, 31)]
)
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.set_title(r"$T_{wh}$ anomalies")

ax.legend()

plt.show()

### Histogram

In [None]:
## compute histogram
hist, bin_edges = np.histogram(idx_anom)

## normalize to a probability distribution (PDF)
bin_width = bin_edges[1:] - bin_edges[:-1]
pdf = hist / (hist * bin_width).sum()

## get a normal distribution for comparison
gaussian_pdf = scipy.stats.norm(loc=idx_anom.mean(), scale=idx_anom.std())

#### Plot result
fig, ax = plt.subplots(figsize=(4, 3))

## plot histogram
ax.stairs(values=pdf, edges=bin_edges)

## plot gaussian
x = np.linspace(-3, 3)
ax.plot(x, gaussian_pdf.pdf(x), c="k")

## label
ax.set_xlabel(r"$^{\circ}C$ anomaly")
ax.set_ylabel("Probability")

plt.show()

#### Power spectral density (PSD)

In [None]:
## compute PSD
## 'fs' is sampling frequency (units: samples/year)
freq, psd = scipy.signal.welch(idx_anom.values, fs=12)

## plot result
fig, ax = plt.subplots(figsize=(4, 3))
ax.loglog(freq, psd)
ax.set_xlabel("Freq (1/year)")
ax.set_ylabel(r"PSD (variance $\cdot$ year)")
plt.show()

### Spatial correlation

In [None]:
## get correlation between index and spatial data
## detrend the data first, so that we don't pick up warming trend
## in the correlation
corr = xr.corr(detrend(data_anom), detrend(idx_anom), dim="time")

## blank canvas to plot on
fig = plt.figure(figsize=(5, 3))

## draw background map of Atlantic
ax = plot_setup_atlantic(fig)

## plot the data
plot_data = ax.contourf(
    corr.longitude,
    corr.latitude,
    corr,
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(1, 0.1),
    extend="both",
    cmap="cmo.balance",
)

## create colorbath
colorbar = fig.colorbar(plot_data, label="Corr.", ticks=[-1, -0.5, 0, 0.5, 1])

## Mark Woods Hole on map
ax.scatter(
    288.5, 41.5, transform=ccrs.PlateCarree(), marker="*", c="magenta", s=50, zorder=10
)

## label
ax.set_title(r"Correlation with $T_{wh}$")

plt.show()