# A collection of various horizontal interpolation (or projection or regridding) methods
## Created by Ehsan Erfani

In [2]:
from scipy import interpolate
from scipy.interpolate import griddata
from scipy.interpolate import NearestNDInterpolator
import pyresample


import numpy as np
import os, glob
import pandas as pd
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from mpl_toolkits.axes_grid1 import make_axes_locatable
import sys
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import datetime
from datetime import timedelta
import seaborn as sb
from netCDF4 import Dataset
from datetime import date, timedelta

import warnings
warnings.filterwarnings("ignore")


## Part1: Regular and/or structured grid:
- When the grid is rectangular: lat_orig(x), lon_orig(y), var(x,y), lat_target(x), lon_target(y)
- Since the data is structured, this method is relatively fast.
- Technically this method can be used for unstructured grid, but you have to use brute force and do it for each grid point separaretly, so it takes a long time.

### interp2d

In [None]:
f = interpolate.interp2d(lat_orig, lon_orig, np.transpose(var), kind='linear') # Two other methods: nearest and cubic
var_intrp = np.transpose(f(lat_target, lon_target))[0]  

### RectBivariateSpline

In [None]:
f = interpolate.RectBivariateSpline(lat_orig, lon_orig, np.transpose(var))
var_intrp = np.transpose(f(lat_target, lon_target))[0]  

## Part2: Irregular and/or unstructured grid:
- When the grid is not rectangular: lat_orig(x,y), lon_orig(x,y), var(x,y), lat_target(x,y), lon_target(x,y)

### griddata
- This method is slow for large data
- Probabely, you need to created meshgrid of original grid before interpolation

In [None]:
## Most of the time, you need to do flattening or reshaping first:
var_intrp = griddata((np.reshape(lat_orig_msh,-1), np.reshape(lon_orig_msh,-1)), np.reshape(var,-1),\
                     (lat_target, lon_target), method='linear') 

## Another similar way:
var_intrp = griddata( lat_orig_msh.flatten(), lon_orig_msh.flatten(), var.flatten(), lat_target.flatten(), lon_target.flatten() )

### NearestNDInterpolator
- This method is a bit slower than griddata for me.

In [None]:
f = NearestNDInterpolator((np.reshape(lat_orig_msh,-1), np.reshape(lon_orig_msh,-1)), np.reshape(var,-1)) 

var_intrp = f(lat_target, lon_target)

### Rbf   and   RBFInterpolator
- These methods seem to be superior to griddata, but they are much slower for large datasets.

## pyresample
- This package does the interpolation based on kd tree. 
- Online resources mentioned that this is faster than scipy kd tree, but it is sill slower than griddata for me.
- There are three methods here and they are sorted from fast to slow:

In [None]:
### Preparing grids:
orig_def = pyresample.geometry.SwathDefinition(lons=lon_orig_msh, lats=lat_orig_msh)
targ_def = pyresample.geometry.SwathDefinition(lons=lon_target, lats=lat_target)

### weight function:
wf = lambda r: 1/r**2

### 3 methods:
### Nearest neighbor:
grid_z0 = pyresample.kd_tree.resample_nearest(orig_def, var, \
              targ_def, radius_of_influence=500000, fill_value=None)

### IDW of square distance:
grid_z0 = pyresample.kd_tree.resample_custom(orig_def, var, \
              targ_def,  radius_of_influence=500000, weight_funcs=wf, neighbours=10, fill_value=None)

### Gauss-shape of distance:
grid_z0 = pyresample.kd_tree.resample_gauss(orig_def, var, \
              targ_def,  radius_of_influence=500000, sigmas=250000, neighbours=10, fill_value=None)

## The fastest method: modified griddata:
- If you need to use griddata multiple times between two fixed grids, there is a trick to speed up the process!
- The first step is to calculate the vertices and weights between the original and target grids (first function below). You need to do this relatively timely process only once.
- The second step is to do the interpolation process (second function below) which is now much faster!

In [None]:
import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import itertools

n, d = 3e3, 2  # d can be 2 or 3 depending on the dimension

def interp_weights(xyz, uvw):
    tri = qhull.Delaunay(xyz)
    simplex = tri.find_simplex(uvw)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uvw - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

# def interpolate(values, vtx, wts):
#     return np.einsum('nj,nj->n', np.take(values, vtx), wts)

def interpolate(values, vtx, wts, fill_value=np.nan):
    ret = np.einsum('nj,nj->n', np.take(values, vtx), wts)
    ret[np.any(wts < 0, axis=1)] = fill_value
    return ret

In [None]:
### Original grid:
xyz = np.array(np.vstack((np.reshape(lat_orig_msh,-1), np.reshape(lon_orig_msh,-1)))).T

### Target grid:
uvw = np.vstack((np.reshape(lat_target,-1), np.reshape(lon_target,-1))).T

f   = np.reshape(var,-1)
vtx, wts = interp_weights(xyz, uvw)

tmp = interpolate(f, vtx, wts, fill_value=np.nan)
var_intrp = test.reshape(lat_target.shape[0], lat_target.shape[1])

### Resources:
- https://stackoverflow.com/questions/20915502/speedup-scipy-griddata-for-multiple-interpolations-between-two-irregular-grids
- https://stackoverflow.com/questions/37872171/how-can-i-perform-two-dimensional-interpolation-using-scipy
- http://earthpy.org/interpolation_between_grids_with_pyresample.html