# Outline of script:

# Packages

In [None]:
import xarray as xr
import numpy as np
from glob import glob
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Define constants/functions

In [None]:
## constants
RAD_PER_DEG = 2 * np.pi / 360  # radians per degree
R = 6.371e6  # radius of earth in m
M_PER_KM = 1000  # meters per km

## Filepaths
# on clidex
# era_fp = "/vortexfs1/share/clidex/data/reanalysis/20CR/prmsl/prmsl.mon.mean.nc"
lme_fp = "/vortex/clidex/data/model/CESM/LME/atm/psl"

# on cmip5 server (note: LME only has single ensemble member at the given directory)
era_fp = "/mnt/cmip5-data/reanalysis/era.20c/sfc/msl/moda/msl.mon.mean.nc"
noaa_fp = "/mnt/cmip5-data/reanalysis/noaa.cires.20crv2c/monolevel/prmsl/monthly/prmsl.mon.mean.nc"
# lme_fp = "/mnt/cmip5-data/CMIP5/output1/NCAR/CCSM4/past1000/mon/atmos/Amon/r1i1p1/psl/psl_Amon_CCSM4_past1000_r1i1p1_085001-185012.nc"

## Set plotting defaults
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

# Load data

## Datasets used:
- ERA-20C
- NOAA–CIRES 20CR
- HadSLP2
- CESM1 LME

## Functions

In [None]:
def trim_to_azores(data):
    """helper function to trim data to Azores lon/lat range."""

    ## Azores lon/lat range
    lon_range = [-65, 15]
    lat_range = [15, 65]

    return data.sel(longitude=slice(*lon_range), latitude=slice(*lat_range))


def djf_avg(data):
    """function to trim data to Dec/Jan/Feb (DJF) months"""

    ## subset data
    data_seasonal_avg = data.resample(time="QS-DEC").mean()

    ## get annual average for DJF
    is_djf = data_seasonal_avg.time.dt.month == 12
    data_djf_avg = data_seasonal_avg.sel(time=is_djf)

    ## drop 1st and last averages, bc they only have
    ## 2 samples and 1 sample, respectively
    data_djf_avg = data_djf_avg.isel(time=slice(1, -1))

    ## Replace time index with year, corresponding to january.
    ## '+1' is because 'time's year uses december
    year = data_djf_avg.time.dt.year + 1
    data_djf_avg["time"] = year
    data_djf_avg = data_djf_avg.rename({"time": "year"})

    return data_djf_avg


def djf_avg_alt(data):
    """alternative (and slightly more general) function
    to trim data to Dec/Jan/Feb (DJF) months"""

    ## get 3-month rolling average average for DJF
    data_seasonal_avg = data_djf.resample({"time": "3MS"}, label="left").mean()

    ## subset for djf avg
    is_djf = data_seasonal_avg.time.dt.month == 12
    data_djf_avg = data_seasonal_avg.sel(time=is_djf)

    ## drop 1st and last averages, bc they only have
    ## 2 samples and 1 sample, respectively
    data_djf_avg = data_djf_avg.isel(time=slice(1, -1))

    ## Replace time index with year, corresponding to january.
    ## '+1' is because 'time's year uses december
    year = data_djf_avg.time.dt.year + 1
    data_djf_avg["time"] = year
    data_djf_avg = data_djf_avg.rename({"time": "year"})

    return data_djf_avg


def convert_longitude(longitude):
    """move longitude from range [0,360) to (-180,180].
    Function accepts and returns a numpy array representing longitude values"""

    ## find indices of longitudes which will become negative
    is_neg = longitude > 180

    ## change values at these indices
    longitude[is_neg] = longitude[is_neg] - 360

    return longitude


def update_longitude_coord(data):
    """move longitude for dataset or dataarray from [0,360)
    to (-180, 180]. Function accepts xr.dataset/xr.dataarray
    and returns object of same type"""

    ## get updated longitude coordinate
    updated_longitude = convert_longitude(data.longitude.values)

    ## sort updated coordinate to be increasing
    updated_longitude = np.sort(updated_longitude)

    ## sort data ("reindex") according to update coordinate
    data = data.reindex({"longitude": updated_longitude})

    return data


def standardize_lonlat(
    data, rename_coords=False, reverse_latitude=False, update_longitude=False
):
    """update lonlat coordinates to be consistent across datasets.
    In particular, make sure:
        - coordinates are called "longitude" and "latitude"
        - latitude is increasing
        - longitude is increasing and in the range (-180, 180].
    Function takes in and returns xr.dataarray or xr.dataset"""

    ## change coord names from "lat" and "lon" to
    ## "latitude" and "longitude", respectively
    if rename_coords:
        data = data.rename({"lat": "latitude", "lon": "longitude"})

    ## change longitude from range [0,360) to (-180, 180]
    if update_longitude:
        data = update_longitude_coord(data)

    ## switch direction of latitude so that it's increasing
    if reverse_latitude:
        latitude_updated = data.latitude.values[::-1]
        data = data.reindex({"latitude": latitude_updated})

    return data


def load_era():

    ## load raw data
    data = xr.open_dataset(era_fp)["msl"]

    ## update coordinates
    data = standardize_lonlat(data, update_longitude=True, reverse_latitude=True)

    return data


def load_noaa():

    ## load raw data
    data = xr.open_dataset(noaa_fp)["prmsl"]

    ## update coordinates
    data = standardize_lonlat(
        data, rename_coords=True, update_longitude=True, reverse_latitude=True
    )

    return data


def load_lme_member(member_id):
    """
    Function loads data from single member of CESM last-millenium ensemble (LME).
    Args:
    - 'member_id' is integer in [1,13] specifying which ensemble member to load
    """

    ## Get names of two files for each ensemble member
    prefix = f"{lme_fp}/b.e11.BLMTRC5CN.f19_g16.{member_id:03d}.cam.h0.PSL"
    fname0 = f"{prefix}.085001-184912.nc"
    fname1 = f"{prefix}.185001-200512.nc"

    ## Load data
    data = xr.open_mfdataset([fname0, fname1], chunks={"time": 2000})["PSL"]

    ## update coordinates
    data = standardize_lonlat(data, rename_coords=True, update_longitude=True)

    return data


def load_lme(member_ids=np.arange(1, 14).astype(int)):
    """load multiple LME members. 'member_ids' is an
    array-like object specifying which ensemble members to load. This
    is a list of integers in the range [1,13]"""

    ## Load all ensemble members in list
    data = [load_lme_member(i) for i in tqdm(member_ids)]

    ## convert list to xarray
    data = xr.concat(data, dim=pd.Index(member_ids, name="ensemble_member"))

    return data


def spatial_avg(data, lon_range=[None, None], lat_range=[None, None]):
    """get global average of a quantity over the sphere.
    Data is xr.dataarray/xr.dataset with  a 'regular' lon/lat grid
    (equal lon/lat spacing between all data points)"""

    ## get latitude-dependent weights
    # first, convert latitude from degrees to radians
    # conversion factor = (2 pi radians)/(360 degrees)
    latitude_radians = data.latitude * (2 * np.pi) / 360
    cos_lat = np.cos(latitude_radians)

    ## Next, trim data to specified range
    data_trim = data.sel(longitude=slice(*lon_range), latitude=slice(*lat_range))
    cos_lat_trim = cos_lat.sel(latitude=slice(*lat_range))

    ## Next, compute weighted avg
    data_weighted = data_trim.weighted(weights=cos_lat_trim)
    data_avg = data_weighted.mean(["latitude", "longitude"])

    return data_avg


def spatial_int(data, lon_range=[None, None], lat_range=[None, None]):
    """compute spatial integral of a quantity on the sphere. For convenience,
    assume regular grid (constant lat/lon)"""

    ## Get latitude/longitude in radians.
    ## denote (lon,lat) in radians as (theta, phi)
    rad_per_deg = 2 * np.pi / 360
    theta = data.longitude * RAD_PER_DEG
    phi = data.latitude * RAD_PER_DEG

    ## get differences for integration (assumes constant differences)
    dtheta = theta[1] - theta[0]
    dphi = phi[1] - phi[0]

    ## broadcast to grid
    dtheta = dtheta * xr.ones_like(theta)
    dphi = dphi * xr.ones_like(phi)

    ## Get area of patch
    dA = R**2 * np.cos(phi) * dphi * dtheta

    ## Integrate
    data_int = (data * dA).sum(["latitude", "longitude"])

    return data_int


def get_trend(data, dim="year"):
    """Get linear trend for an xr.dataarray along specified dimension"""

    ## Get coefficients for best fit
    polyfit_coefs = data.polyfit(dim=dim, deg=1)["polyfit_coefficients"]

    ## Get best fit line (linear trend in this case)
    trend = xr.polyval(data[dim], polyfit_coefs)

    return trend

## Do the actual data loading

In [None]:
## "Load" data (but not into memory yet)
# slp_lme = load_lme(member_ids=[1, 2])
slp_lme = load_lme(member_ids=np.arange(1, 14))
slp_noaa = load_noaa()
slp_era = load_era()

## Get DJF average
slp_noaa = djf_avg(slp_noaa).compute()
slp_era = djf_avg(slp_era).compute()
slp_lme = djf_avg(slp_lme).compute()

# Compute AHA metric

From Cresswell-Clay et al. (2022): "The AHA was defined as the area (km2) over the North Atlantic and Western Europe that had mean winter (December–January–February) SLP greater than 0.5 s.d. from the mean of the spatio-temporal winter SLP distribution (Fig. 1b). The region considered when calculating the AHA is bounded by the 60° W and 10° E meridians as well as the 10° N and 52° N latitudes."

In [None]:
def compute_AHA(slp, norm_type="global_mean"):
    """compute Azores High Area index, similar to Cresswell-Clay et al. (2022).
    Defined as: area of Azores region which has (normalized) SLP exceeding 0.5 std
    of long term average. Normalized SLP is defined as local SLP minus
    globally-averaged SLP.

    Args:
    - slp is gridded SLP data (at global scale)
    - norm_type is one of {"global_mean", "detrend"} specifying
        how to normalize the AHA index

    Note: returns area in KM^2.
    """

    ## get SLP anomaly in Azores regions
    slp_azores = trim_to_azores(slp)

    ## get globally-averaged SLP
    slp_global_avg = spatial_avg(slp)

    ## get normalized anomaly (func of lon/lat)
    if norm_type == "global_mean":
        slp_azores_norm = slp_azores - slp_global_avg

    elif norm_type == "detrend":
        trend = get_trend(slp_global_avg)
        slp_azores_norm = slp_azores - trend

    else:
        print("Error: specify valid normalization type.")

    ## Get standard deviation (func of lon/lat)
    slp_azores_mean = slp_azores_norm.mean("year")
    slp_azores_std = slp_azores_norm.std("year")

    ## Get mask of grid cells exceeding 0.5 std threshold
    threshold = slp_azores_mean + 0.5 * slp_azores_std
    exceeds_thresh = slp_azores_norm > threshold

    ## Sum area of lon/lat cells exceeding threshold
    ## convert from True/False to 1.0/0.0 for integration
    AHA = spatial_int(exceeds_thresh.astype(float))

    ## convert from m^2 to km^2
    m2_per_km2 = 1e6
    AHA *= 1 / (M_PER_KM**2)

    return AHA


def count_extremes(AHA, cutoff_perc=90.0, window=25):
    """Get rolling count of Azores High extreme events.
    Args:
    - cutoff_perc is percentile value in range (0 and 100) used to define
        'extreme' events
    - window is an integer specifying how many years the rolling window is.
    """

    ## get threshold for extreme events
    threshold = AHA.quantile(q=cutoff_perc / 100)

    ## Get boolean array: True if AHA exceeds thresh
    exceeds_thresh = AHA > threshold

    ## Get rolling count
    rolling_count = exceeds_thresh.rolling(dim={"year": window}, center=True).sum()

    ## remove NaN values at beginning and end
    nan_count = np.round((window - 1) / 2).astype(int)
    rolling_count = rolling_count.isel(year=slice(nan_count, -nan_count))

    return rolling_count


def count_extremes_wrapper(slp, norm_type="detrend"):
    """wrapper function which takes in SLP and computes # of Azores High extremes"""
    return count_extremes(compute_AHA(slp, norm_type=norm_type))

Look at SLP over time

In [None]:
## Compute area-averaged SLP
slp_azores = spatial_avg(trim_to_azores(slp_noaa))
slp_global = spatial_avg(slp_noaa)

## get linear trend for global
trend_fit = get_trend(slp_global)

## Plot
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(slp_noaa.year, 1e-2 * slp_azores, label="Azores")
ax.plot(slp_noaa.year, 1e-2 * slp_global, label="Global")
ax.plot(trend_fit.year, 1e-2 * trend_fit, label="Global trend", c="k", ls="--")
ax.legend()
ax.set_xlabel("Year")
ax.set_ylabel("SLP (hPa)")
plt.show()

## Plot before/after normalizing
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(
    slp_noaa.year,
    1e-2 * (slp_azores - slp_global + slp_global.mean()),
    label="remove global mean",
)
ax.plot(
    trend_fit.year,
    1e-2 * (slp_azores - trend_fit + trend_fit.mean()),
    label="remove trend",
)
ax.legend()
ax.set_xlabel("Year")
ax.set_ylabel("SLP (hPa)")
plt.show()

Plot # of LME extremes over time

In [None]:
## specify which type of normalization
## one of {"global_mean","detrend"}
norm_type = "global_mean"

## count extremes in reanalysis
count_noaa = count_extremes_wrapper(slp_noaa, norm_type=norm_type)
count_era = count_extremes_wrapper(slp_era, norm_type=norm_type)

## count in historical component of LME
slp_lme_hist = slp_lme.sel(year=slice(1850, None))
count_lme = count_extremes_wrapper(slp_lme_hist, norm_type=norm_type)
count_lme_mean = count_lme.mean("ensemble_member")
count_lme_min = count_lme.min("ensemble_member")
count_lme_max = count_lme.max("ensemble_member")

## make plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot reanalysis
ax.plot(count_noaa.year, count_noaa, label="NOAA", ls="--")
ax.plot(count_era.year, count_era, label="ERA", ls=":")

## plot LME mean and range
count_lme_plot = ax.plot(count_lme.year, count_lme_mean, label="LME", ls="-")
for bound in [count_lme_min, count_lme_max]:
    ax.plot(bound.year, bound, c=count_lme_plot[0].get_color(), lw=0.5)

## label plot
ax.legend()
ax.set_xlabel("Year")
ax.set_ylabel("Count (25-yr rolling)")
plt.show()

# Scratch

In [None]:
## Trim to Azores region
# slp_noaa_azores =

# ## now subset in space and get DJF average (loads into memory)
# print("Loading LME")
# slp_lme = reduce(slp_lme)

# print("Loading NOAA")
# slp_noaa = reduce(slp_noaa)

# print("Loading 20CR")
# slp_era = reduce(slp_era)

In [None]:
import matplotlib.pyplot as plt

cos_lat = np.cos(slp_noaa.latitude * 2 * np.pi / 360)
slp_noaa_weighted = slp_noaa.weighted(weights=cos_lat)
slp_weighted = slp_noaa_weighted.mean(["longitude", "latitude"])
slp_unweighted = slp_noaa.mean(["longitude", "latitude"])
# (slp_noaa * cos_lat).sum("latitude") / cos_lat.sum("latitude")

In [None]:
slp_azores = trim_to_azores(slp_noaa)
slp_azores = slp_azores.weighted(np.cos(slp_azores.latitude * 2 * np.pi / 360))
slp_azores = slp_azores.mean(["longitude", "latitude"])

In [None]:
import scipy.stats

y0 = 1900
scipy.stats.pearsonr(
    slp_azores.sel(year=slice(y0, None)), slp_weighted.sel(year=slice(y0, None))
)

scipy.stats.pearsonr(
    slp_azores.sel(year=slice(y0, None)), slp_unweighted.sel(year=slice(y0, None))
)

In [None]:
fig, ax = plt.subplots()
ax.plot(slp_weighted.year, slp_weighted, label="weighted")
ax.plot(slp_weighted.year, spatial_avg(slp_noaa), label="weighted", ls="--", c="k")
ax.plot(slp_unweighted.year, slp_unweighted, label="unweighted")
ax.plot(slp_azores.year, slp_azores, label="azores")
ax.axhline(1.013e5, ls="--", c="k")
ax.legend()
plt.show()

fig, ax = plt.subplots()
ax.plot(slp_azores.year, slp_azores, label="azores")
ax.plot(
    slp_azores.year, slp_azores - slp_weighted + slp_weighted.mean(), label="azores"
)
ax.axhline(1.013e5, ls="--", c="k")
ax.legend()
plt.show()

In [None]:
print(scipy.stats.pearsonr(slp_azores / slp_weighted, slp_azores - slp_weighted))
print(
    scipy.stats.pearsonr(
        np.log(slp_azores) - np.log(slp_weighted), slp_azores - slp_weighted
    )
)

In [None]:
norm = lambda x: (x - x.mean()) / x.std()
plt.plot(norm(slp_azores / slp_weighted))
plt.plot(norm(slp_azores - slp_weighted))

In [None]:
def get_anom(data):
    """Compute anomalies relative to long-term mean"""
    return


def AHA(data_anom):
    return

In [None]:
slp_noaa.chunk({"time": None})

In [None]:
x = slp_noaa.isel(latitude=10, longitude=10)


import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(djf_avg(x))
ax.plot(djf_avg2(x))
ax.set_xlim([-1, 24])

In [None]:
djf_avg2(slp_noaa)

In [None]:
djf_avg2(slp_noaa).year

In [None]:
djf_avg(slp_noaa).year

# Compute AHA index