# Azores high: data prep
In this notebook, we'll load global data from WHOI's CMIP server,  "reduce" it by trimming in time and space, and save the pre-processed data locally. While not always possible, saving a reduced version of the data can speed up subsequent analysis: the code will run much faster if we can fit the data into "memory" (i.e., random access memory, or RAM).

To compute the Azores High index, download the trimmed data for reanalysis and CESM last millenium ensemble here: [https://drive.google.com/drive/folders/1kFF6a-ArVBdNaAK-3TKj-tpSV6Vx8yk4?usp=sharing](https://drive.google.com/drive/folders/1kFF6a-ArVBdNaAK-3TKj-tpSV6Vx8yk4?usp=sharing) (the folder called "LME" contains the trimmed last millenium ensemble data).

## Packages

In [None]:
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cmocean
import os
import pathlib

## custom imports
# import src.utils
# import src.utils_azores

## Set plotting defaults
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## Define constants/functions

__Last millenium ensemble (LME) data access__: If you don't have access to the "clidex" data server, download the LME folder from the shared google drive, and save it to your local project's data filepath, ```DATA_FP```.  

__Reanalysis data access__: Note that this script uses reanalysis data on the CMIP<b>5</b> server, rather than the CMIP<b>6</b> server.

### Utility functions

In [None]:
def switch_longitude_range(data):
    """move longitude for dataset or dataarray from [0,360)
    to (-180, 180]. Function accepts xr.dataset/xr.dataarray
    and returns object of same type"""

    ## get updated longitude coordinate
    updated_longitude = convert_longitude(data.lon.values)

    ## sort updated coordinate to be increasing
    updated_longitude = np.sort(updated_longitude)

    ## sort data ("reindex") according to update coordinate
    data = data.reindex({"lon": updated_longitude})

    return data


def convert_longitude(longitude):
    """move longitude from range [0,360) to (-180,180].
    Function accepts and returns a numpy array representing longitude values"""

    ## find indices of longitudes which will become negative
    is_neg = longitude > 180

    ## change values at these indices
    longitude[is_neg] = longitude[is_neg] - 360

    return longitude


def reverse_latitude(data):
    """Change direction of latitude from [-90,90] to [90,-90]
    or vice versa"""

    lat_reversed = data.lat.values[::-1]
    data = data.reindex({"lat": lat_reversed})

    return data


def djf_avg(data):
    """function to trim data to Dec/Jan/Feb (DJF) months"""

    ## subset data
    data_seasonal_avg = data.resample(time="QS-DEC").mean()

    ## get annual average for DJF
    is_djf = data_seasonal_avg.time.dt.month == 12
    data_djf_avg = data_seasonal_avg.sel(time=is_djf)

    ## drop 1st and last averages, bc they only have
    ## 2 samples and 1 sample, respectively
    data_djf_avg = data_djf_avg.isel(time=slice(1, -1))

    ## Replace time index with year, corresponding to january.
    ## '+1' is because 'time's year uses december
    year = data_djf_avg.time.dt.year + 1
    data_djf_avg["time"] = year
    data_djf_avg = data_djf_avg.rename({"time": "year"})

    return data_djf_avg


def spatial_avg(data, lon_range=[None, None], lat_range=[None, None]):
    """get global average of a quantity over the sphere.
    Data is xr.dataarray/xr.dataset with  a 'regular' lon/lat grid
    (equal lon/lat spacing between all data points)"""

    ## get latitude-dependent weights
    # first, convert latitude from degrees to radians
    # conversion factor = (2 pi radians)/(360 degrees)
    latitude_radians = data.lat * (2 * np.pi) / 360
    cos_lat = np.cos(latitude_radians)

    ## Next, trim data to specified range
    data_trim = data.sel(lon=slice(*lon_range), lat=slice(*lat_range))
    cos_lat_trim = cos_lat.sel(lat=slice(*lat_range))

    ## Next, compute weighted avg
    data_weighted = data_trim.weighted(weights=cos_lat_trim)
    data_avg = data_weighted.mean(["lat", "lon"])

    return data_avg


def spatial_int(data, lon_range=[None, None], lat_range=[None, None]):
    """compute spatial integral of a quantity on the sphere. For convenience,
    assume regular grid (constant lat/lon)"""

    ## Get latitude/longitude in radians.
    ## denote (lon,lat) in radians as (theta, phi)
    rad_per_deg = 2 * np.pi / 360
    theta = data.lon * rad_per_deg
    phi = data.lat * rad_per_deg

    ## get differences for integration (assumes constant differences)
    dtheta = theta[1] - theta[0]
    dphi = phi[1] - phi[0]

    ## broadcast to grid
    dtheta = dtheta * xr.ones_like(theta)
    dphi = dphi * xr.ones_like(phi)

    ## Get area of patch
    R = 6.371e6  # radius of earth in m
    dA = R**2 * np.cos(phi) * dphi * dtheta

    ## Integrate
    data_int = (data * dA).sum(["lat", "lon"])

    return data_int


def get_trend(data, dim="year"):
    """Get linear trend for an xr.dataarray along specified dimension"""

    ## Get coefficients for best fit
    polyfit_coefs = data.polyfit(dim=dim, deg=1)["polyfit_coefficients"]

    ## Get best fit line (linear trend in this case)
    trend = xr.polyval(data[dim], polyfit_coefs)

    return trend

### Azores-specific utility functions

In [None]:
def compute_AHA(slp, slp_global_avg, norm_type="global_mean"):
    """compute Azores High Area index, similar to Cresswell-Clay et al. (2022).
    Defined as: area of Azores region which has (normalized) SLP exceeding 0.5 std
    of long term average. Normalized SLP is defined as local SLP minus
    globally-averaged SLP.

    Args:
    - slp is gridded SLP data (at global scale)
    - norm_type is one of {None, "global_mean", "detrend"} specifying
        how to normalize the AHA index

    Note: returns area in KM^2.
    """

    ## get SLP anomaly in Azores regions
    slp_azores = trim_to_azores(slp)

    ## get normalized anomaly (func of lon/lat)
    if norm_type == "global_mean":
        slp_azores_norm = slp_azores - slp_global_avg

    elif norm_type == "detrend":
        trend = get_trend(slp_global_avg)
        slp_azores_norm = slp_azores - trend

    elif norm_type is None:
        slp_azores_norm = slp_azores

    else:
        print("Error: specify valid normalization type.")

    ## Get standard deviation (func of lon/lat)
    slp_azores_mean = slp_azores_norm.mean("year")
    slp_azores_std = slp_azores_norm.std("year")

    ## Get mask of grid cells exceeding 0.5 std threshold
    threshold = slp_azores_mean + 0.5 * slp_azores_std
    exceeds_thresh = slp_azores_norm > threshold

    ## Sum area of lon/lat cells exceeding threshold
    ## convert from True/False to 1.0/0.0 for integration
    AHA = spatial_int(exceeds_thresh.astype(float))

    ## convert from m^2 to km^2
    m_per_km = 1000
    m2_per_km2 = m_per_km**2
    AHA *= 1 / m2_per_km2

    return AHA

### Data loading functions

In [None]:
def load_lme_member(forcing_type, member_id, fp_in):
    """
    Function loads data from single member of CESM last-millenium ensemble (LME).
    Args:
    - 'member_id' is integer in [1,13] specifying which ensemble member to load
    - 'forcing_type' is one of {"all","volcanic","GHG","orbital"}
    """

    ## get prefix
    prefix = "b.e11.BLMTRC5CN.f19_g16"
    if forcing_type == "all":
        pass

    elif forcing_type == "GHG":
        prefix = f"{prefix}.GHG"

    elif forcing_type == "volcanic":
        prefix = f"{prefix}.VOLC_GRA"

    elif forcing_type == "orbital":
        prefix = f"{prefix}.ORBITAL"

    else:
        print("Error: not a valid forcing type")
        return

    ## add rest of prefix
    prefix = f"{prefix}.{member_id:03d}.cam.h0.PSL"

    ## Get names of two files for each ensemble member
    fp_and_prefix = os.path.join(fp_in, prefix)
    fp0 = f"{fp_and_prefix}.085001-184912.nc"
    fp1 = f"{fp_and_prefix}.185001-200512.nc"

    ## Load data
    data = xr.open_mfdataset([fp0, fp1], chunks={"time": 2000})["PSL"]

    ## switch longitude range from [0,360) to (-180,180]
    data = switch_longitude_range(data)

    return data


def load_noaa(fp_slp_noaa):
    """NOAA CIRES data and update coordinates"""

    ## open raw data and select PSL variable
    data = xr.open_dataset(fp_slp_noaa)["prmsl"]

    ## switch longitude range from [0,360) to (-180,180]
    data = switch_longitude_range(data)

    ## reverse latitude direction from [90,-90] to [-90,90]
    data = reverse_latitude(data)

    return data


def load_era(fp_slp_era):

    ## open raw data and select PSL file
    data = xr.open_dataarray(fp_slp_era)

    ## rename coordinates from "latitude" and "longitude"
    ## to "lat" and "lon"
    data = data.rename({"latitude": "lat", "longitude": "lon"})

    ## switch longitude range from [0,360) to (-180,180]
    data = switch_longitude_range(data)

    ## reverse latitude direction from [90,-90] to [-90,90]
    data = reverse_latitude(data)

    return data


## For each dataset, get DJF average and trim to North Atlantic
def trim(data):
    """function to trim a data in time and space, and save to file.
    Two datasets are returned:
    - data_trim (trimmed to north atlantic)
    - data_global_avg (global averaged of data)"""

    ## trim in time, then load to memory
    data_djf = djf_avg(data).compute()

    ## trim in space
    data_trim = trim_to_north_atlantic(data_djf)

    ## get global average
    data_global_avg = spatial_avg(data_djf)

    ## combine into single dataset
    data_prepped = xr.merge(
        [data_trim.rename("slp"), data_global_avg.rename("slp_global_avg")]
    )

    return data_prepped


def trim_to_north_atlantic(data):
    """convenience function to trim data north atlantic domain"""

    return data.sel(lon=slice(-70, 15), lat=slice(0, 70))


def trim_to_azores(data):
    """convenience function to trim data to Azores lon/lat range."""

    return data.sel(lon=slice(-60, 10), lat=slice(10, 52))


def load_prepped_data(data_loader_fn, prep_fn, fp_out):
    """Function applies 'prep_fn' to data returned by 'data_loader_fn',
    and saves result to 'fp_out'.
    Args:
        - data_loader_fn: function to load raw data
        - prep_fn: function to pre-process raw data
        - fp_out: filepath to save pre-processed data
    """

    ## check if file exists
    if os.path.isfile(fp_out):

        ## Load pre-trimmed file
        data_prepped = xr.open_dataset(fp_out).compute()

    else:

        ## Load data, trim it, and save to file
        data = data_loader_fn()
        data_prepped = prep_fn(data)
        data_prepped.to_netcdf(fp_out)

    return data_prepped


def get_trimmed_data_lme(forcing_type, member_ids, fp_in, save_fp):
    """Process multiple ensemble members"""

    ## Loop through each ensemble member
    data_trimmed = []
    for member_id in tqdm(member_ids):

        ## get filepath for saving data
        fp_out = pathlib.Path(save_fp, "LME", f"LME_{forcing_type}_{member_id:03d}.nc")

        ## function to load the given ensemble member
        data_loader_fn = lambda: load_lme_member(forcing_type, member_id, fp_in=fp_in)

        ## load trimmed data
        data_trimmed_i = load_prepped_data(data_loader_fn, trim, fp_out)

        ## append result
        data_trimmed.append(data_trimmed_i)

    ## merge data from each ensemble member to single dataset
    ensemble_member_dim = pd.Index(member_ids, name="ensemble_member")
    data_trimmed = xr.concat(data_trimmed, dim=ensemble_member_dim)

    return data_trimmed


def get_trimmed_data(data_loader_fn, fp_out):
    return load_prepped_data(data_loader_fn=data_loader_fn, fp_out=fp_out, prep_fn=trim)

## Load data

### Set filepaths

#### Save filepath 

In [None]:
## Where should we save intermediate results?
save_fp = pathlib.Path("data")

#### Reanalysis

Note: I've downloaded this data locally (and saved it to the folder ```data/cmip5_reanalysis_raw```).

In [None]:
fp_noaa_u10 = pathlib.Path("data/cmip5_reanalysis_raw/uwnd.10m.mon.mean.nc")
fp_noaa_v10 = pathlib.Path("data/cmip5_reanalysis_raw/vwnd.10m.mon.mean.nc")
fp_noaa_slp = pathlib.Path("data/cmip5_reanalysis_raw/prmsl.mon.mean.nc")
fp_noaa_precip = pathlib.Path("data/cmip5_reanalysis_raw/prate.mon.mean.nc")
fp_era_slp = pathlib.Path("data/cmip5_reanalysis_raw/msl.mon.mean.nc")

The data was downloaded from the following locations on the CMIP5 server:

Path to the reanalysis data on the CMIP5 server:
```python
SERVER_FP = pathlib.Path("/Volumes/data")
CMIP5_REANALYSIS_FP = pathlib.Path(SERVER_FP, "reanalysis")
```

Paths to the NOAA 20th century reanalysis data products:
```python
NOAA_FP = pathlib.Path(CMIP5_REANALYSIS_FP, "noaa.cires.20crv2c")
fp_noaa_u10 = pathlib.Path(NOAA_FP, "gaussian/uwnd.10m/monthly/uwnd.10m.mon.mean.nc")
fp_noaa_v10 = pathlib.Path(NOAA_FP, "gaussian/vwnd.10m/monthly/vwnd.10m.mon.mean.nc")
fp_noaa_slp = pathlib.Path(NOAA_FP, "monolevel/prmsl/monthly/prmsl.mon.mean.nc")
fp_noaa_precip = pathlib.Path(NOAA_FP, "gaussian/prate/monthly/prate.mon.mean.nc")
```

Path to the ERA 20th century renalysis:
```python
fp_era_slp = pathlib.Path(CMIP5_REANALYSIS_FP, "era.20c/sfc/msl/moda/msl.mon.mean.nc")
```

#### LME
Note: accesssing this data requires access to the shared data server called "clidex".

In [None]:
fp_lme = pathlib.Path("/vortexfs1/share/clidex/data/model/CESM/LME/atm/psl")

### LME filepath

### Load the data

In [None]:
## reanalysis
slp_noaa = get_trimmed_data(
    data_loader_fn=lambda: load_noaa(fp_noaa_slp),
    fp_out=pathlib.Path(save_fp, "slp_noaa.nc"),
)
slp_era = get_trimmed_data(
    data_loader_fn=lambda: load_era(fp_era_slp),
    fp_out=pathlib.Path(save_fp, "slp_era.nc"),
)

## LME
kwargs = dict(save_fp=save_fp, fp_in=fp_lme)
slp_lme = get_trimmed_data_lme(
    forcing_type="all", member_ids=np.arange(1, 14), **kwargs
)
slp_lme_volc = get_trimmed_data_lme(
    forcing_type="volcanic", member_ids=np.arange(1, 6), **kwargs
)
slp_lme_GHG = get_trimmed_data_lme(
    forcing_type="GHG", member_ids=np.arange(1, 4), **kwargs
)
slp_lme_orb = get_trimmed_data_lme(
    forcing_type="orbital", member_ids=np.arange(1, 4), **kwargs
)

### Plot SLP over time in dataset

In [None]:
## specify which dataset to use
data = slp_noaa

## Compute SLP averaged over Azores region
slp_azores = spatial_avg(trim_to_azores(data["slp"]))
slp_global = data["slp_global_avg"]

## get linear trend for global
global_trend = get_trend(slp_global)

## Plot
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(data.year, slp_azores, label="Azores")
ax.plot(data.year, slp_global, label="Global")
ax.plot(data.year, global_trend, label="Global trend", c="k", ls="--")
ax.legend(prop=dict(size=8))
ax.set_xlabel("Year")
ax.set_ylabel("SLP (hPa)")
ax.set_yticks(ax.get_yticks(), labels=np.round(ax.get_yticks() / 100, 0).astype(int))
plt.show()

## Plot before/after normalizing
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(
    slp_noaa.year,
    (slp_azores - slp_global + slp_global.mean()),
    label="remove global mean",
)
ax.plot(
    data.year,
    (slp_azores - global_trend + global_trend.mean()),
    label="remove trend",
)
ax.legend(prop=dict(size=8))
ax.set_xlabel("Year")
ax.set_ylabel("SLP (hPa)")
ax.set_yticks(ax.get_yticks(), labels=np.round(ax.get_yticks() / 100, 0).astype(int))
plt.show()

## Compute AHA metric

From Cresswell-Clay et al. (2022): "The AHA was defined as the area (km2) over the North Atlantic and Western Europe that had mean winter (December–January–February) SLP greater than 0.5 s.d. from the mean of the spatio-temporal winter SLP distribution (Fig. 1b). The region considered when calculating the AHA is bounded by the 60° W and 10° E meridians as well as the 10° N and 52° N latitudes."

In [15]:
def compute_AHA_wrapper(slp_data, save_fp):
    """Load AHA from given save_fp. If file doesn't exist,
    compute AHA based on given SLP data and save result"""

    try:
        AHA = xr.open_dataarray(save_fp)

    except:
        AHA = compute_AHA(
            slp_data["slp"], slp_data["slp_global_avg"], norm_type="detrend"
        )
        AHA.to_netcdf(save_fp)

    return AHA


## specify directory for saving AHA results
AHA_save_fp = pathlib.Path(save_fp, "AHA")

## compute AHA metric for each dataset
AHA_noaa = compute_AHA_wrapper(slp_noaa, AHA_save_fp / "AHA_noaa.nc")
AHA_era = compute_AHA_wrapper(slp_era, AHA_save_fp / "AHA_era.nc")
AHA_lme = compute_AHA_wrapper(slp_lme, AHA_save_fp / "AHA_lme_full.nc")
AHA_lme_orbital = compute_AHA_wrapper(slp_lme_orb, AHA_save_fp / "AHA_lme_orb.nc")
AHA_lme_volc = compute_AHA_wrapper(slp_lme_volc, AHA_save_fp / "AHA_lme_volc.nc")
AHA_lme_ghg = compute_AHA_wrapper(slp_lme_GHG, AHA_save_fp / "AHA_lme_GHG.nc")