# Overview
__In this notebook, we'll:__
- Walk through an example of model validation

# Preliminaries

#### Specify filepath to data

In [None]:
## filepath for the mount
cmip6_fp = "/Volumes/cmip6/data"

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import cmocean
import seaborn as sns
import glob
import cftime
import src.utils
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
import matplotlib.dates as mdates
import xesmf as xe

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## initialize random number generator
rng = np.random.default_rng()

#### Load in ORAS5 data

In [None]:
fp = "/Volumes/cmip6/data/ocean_reanalysis/ORAS5/oras5/monthly/global-reanalysis-phy-001-031"
zos_oras = xr.open_mfdataset(f"{fp}/global-reanalysis-phy-001-031-grepv2-monthly_*.nc")
zos_oras = zos_oras["zos_oras"]
zos_oras.load();

#### Load in CESM data

In [None]:
fp = "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/zos/gn/1"
zos_cesm = xr.open_dataset(
    f"{fp}/zos_Omon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc"
)
zos_cesm = zos_cesm["zos"]
zos_cesm = zos_cesm.load();

#### Regrid CESM to match ORAS5

In [None]:
regridder = xe.Regridder(
    ds_in=zos_cesm, ds_out=zos_oras, method="bilinear", periodic=False
)
zos_cesm = regridder(zos_cesm)

#### Plot results

In [None]:
def plot_setup(ax, lon_range, lat_range, xticks, yticks, scale):
    """
    Create map background for plotting spatial data.
    Arguments:
        - ax: Matplotlib object containing everything in the plot.
            (I think of it as the plot "canvas")
        - lon_range/lat_range: 2-element arrays, representing plot boundaries
        - xticks/yticks: location for lon/lat labels
        - scale: number which controls linewidth and fontsize

    Returns a modified 'ax' object.
    """

    ## specify transparency/linewidths
    grid_alpha = 0.1 * scale
    grid_linewidth = 0.5 * scale
    coastline_linewidth = 0.3 * scale
    label_size = 8 * scale

    ## crop map and plot coastlines
    ax.set_extent([*lon_range, *lat_range], crs=ccrs.PlateCarree())
    ax.coastlines(linewidth=coastline_linewidth)

    ## plot grid
    gl = ax.gridlines(
        draw_labels=True,
        linestyle="--",
        alpha=grid_alpha,
        linewidth=grid_linewidth,
        color="k",
        zorder=1.05,
    )

    ## add tick labels
    gl.bottom_labels = False
    gl.right_labels = False
    gl.xlabel_style = {"size": label_size}
    gl.ylabel_style = {"size": label_size}
    gl.ylocator = mticker.FixedLocator(yticks)
    gl.xlocator = mticker.FixedLocator(xticks)

    return ax, gl


## Next, a function to plot the North Atlantic
def plot_setup_north_atl(ax, scale=1):
    """Create map background for plotting spatial data.
    Returns modified 'ax' object."""

    ## specify range and ticklabels for plot
    lon_range = [-80, -30]
    lat_range = [30, 50]
    xticks = [-70, -55, -40]
    yticks = [35, 45]

    ax, gl = plot_setup(ax, lon_range, lat_range, xticks, yticks, scale)

    return ax, gl

In [None]:
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())

## ORAS5
ax, gl = plot_setup_north_atl(ax)
ax.contourf(
    zos_oras.longitude,
    zos_oras.latitude,
    zos_oras.mean("time"),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(1, 0.1),
)

## CESM
ax = fig.add_subplot(2, 1, 2, projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)
ax.contourf(
    zos_cesm.longitude,
    zos_cesm.latitude,
    zos_cesm.mean("time"),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(1, 0.1),
)

plt.show()

In [None]:
def grad(f):
    """compute gradient of data called 'f'.
    assumes latitude and longitude are both increasing"""

    ## get change in lon/lat (assume constant grid)
    theta = np.deg2rad(f.longitude)
    phi = np.deg2rad(f.latitude)
    dtheta = theta.values[1] - theta.values[0]
    dphi = phi.values[1] - phi.values[0]

    ## get change in dx/dy
    R = 6.37e6 # radius of earth in meters
    dy = R * dphi
    dx = R * np.cos(phi) * dtheta

    ## get differences
    f_plus = f.isel(longitude=slice(1,None))
    f_minus = f.isel(longitude=slice(None,-1))
    df_dy = xr.zeros_like(f_minus)
    df = f_plus.values - f_minus.values
    df_dy.values = df / dy

    f_plus = f.isel(latitude=slice(1,None))
    f_minus = f.isel(latitude=slice(None,-1))
    df = xr.zeros_like(f_minus)
    df.values = f_plus.values - f_minus.values
    df_dx = df / dx
    
    return df_dx, df_dy

def grad_mag(f):
    """get magnitude of gradient"""
    
    # get gradient
    df_dx, df_dy = grad(f)

    return np.sqrt(df_dx**2 + df_dy**2)
    

In [None]:
## compute magnitude of gradient
mag_oras = grad_mag(zos_oras)#.sel(longitude=slice(-80,None), latitude=slice(None,45))
mag_cesm = grad_mag(zos_cesm)#.sel(longitude=slice(-80,None), latitude=slice(None,45))

## smooth gradient gradient before predicting posn
n = 3
mag_oras = mag_oras.rolling({"latitude":n,"longitude":n}).mean().sel(longitude=slice(-80,None), latitude=slice(None,45))
mag_cesm = mag_cesm.rolling({"latitude":n,"longitude":n}).mean().sel(longitude=slice(-80,None), latitude=slice(None,45))

## get gulf stream y_posn
y_posn_oras = mag_oras.latitude.isel(latitude=mag_oras.argmax("latitude"))
y_posn_cesm = mag_cesm.latitude.isel(latitude=mag_cesm.argmax("latitude"))

In [None]:
vmax = [1e-5,5e-6]
t_idx = 9

fig = plt.figure()

for i, (mag, y_posn, vmax_) in enumerate(zip([mag_oras, mag_cesm],[y_posn_oras,y_posn_cesm],vmax), start=1):

    ax = fig.add_subplot(2,1,i, projection=ccrs.PlateCarree())
    ax, gl = plot_setup_north_atl(ax)
    ax.contourf(
        mag.longitude,
        mag.latitude,
        mag.isel(time=t_idx),
        # mag.mean("time"),
        cmap="cmo.amp",
        levels=np.linspace(0,vmax_,11),
        extend="both",
    )

    ax.plot(y_posn.longitude, y_posn.isel(time=t_idx))
    # ax.plot(y_posn.longitude, y_posn.mean("time"), c="w")

    ax.axvline(-55)
    ax.axhline(40, c="w")
    
    # break
plt.show()

In [None]:
plt.plot(y_posn_oras.sel(longitude=-55))
plt.plot(y_posn_oras.mean("longitude"))

In [None]:
y_posn_cesm.sel(longitude=-55)

In [None]:
# plt.plot(y_posn_cesm.sel(longitude=-55))
plt.plot(y_posn_cesm.sel(longitude=slice(None,-50)).mean("longitude"))

# Old stuff

In [None]:
fp = "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/zos/gn/1"
zos = xr.open_dataset(f"{fp}/zos_Omon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc")
zos = zos["zos"].compute()


def trim_to_natl(data):

    ## Get mask for north Atlantic
    lon_range = [280, 340]
    lat_range = [30, 50]
    isin_range = lambda x, range_: (x >= range_[0]) & (x <= range_[1])

    ## check if in lon/lat range
    in_lon_range = isin_range(x=data.lon.compute(), range_=lon_range)
    in_lat_range = isin_range(x=data.lat.compute(), range_=lat_range)

    ## mask for north atlantic
    in_north_atl = in_lon_range & in_lat_range

    ## keep rows/cols with at least one valid value
    lat_idx = in_north_atl.any("nlon")
    lon_idx = in_north_atl.any("nlat")

    ## subset data
    data_trim = data.isel(nlon=lon_idx, nlat=lat_idx)

    return data_trim

In [None]:
# lon = np.arange(280,330)
# lat = np.arange(30,50)
# grid = xr.DataArray(data=None,coords={"lon":lon,"lat":lat},dims=["lon","lat"])
# import xesmf as xe
regridder = xe.Regridder(
    ds_in=zos_oras, ds_out=zos_clim, method="bilinear", periodic=False
)

zos_clim_regrid = regridder(zos_oras)

fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)
ax.pcolormesh(zos_clim.lon, zos_clim.lat, zos_clim, vmax=2, vmin=-2, cmap="cmo.balance")

ax = fig.add_subplot(2, 1, 2, projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)
ax.pcolormesh(
    zos_clim_regrid.lon,
    zos_clim_regrid.lat,
    zos_clim_regrid.mean("time"),
    cmap="cmo.balance",
    vmax=2,
    vmin=-2,
)
plt.show()

In [None]:
# lon = np.arange(280,330)
# lat = np.arange(30,50)
# grid = xr.DataArray(data=None,coords={"lon":lon,"lat":lat},dims=["lon","lat"])
# import xesmf as xe
regridder = xe.Regridder(
    ds_in=zos_clim, ds_out=zos_oras, method="bilinear", periodic=False
)
regrid = regridder(zos)

fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)
ax.contourf(
    regrid.longitude,
    regrid.latitude,
    regrid.isel(time=-3),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(2, 0.2),
)

# ax = fig.add_subplot(2,1,2, projection=ccrs.PlateCarree())
# ax, gl = plot_setup_north_atl(ax)
# ax.contourf(zos_oras.lon, zos_oras.lat, zos_oras.mean("time"), cmap="cmo.balance",
#            vmax=2, vmin=-2)
# plt.show()

# Old

CESM velocity data

In [None]:
## Load CESM data
cesm_fp = "/Volumes/cmip6/data/cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/"
cesm_fp_u = f"{cesm_fp}/uo/gn/1/uo_Omon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc"
cesm_fp_v = f"{cesm_fp}/vo/gn/1/vo_Omon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc"
u = xr.open_mfdataset([cesm_fp_u, cesm_fp_v])

## select level closest to surface
u = u.isel(lev=0, drop=True)  # .compute()

## Get mask for north Atlantic
lon_range = [280, 340]
lat_range = [30, 50]
isin_range = lambda x, range_: (x >= range_[0]) & (x <= range_[1])

## check if in lon/lat range
in_lon_range = isin_range(x=u.lon.compute(), range_=lon_range)
in_lat_range = isin_range(x=u.lat.compute(), range_=lat_range)

## mask for north atlantic
in_north_atl = in_lon_range & in_lat_range

## keep rows/cols with at least one valid value
lat_idx = in_north_atl.any("nlon")
lon_idx = in_north_atl.any("nlat")

## subset data
u = u.isel(nlon=lon_idx, nlat=lat_idx)

In [None]:
u.compute()