This notebook defines a function to find the nearest gridpoint to a specified target longitude and latitude and demonstrates how to use that function.

## Dependencies
The following packages are required. If you don't have of these packages try installing with `conda install package-name`

In [1]:
import math

import numpy as np
from pykdtree.kdtree import KDTree
import xarray as xr

# Function definition
The two cells below define the function`find_nearest_grid_point()` and a helpder function `_find_index()`. The docstrings explain the arguments and what is returned.

In [2]:
def find_nearest_grid_point(
        lat, lon, dataset, lat_var_name, lon_var_name, n=1
):
    """Find the nearest grid point to a given lat/lon pair.

    Parameters
    ----------
    lat : float
        Latitude value at which to find the nearest grid point.
    lon : float
        Longitude value at which to find the nearest grid point.
    dataset : xarray.Dataset
        An xarray Dataset containing the mesh variables.
    lat_var_name : str
        Name of the latitude variable in the dataset.
    lon_var_name : str
        Name of the longitude variable in the dataset.
    n : int, optional
        Number of nearest grid points to return. Default is to return the
        single closest grid point.

    Returns
    -------
    dist_sq, iy, ix, lat_near, lon_near
        A tuple of numpy arrays:

        - ``dist_sq``: the squared distance between the given lat/lon location
          and the nearest grid points
        - ``iy``: the y indices of the nearest grid points
        - ``ix``: the x indices of the nearest grid points
        - ``lat_near``: the latitude values of the nearest grid points
        - ``lon_near``: the longitude values of the nearest grid points
    """

    # Note the use of the squeeze method: it removes single-dimensional entries
    # from the shape of an array. For example, in the GIOPS mesh file the
    # longitude of the U velocity points is defined as an array with shape
    # (1, 1, 1021, 1442). The squeeze method converts this into the equivalent
    # array with shape (1021, 1442).
    latvar = dataset.variables[lat_var_name].squeeze()
    lonvar = dataset.variables[lon_var_name].squeeze()

    rad_factor = math.pi / 180.0
    latvals = latvar[:] * rad_factor
    lonvals = lonvar[:] * rad_factor
    clat, clon = np.cos(latvals), np.cos(lonvals)
    slat, slon = np.sin(latvals), np.sin(lonvals)
    if latvar.ndim == 1:
        # If latitude and longitude are 1D arrays (as is the case with the
        # GIOPS forecast data currently pulled from datamart), then we need to
        # handle this situation in a special manner. The clat array will be of
        # some size m, say, and the clon array will be of size n. By virtue of
        # being defined with different dimensions, the product of these two
        # arrays will be of size (m, n) because xarray will automatically
        # broadcast the arrays so that the multiplication makes sense to do.
        # Thus, the array calculated from
        #
        #   np.ravel(clat * clon)
        #
        # will be of size mn. However, the array
        #
        #   np.ravel(slat)
        #
        # will be of size m and this will cause the KDTree() call to fail. To
        # resolve this issue, we broadcast slat to the appropriate size and
        # shape.
        shape = (slat.size, slon.size)
        slat = np.broadcast_to(slat.values[:, np.newaxis], shape)
    else:
        shape = latvar.shape
    triples = np.array([np.ravel(clat * clon), np.ravel(clat * slon),
                            np.ravel(slat)]).transpose()
    kdt = KDTree(triples)
        

    dist_sq, iy, ix = _find_index(lat, lon, kdt, shape, n)
    # The results returned from _find_index are two-dimensional arrays (if
    # n > 1) because it can handle the case of finding indices closest to
    # multiple lat/lon locations (i.e., where lat and lon are arrays, not
    # scalars). Currently, this function is intended only for a single lat/lon,
    # so we redefine the results as one-dimensional arrays.
    if n > 1:
        dist_sq = dist_sq[0, :]
        iy = iy[0, :]
        ix = ix[0, :]

    if latvar.ndim == 1:
        lat_near = latvar.values[iy]
        lon_near = lonvar.values[ix]
    else:
        lat_near = latvar.values[iy, ix]
        lon_near = lonvar.values[iy, ix]

    # Most datasets have longitude defined over the range -180 to +180. The
    # GIOPS forecast data, however, currently uses a 0 to 360 range, so we
    # adjust those values where necessary.
    lon_near[lon_near > 180] -= 360

    return dist_sq, iy, ix, lat_near, lon_near

In [3]:
def _find_index(lat0, lon0, kdt, shape, n=1):
    """Finds the y, x indicies that are closest to a latitude, longitude pair.

    Arguments:
        lat0 -- the target latitude
        lon0 -- the target longitude
        n -- the number of indicies to return

    Returns:
        squared distance, y, x indicies
    """
    if hasattr(lat0, "__len__"):
        lat0 = np.array(lat0)
        lon0 = np.array(lon0)
        multiple = True
    else:
        multiple = False
    rad_factor = math.pi / 180.0
    lat0_rad = lat0 * rad_factor
    lon0_rad = lon0 * rad_factor
    clat0, clon0 = np.cos(lat0_rad), np.cos(lon0_rad)
    slat0, slon0 = np.sin(lat0_rad), np.sin(lon0_rad)
    q = [clat0 * clon0, clat0 * slon0, slat0]
    if multiple:
        q = np.array(q).transpose()
    else:
        q = np.array(q)
        q = q[np.newaxis, :]
    dist_sq_min, minindex_1d = kdt.query(np.float32(q), k=n)
    iy_min, ix_min = np.unravel_index(minindex_1d, shape)
    return dist_sq_min, iy_min, ix_min

# Example of using the function
The cells below describe how to use the function using the GLORS12 mesh file.

In [4]:
d = xr.open_dataset('/terra/GLORYS12v1/Cgrid/PSY4V3R1_mesh_all.nc')
d

<xarray.Dataset>
Dimensions:       (t: 1, x: 4322, y: 1559, z: 50)
Dimensions without coordinates: t, x, y, z
Data variables:
    nav_lon       (y, x) float32 ...
    nav_lat       (y, x) float32 ...
    nav_lev       (z) float32 ...
    time_counter  (t) float64 ...
    tmask         (t, z, y, x) uint8 ...
    umask         (t, z, y, x) uint8 ...
    vmask         (t, z, y, x) uint8 ...
    fmask         (t, z, y, x) uint8 ...
    tmaskutil     (t, y, x) uint8 ...
    umaskutil     (t, y, x) uint8 ...
    vmaskutil     (t, y, x) uint8 ...
    fmaskutil     (t, y, x) uint8 ...
    glamt         (t, y, x) float32 ...
    glamu         (t, y, x) float32 ...
    glamv         (t, y, x) float32 ...
    glamf         (t, y, x) float32 ...
    gphit         (t, y, x) float32 ...
    gphiu         (t, y, x) float32 ...
    gphiv         (t, y, x) float32 ...
    gphif         (t, y, x) float32 ...
    e1t           (t, y, x) float64 ...
    e1u           (t, y, x) float64 ...
    e1v         

The longitude and latitude are called `nav_lon` and `nav_lat`. These arguments will be used in `find_nearest_grid_cell()`.

In [5]:
target_lon = -66
target_lat = 44
lon_var_name = 'nav_lon'
lat_var_name = 'nav_lat'

dist_sq, iy, ix, lat_near, lon_near = find_nearest_grid_point(target_lat, target_lon, d, lat_var_name, lon_var_name)

The y index and x index are printed below:

In [6]:
print(iy, ix)

[593] [2658]


We can also print the longitude and latitude of the nearest grid point

In [7]:
print(lat_near, lon_near)

[43.975548] [-65.96261]
