# Regrid with xesmf

This Notebook shows a bare-bones example of using xesmf to do regridding. 


In [None]:

import platform  # only for getting python version number
from collections import namedtuple
from pathlib import Path
import numpy as np
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
# auxilliary stuff
import colorcet as cc
import esmlab
import cartopy
import cartopy.crs as ccrs

from numba import jit
# USE AS: @jit(nopython=True)

import xesmf as xe

# ERA-Interim climatology

This is grabbing the monthly climatology used by the AMWG diagnostics. 

The path shows that I mounted CGD's `project` directory on my Mac. Otherwise this should work if you're on CGD systems and using a recent version of xarray.

In [None]:
def load_erai():
    obs_stem = Path("/Volumes/project/amp/amwg/amwg_diagnostics/obs_data")
    # NOTE: Is there documentation about these derived data? In particular, no notes on years/sample-size included.
    monthly_fils = sorted(list(obs_stem.glob('ERAI_[0-1][0-9]_climo.nc')))
    ds = xr.open_mfdataset(monthly_fils, combine='nested', concat_dim='time', decode_times=False)
    # the times are mangled, so just use values
    ds['time'].values = np.linspace(1,12,12,dtype=int)
    ds = ds.transpose('time', 'lev', 'lat', 'lon')
    return ds



In [None]:
ds_erai = load_erai()

# CESM data
This is just a CESM dataset that has native history files. These are monthly averages, and cover 5 years. 

I like to use the pathlib module to construct path objects, this could be done with strings.

In [None]:
stem = Path("/Volumes/project/amp/brianpm/vres/vres_L032")
ds_cesm = xr.open_mfdataset( stem.glob("*.h0.*.nc"), combine='by_coords')

## Horizontal regrid

We're going to do a simple bilinear remapping.

Average ERAI in time (un-needed for this example, but just simplifies the data).

Then use `xesmf` to create a regridding object. It uses the first argument to get the lat and lon coordinates of the source, and the second argument to get the lat and lon of the destination grid. Then it defines the type of interpolation (bilinear), the periodic keyward argument is set to True to use wraparound longitude, and `reuse_weights` allows us to re-use the weights file if we do multiple regrids.

The result `regridder` is not the regridded data, it is the operator that performs regridding. 

In [None]:
ds_erai_climo = ds_erai.mean(dim='time')
regridder = xe.Regridder(ds_erai_climo, dses[casenames[0]], 'bilinear', periodic=True, reuse_weights=True)

In [None]:
# pull out some example variable
Xerai = ds_erai_climo['U']
Xcesm = ds_cesm['U']

In [None]:
print('(horizontal) regrid step')
xobs_orig = Xerai.compute()  # explicitly say to load into memory
xobs = regridder(xobs_orig)  # xobs is the regridded data.
print('regrid done')

## Vertical regrid

Now that we have put the ERA-Interim data on the CAM grid, we now want to regrid the CAM data to pressure levels. Here I provide a few functions that are used to regrid to pressure levels. 

In this approach, I have used numba as a just-in-time (jit) compiler. So the first time we call the `@jit` decorated funcitons, they will actually be compiled, which *should* make the function very fast (especially on subsequent calls when the compilation step is skipped). The main reason I think it is helpful here is because it allows us to simply do nested loops of a simple function (`interp`); this should make the function run as fast as a compiled numpy function.

In [None]:
# get set up for pressure levels... 

def pres_from_hybrid(psfc, hya, hyb, p0=100000.):
    # p = a(k)*p0 + b(k)*ps.
    return hya*p0 + hyb*psfc


@jit(nopython=True)
def to_pres_4d(x_in, p, pnew):    
    ntime, nlev_in, nlat, nlon = x_in.shape
    nlev_out = len(pnew)
    x_out = np.empty((ntime, nlev_out, nlat, nlon))
    for time in range(ntime):
        for lat in range(nlat):
            for lon in range(nlon):
                pin = p[time, :, lat, lon]
                x_out[time, :, lat, lon] = np.interp(
                    pnew, pin, x_in[time, :, lat, lon]
                )
    return x_out


@jit(nopython=True)
def to_pres_3d(x_in, p, pnew):
    nlev_in, nlat, nlon = x_in.shape
    nlev_out = len(pnew)
    x_out = np.empty((nlev_out, nlat, nlon))
    for lat in range(nlat):
        for lon in range(nlon):
            pin = p[:, lat, lon]
            x_out[:, lat, lon] = np.interp(
                pnew, pin, x_in[:, lat, lon]
            )
    return x_out


# @jit(nopython=True)
def to_pres(x_in, p, pnew):
    s = x_in.shape
    if len(s) == 4:
        xout = to_pres_4d(x_in, p, pnew)
    elif len(s) == 3:
        xout = to_pres_3d(x_in, p, pnew)
    return xout



@jit(nopython=True)
def shape_checker(x):
    s = x.shape
    l = len(s)
    return l



def to_pres_wrap(x_in, p_in, p_new):
    # first unwrap into bare numpy
    x_np = x_in.values
    p_np = p_in.values
    if isinstance(p_new, xr.DataArray):
        print("Got dataarray for p, convert")
        p_new_np = p_new.values
    else:
        p_new_np = p_new
    # interpolate to the new pressure levels
    print(f"[to_pres_wrap] Shapes being sent to to_pres: {x_np.shape} (length: {len(x_np.shape)}), {p_np.shape}, {p_new_np.shape}")
    x_plev = to_pres(x_np, p_np, p_new_np)  # numba should make this speedy.
    # wrap back into DataArray
    new_coords = dict()
    for i in x_in.coords:
        new_coords[i] = x_in[i]
    new_coords['lev'] = p_new
    return xr.DataArray(x_plev, coords=new_coords, dims=x_in.dims)




In [None]:

# info for pressure interp -- USE ERA-I PRESSURE LEVELS.
pnew = ds_erai_climo['lev'].copy(deep=True) # Pa in the amwg netcdf files (n = 37, bottom to top)
if max(pnew) > 2000:
    print("ERAI lev must be in Pa")
else:
    print("ERAI lev must be in hPa")
    pnew *= 100.
    
print(pnew.values)


# pressure field on level midpoints
pressure = pres_from_hybrid(ds_cesm['PS'], ds_cesm['hyam'], ds_cesm['hybm']).compute()




In [None]:
# regrid model to pressure levels:
print("Start pressure level interpolation")
Xplev = to_pres_wrap(Xcesm, pressure, pnew)
print("Finished pressure level interpolation")

# Outcome

We now should have ERA-Interim data regridded to the CAM horizontal grid, and the CAM data interpolated to the ERA-Interim pressure levels. 