# Overview
__In this notebook, we'll:__
- Walk through an example of model validation

# Preliminaries

#### Specify filepath to data

In [None]:
## filepath for the mount
cmip6_fp = "/Volumes/cmip6/data"

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import cmocean
import seaborn as sns
import glob
import cftime
import src.utils
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
import matplotlib.dates as mdates
import xesmf as xe

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## initialize random number generator
rng = np.random.default_rng()

#### Load in ORAS5 data

In [None]:
fp = "/Volumes/cmip6/data/ocean_reanalysis/ORAS5/oras5/monthly/global-reanalysis-phy-001-031"
zos_oras = xr.open_mfdataset(f"{fp}/global-reanalysis-phy-001-031-grepv2-monthly_*.nc")
zos_oras = zos_oras["zos_oras"]
zos_oras.load();

#### Load in CESM data

In [None]:
fp = "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/zos/gn/1"
zos_cesm = xr.open_dataset(
    f"{fp}/zos_Omon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc"
)
zos_cesm = zos_cesm["zos"]
zos_cesm = zos_cesm.load();

#### Regrid CESM to match ORAS5

In [None]:
regridder = xe.Regridder(
    ds_in=zos_cesm, ds_out=zos_oras, method="bilinear", periodic=False
)
zos_cesm = regridder(zos_cesm)

#### time mean SSH

In [None]:
def plot_setup(ax, lon_range, lat_range, xticks, yticks, scale):
    """
    Create map background for plotting spatial data.
    Arguments:
        - ax: Matplotlib object containing everything in the plot.
            (I think of it as the plot "canvas")
        - lon_range/lat_range: 2-element arrays, representing plot boundaries
        - xticks/yticks: location for lon/lat labels
        - scale: number which controls linewidth and fontsize

    Returns a modified 'ax' object.
    """

    ## specify transparency/linewidths
    grid_alpha = 0.1 * scale
    grid_linewidth = 0.5 * scale
    coastline_linewidth = 0.3 * scale
    label_size = 8 * scale

    ## crop map and plot coastlines
    ax.set_extent([*lon_range, *lat_range], crs=ccrs.PlateCarree())
    ax.coastlines(linewidth=coastline_linewidth)

    ## plot grid
    gl = ax.gridlines(
        draw_labels=True,
        linestyle="--",
        alpha=grid_alpha,
        linewidth=grid_linewidth,
        color="k",
        zorder=1.05,
    )

    ## add tick labels
    gl.bottom_labels = False
    gl.right_labels = False
    gl.xlabel_style = {"size": label_size}
    gl.ylabel_style = {"size": label_size}
    gl.ylocator = mticker.FixedLocator(yticks)
    gl.xlocator = mticker.FixedLocator(xticks)

    return ax, gl


## Next, a function to plot the North Atlantic
def plot_setup_north_atl(ax, scale=1):
    """Create map background for plotting spatial data.
    Returns modified 'ax' object."""

    ## specify range and ticklabels for plot
    lon_range = [-80, -42]
    lat_range = [30, 50]
    xticks = [-70, -60, -50]
    yticks = [35, 45]

    ax, gl = plot_setup(ax, lon_range, lat_range, xticks, yticks, scale)

    return ax, gl

In [None]:
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())

## ORAS5
ax, gl = plot_setup_north_atl(ax)
oras_plot = ax.contourf(
    zos_oras.longitude,
    zos_oras.latitude,
    zos_oras.mean("time"),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(1, 0.1),
)
cb = fig.colorbar(oras_plot, ticks=[-1, 0, 1], label=r"SSH ($m$)")

## CESM
ax = fig.add_subplot(2, 1, 2, projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)
cesm_plot = ax.contourf(
    zos_cesm.longitude,
    zos_cesm.latitude,
    zos_cesm.mean("time"),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(1, 0.1),
)
cb = fig.colorbar(cesm_plot, ticks=[-1, 0, 1], label=r"SSH ($m$)")

plt.show()

## Estimate Gulf Stream position based on magnitude of gradient

In [None]:
def grad(f):
    """compute gradient of data called 'f'.
    assumes latitude and longitude are both increasing"""

    ## get change in lon/lat (assume constant grid)
    theta = np.deg2rad(f.longitude)
    phi = np.deg2rad(f.latitude)
    dtheta = theta.values[1] - theta.values[0]
    dphi = phi.values[1] - phi.values[0]

    ## get change in dx/dy
    R = 6.37e6  # radius of earth in meters
    dy = R * dphi
    dx = R * np.cos(phi) * dtheta

    ## get differences
    f_plus = f.isel(longitude=slice(1, None))
    f_minus = f.isel(longitude=slice(None, -1))
    df_dy = xr.zeros_like(f_minus)
    df = f_plus.values - f_minus.values
    df_dy.values = df / dy

    f_plus = f.isel(latitude=slice(1, None))
    f_minus = f.isel(latitude=slice(None, -1))
    df = xr.zeros_like(f_minus)
    df.values = f_plus.values - f_minus.values
    df_dx = df / dx

    return df_dx, df_dy


def get_grad_mag(f):
    """get magnitude of gradient"""

    # get gradient
    df_dx, df_dy = grad(f)

    ## get gradient magnitude
    grad_mag = np.sqrt(df_dx**2 + df_dy**2)

    ## smooth in space
    n = 3
    grad_mag = grad_mag.rolling({"latitude": n, "longitude": n}).mean()

    return grad_mag


def get_GS_yposn(grad_mag):
    """Get position of Gulf Stream based on magnitude of gradient"""

    ## Trim in space
    grad_mag_trimmed = grad_mag.sel(
        longitude=slice(-80, None), latitude=slice(None, 45)
    )

    ## get gulf stream y_posn
    y_posn = grad_mag_trimmed.latitude.isel(
        latitude=grad_mag_trimmed.argmax("latitude")
    )

    return y_posn

Compute magnitude of gradient and gulf stream position

In [None]:
## get magnitude
mag_oras = get_grad_mag(zos_oras)
mag_cesm = get_grad_mag(zos_cesm)

## get GS latitude
y_posn_oras = get_GS_yposn(mag_oras)
y_posn_cesm = get_GS_yposn(mag_cesm)

In [None]:
## specify whether to plot time mean or random sample
plot_mean = False

## specify colorbar max for each plot
vmax = [1e-5, 5e-6]

if plot_mean:
    get_plot_data = lambda x: x.mean("time")

else:
    t_idx = rng.choice(np.arange(len(y_posn_oras.time)))
    get_plot_data = lambda x: x.isel(time=t_idx)

fig = plt.figure()

for i, (mag, y_posn, vmax_) in enumerate(
    zip([mag_oras, mag_cesm], [y_posn_oras, y_posn_cesm], vmax), start=1
):

    ax = fig.add_subplot(2, 1, i, projection=ccrs.PlateCarree())
    ax, gl = plot_setup_north_atl(ax)

    ## Plot magnitude of gradient
    plot = ax.contourf(
        mag.longitude,
        mag.latitude,
        get_plot_data(mag),  # .mean("time"),
        cmap="cmo.amp",
        levels=np.linspace(0, vmax_, 11),
        extend="both",
    )

    ## Plot estimated position
    ax.plot(y_posn.longitude, get_plot_data(y_posn), c="w", ls="--")

    ## colorbar
    cb = fig.colorbar(plot, ax=ax, ticks=[0, vmax_], label=r"SSH grad.")

plt.show()