# Outline of script:

# Packages

In [None]:
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
import matplotlib as mpl
import matplotlib.ticker as mticker
import cmocean
import xesmf as xe
import os

## custom imports
import src.utils
import src.utils_azores

# Define constants/functions

In [None]:
## Filepaths
# where to save data locally
DATA_FP = "/home/kcarr/whoi-climate-data-tutorial/data"

# on clidex
fp_lme = "/vortex/clidex/data/model/CESM/LME/atm/psl"

# on cmip5 server
fp_slp_era = "/mnt/cmip5-data/reanalysis/era.20c/sfc/msl/moda/msl.mon.mean.nc"
fp_slp_noaa = "/mnt/cmip5-data/reanalysis/noaa.cires.20crv2c/monolevel/prmsl/monthly/prmsl.mon.mean.nc"

# # also potentially useful:
# # (note: LME only has single ensemble member at the given directory)
# lme_fp = "/mnt/cmip5-data/CMIP5/output1/NCAR/CCSM4/past1000/mon/atmos/Amon/r1i1p1/psl/psl_Amon_CCSM4_past1000_r1i1p1_085001-185012.nc"
# era_fp = "/vortexfs1/share/clidex/data/reanalysis/20CR/prmsl/prmsl.mon.mean.nc"

## Set plotting defaults
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

# Load data

## Datasets used:
- ERA-20C
- NOAA–CIRES 20CR
- HadSLP2
- CESM1 LME

### Functions for fetching 'raw' data from server

In [None]:
def load_lme_member(forcing_type, member_id):
    """
    Function loads data from single member of CESM last-millenium ensemble (LME).
    Args:
    - 'member_id' is integer in [1,13] specifying which ensemble member to load
    - 'forcing_type' is one of {"all","volcanic","GHG","orbital"}
    """

    ## get prefix
    lme_fp = "/vortex/clidex/data/model/CESM/LME/atm/psl"
    prefix = "b.e11.BLMTRC5CN.f19_g16"
    if forcing_type == "all":
        pass

    elif forcing_type == "GHG":
        prefix = f"{prefix}.GHG"

    elif forcing_type == "volcanic":
        prefix = f"{prefix}.VOLC_GRA"

    elif forcing_type == "orbital":
        prefix = f"{prefix}.ORBITAL"

    else:
        print("Error: not a valid forcing type")
        return

    ## Get names of two files for each ensemble member
    fp_and_prefix = f"{lme_fp}/{prefix}.{member_id:03d}.cam.h0.PSL"
    fp0 = f"{fp_and_prefix}.085001-184912.nc"
    fp1 = f"{fp_and_prefix}.185001-200512.nc"

    ## Load data
    data = xr.open_mfdataset([fp0, fp1], chunks={"time": 2000})["PSL"]

    ## switch longitude range from [0,360) to (-180,180]
    data = src.utils.switch_longitude_range(data)

    return data


def load_noaa():
    """NOAA CIRES data and update coordinates"""

    ## open raw data and select PSL variable
    data = xr.open_dataset(fp_slp_noaa)["prmsl"]

    ## switch longitude range from [0,360) to (-180,180]
    data = src.utils.switch_longitude_range(data)

    ## reverse latitude direction from [90,-90] to [-90,90]
    data = src.utils.reverse_latitude(data)

    return data


def load_era():

    ## open raw data and select PSL file
    data = xr.open_dataarray(fp_slp_era)

    ## rename coordinates from "latitude" and "longitude"
    ## to "lat" and "lon"
    data = data.rename({"latitude": "lat", "longitude": "lon"})

    ## switch longitude range from [0,360) to (-180,180]
    data = src.utils.switch_longitude_range(data)

    ## reverse latitude direction from [90,-90] to [-90,90]
    data = src.utils.reverse_latitude(data)

    return data

### Functions to trim raw data and save locally

In [None]:
## For each dataset, get DJF average and trim to North Atlantic
def trim(data):
    """function to trim a data in time and space, and save to file.
    Two datasets are returned:
    - data_trim (trimmed to north atlantic)
    - data_global_avg (global averaged of data)"""

    ## trim in time, then load to memory
    data_djf = src.utils.djf_avg(data).compute()

    ## trim in space
    data_trim = src.utils_azores.trim_to_north_atlantic(data_djf)

    ## get global average
    data_global_avg = src.utils.spatial_avg(data_djf)

    ## combine into single dataset
    data_prepped = xr.merge(
        [data_trim.rename("slp"), data_global_avg.rename("slp_global_avg")]
    )

    return data_prepped


def get_trimmed_data(data_loader_fn, fp_out):
    """Function trims data returned by 'data_loader_fn', and
    saves result to 'fp_out'."""

    ## check if file exists
    if os.path.isfile(fp_out):

        ## Load pre-trimmed file
        data_trimmed = xr.open_dataset(fp_out).compute()

    else:
        ## Load data, trim it, and save to file
        data = data_loader_fn()
        data_trimmed = trim(data)
        data_trimmed.to_netcdf(fp_out)

    return data_trimmed


def get_trimmed_data_lme(forcing_type, member_ids):
    """Process multiple ensemble members"""

    ## Loop through each ensemble member
    data_trimmed = []
    for member_id in tqdm(member_ids):

        ## get filepath for saving data
        fp_out = f"{DATA_FP}/LME_{forcing_type}_{member_id:03d}.nc"

        ## if file already exists, load it
        if os.path.isfile(fp_out):
            data_trimmed_i = xr.open_dataset(fp_out).compute()

        ## otherwise, load raw data, trim it, and save to file
        else:
            ## function to load the given ensemble member
            data_loader_fn = lambda: load_lme_member(forcing_type, member_id)

            ## get trimmed data for ensemble member
            data_trimmed_i = get_trimmed_data(data_loader_fn, fp_out)

            ## save to file
            data_trimmed_i.to_netcdf(fp_out)

        ## append result
        data_trimmed.append(data_trimmed_i)

    ## merge data from each ensemble member to single dataset
    ensemble_member_dim = pd.Index(member_ids, name="ensemble_member")
    data_trimmed = xr.concat(data_trimmed, dim=ensemble_member_dim)

    return data_trimmed

#### Load in trimmed data

In [None]:
slp_noaa = get_trimmed_data(data_loader_fn=load_noaa, fp_out=f"{DATA_FP}/slp_noaa.nc")
slp_era = get_trimmed_data(data_loader_fn=load_era, fp_out=f"{DATA_FP}/slp_era.nc")
slp_lme = get_trimmed_data_lme(forcing_type="all", member_ids=np.arange(1, 14))

# Compute AHA metric

From Cresswell-Clay et al. (2022): "The AHA was defined as the area (km2) over the North Atlantic and Western Europe that had mean winter (December–January–February) SLP greater than 0.5 s.d. from the mean of the spatio-temporal winter SLP distribution (Fig. 1b). The region considered when calculating the AHA is bounded by the 60° W and 10° E meridians as well as the 10° N and 52° N latitudes."

#### Plot SLP over time in dataset

In [None]:
## specify which dataset to use
data = slp_noaa

## Compute SLP averaged over Azores region
slp_azores = src.utils.spatial_avg(src.utils_azores.trim_to_azores(data["slp"]))
slp_global = data["slp_global_avg"]

## get linear trend for global
global_trend = src.utils.get_trend(slp_global)

## Plot
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(data.year, slp_azores, label="Azores")
ax.plot(data.year, slp_global, label="Global")
ax.plot(data.year, global_trend, label="Global trend", c="k", ls="--")
ax.legend()
ax.set_xlabel("Year")
ax.set_ylabel("SLP (Pa)")
plt.show()

## Plot before/after normalizing
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(
    slp_noaa.year,
    (slp_azores - slp_global + slp_global.mean()),
    label="remove global mean",
)
ax.plot(
    data.year,
    (slp_azores - global_trend + global_trend.mean()),
    label="remove trend",
)
ax.legend()
ax.set_xlabel("Year")
ax.set_ylabel("SLP (Pa)")
plt.show()

#### Fig 2c: # of AHA extremes since ~1850

In [None]:
## count extremes in reanalysis
count_noaa = src.utils_azores.count_extremes_wrapper(slp_noaa)
count_era = src.utils_azores.count_extremes_wrapper(slp_era)

## count in historical component of LME
slp_lme_hist = slp_lme.sel(year=slice(1850, None))
count_lme = src.utils_azores.count_extremes_wrapper(slp_lme_hist)

## get ensemble mean, min, and max
count_lme_mean = count_lme.mean("ensemble_member")
count_lme_min = count_lme.min("ensemble_member")
count_lme_max = count_lme.max("ensemble_member")

## make plot
fig, ax = plt.subplots(figsize=(6, 3))

## plot reanalysis
ax.plot(count_noaa.year, count_noaa, label="NOAA", c="purple")
ax.plot(count_era.year, count_era, label="ERA", c="blue")

## plot LME mean and range
count_lme_plot = ax.plot(count_lme.year, count_lme_mean, label="LME", c="orange")
for bound in [count_lme_min, count_lme_max]:
    ax.plot(bound.year, bound, c=count_lme_plot[0].get_color(), lw=0.5)

## label plot
ax.legend()
ax.set_xlabel("Year")
ax.set_ylabel("Count (25-yr rolling)")
plt.show()

# Spatial plots

#### Functions to compute and plot composite

In [None]:
def make_composite(data, year_range0=[1950, 1979], year_range1=[1980, 2007]):
    """get composite, defined as difference between data
    when averaged over year_range1 and year_range0.
    """

    ## compute means
    mean1 = data.sel(year=slice(*year_range1)).mean("year")
    mean0 = data.sel(year=slice(*year_range0)).mean("year")

    return mean1 - mean0


def plot_setup_helper(ax, scale=1):
    """Create map background for plotting spatial data.
    Returns modified 'ax' object."""

    ## specify range and ticklabels for plot
    lon_range = [-70, 10]
    lat_range = [3, 70]
    xticks = [-60, -40, -20, 0]
    yticks = [20, 40, 60]

    ax, gl = src.utils.plot_setup(ax, lon_range, lat_range, xticks, yticks, scale)

    return ax, gl

#### Data loading functions

In [None]:
def load_noaa_uv10():
    """Load NOAA U10 and V10 data (eastward and northward windspeed at 10m)"""

    ## load raw files for u and v
    fp = "/mnt/cmip5-data/reanalysis/noaa.cires.20crv2c/gaussian"
    u10_noaa = xr.open_dataset(f"{fp}/uwnd.10m/monthly/uwnd.10m.mon.mean.nc")
    v10_noaa = xr.open_dataset(f"{fp}/vwnd.10m/monthly/vwnd.10m.mon.mean.nc")

    ## merge into single dataset
    uv10_noaa = xr.merge([u10_noaa, v10_noaa]).compute()

    ## update coordinates
    uv10_noaa = src.utils.standardize_lonlat(
        uv10_noaa, rename_coords=True, update_longitude=True, reverse_latitude=True
    )

    ## trim to north atlantic
    uv10_noaa = src.utils_azores.trim_to_north_atlantic(uv10_noaa)

    return uv10_noaa


def load_noaa_precip():
    """Load NOAA precipitation data from CMIP server.
    Returns xr.dataarray with units of mm"""

    ## load in precipitation rate (units: kg/m^2/s)
    fp = "/mnt/cmip5-data/reanalysis/noaa.cires.20crv2c/gaussian/prate/monthly/prate.mon.mean.nc"
    precip_rate = xr.open_dataset(fp)["prate"].compute()

    ## fix lonlat coordinates
    precip_rate = src.utils.standardize_lonlat(
        precip_rate, rename_coords=True, update_longitude=True, reverse_latitude=True
    )

    ## trim to North Atlantic domain
    precip_rate = src.utils_azores.trim_to_north_atlantic(precip_rate)

    ## Convert units from mass/area to height by dividing
    ## by density of liquid water:
    ## (km m^-2 s^-1) * (kg^-1 m^3) = m s^-1
    density_water = 1000  # units: kg m^-3
    precip_rate = precip_rate / density_water  # new units of m/s

    ## convert from month-averaged rate (units: m/s) to
    ## month total (units: m), by multiplyin by no. of seconds
    ## in each month. Units: (m s^-1) * s = m.
    ## Note: days_per_month depends on month (so is
    ## array, not a scalar, in the code below).
    days_per_month = precip_rate.time.dt.days_in_month
    seconds_per_day = 86400
    seconds_per_month = seconds_per_day * days_per_month
    precip = precip_rate * seconds_per_month

    ## convert units from m to mm
    mm_per_m = 1000
    precip = precip * mm_per_m

    return precip.rename("precip")


def load_noaa_lsm():
    """Load NOAA land-sea mask for NOAA reanalysis"""

    ## load raw data
    url = "http://psl.noaa.gov/thredds/dodsC/Datasets/20thC_ReanV2c/gaussian/time_invariant/land.nc"
    lsm = xr.open_dataset(url)["land"].isel(time=0, drop=True).compute()

    ## fix coordinates to match other datasets
    lsm = src.utils.standardize_lonlat(
        lsm, rename_coords=True, update_longitude=True, reverse_latitude=True
    )

    ## trim to north atlantic
    lsm = src.utils_azores.trim_to_north_atlantic(lsm)

    return lsm

#### Do the data loading

In [None]:
## load u/v data
uv10_noaa = load_noaa_uv10()
uv10_noaa = src.utils.djf_avg(uv10_noaa)

## get land-sea mask
lsm_noaa = load_noaa_lsm()

## load precip data, get DJF average
precip_noaa = load_noaa_precip()
precip_noaa = src.utils.djf_avg(precip_noaa)

## Regrid SLP to match precip and uv10;
## this will make compositing easier
regridder = xe.Regridder(ds_in=slp_noaa, ds_out=lsm_noaa, method="bilinear")
slp_noaa_regrid = regridder(slp_noaa).rename("slp")

## merge datasets
data_noaa = xr.merge([uv10_noaa, precip_noaa, slp_noaa_regrid])

## compute composite and climatology
comp = make_composite(data_noaa)
clim = data_noaa.sel(year=slice(1950, 2007)).mean("year")

In [None]:
## functions to plot individual variables on plotting background.
## these functions take in an 'ax' object
def plot_precip(fig, ax):
    """plot precipitation composite"""

    ## plot data
    precip_plot = ax.contourf(
        comp.longitude,
        comp.latitude,
        comp["precip"],
        cmap="cmo.diff_r",
        levels=src.utils.make_cb_range(35, 3.5),
        extend="both",
    )

    ## add colorbar
    precip_cb = fig.colorbar(
        precip_plot,
        pad=0.05,
        ax=ax,
        orientation="horizontal",
        label=r"$\Delta$ Precip. (mm/month)",
        ticks=np.arange(-28, 42, 14),
    )

    return


def plot_slp(fig, ax, mask=False, add_colorbar=False):
    """Plot SLP. If mask==True, mask values over land"""

    ## mask data if specified
    if mask:
        plot_data = comp["slp"].where(lsm_noaa < 1)
    else:
        plot_data = comp["slp"]

    slp_plot = ax.contourf(
        comp.longitude,
        comp.latitude,
        plot_data,
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(500, 50),
        transform=ccrs.PlateCarree(),
    )

    ## add colorbar if desired
    if add_colorbar:
        slp_cb = fig.colorbar(
            slp_plot,
            ax=ax,
            pad=0.05,
            orientation="horizontal",
            label=r"$\Delta$ SLP (Pa)",
            ticks=np.arange(-400, 600, 200),
        )

    return


def plot_slp_clim(fig, ax):
    """Plot SLP. If mask==True, mask values over land"""

    slp_plot_clim = ax.contour(
        clim.longitude,
        clim.latitude,
        clim["slp"],
        colors="k",
        levels=np.arange(99600, 102400, 400),
        linewidths=0.8,
    )

    return


def plot_wind(fig, ax, n=3):
    """Plot low-level wind. Plot every 'n'th vector to avoid overcrowding"""

    # get grid for plotting
    xx, yy = np.meshgrid(comp.longitude.values[::n], comp.latitude.values[::n])

    # plot the vectors
    wind_plot = ax.quiver(
        xx, yy, comp["uwnd"].values[::n, ::n], comp["vwnd"].values[::n, ::n]
    )

    # add legend
    wind_legend = ax.quiverkey(wind_plot, X=1.07, Y=-0.28, U=2, label=r"2 $m/s$")

    return

In [None]:
## Create figure
fig = plt.figure(figsize=(10, 6))

## add first plotting background
ax0 = fig.add_subplot(1, 2, 1, projection=ccrs.PlateCarree())
ax0, gl0 = plot_setup_helper(ax0, scale=1.2)

## plot precip in colors over land:
plot_precip(fig, ax0)
plot_slp(fig, ax0, mask=True)
plot_wind(fig, ax0, n=3)

## add second plotting background, and remove left longitude labels
ax1 = fig.add_subplot(1, 2, 2, projection=ccrs.PlateCarree())
ax1, gl1 = plot_setup_helper(ax1, scale=1.2)
gl1.left_labels = False

## for SLP: plot composite in colors and climatology in contours
plot_slp(fig, ax1, add_colorbar=True)
plot_slp_clim(fig, ax1)

plt.show()