Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add KNeighbors class for nearest neighbor interpolation (#378)
Add a gridder for nearest neighbor interpolation. Allows choosing the number of neighbors and which reduction function to use when combining values (default is mean). Based on initial work by Sarah M. Askevold. Co-authored-by: Sarah M. Askevold <91882127+SAskevold@users.noreply.github.com>
- Loading branch information
Showing
6 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ Interpolators | |
|
||
Spline | ||
SplineCV | ||
KNeighbors | ||
Linear | ||
Cubic | ||
VectorSpline2D | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) 2017 The Verde Developers. | ||
# Distributed under the terms of the BSD 3-Clause License. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# This code is part of the Fatiando a Terra project (https://www.fatiando.org) | ||
# | ||
""" | ||
Gridding with a nearest-neighbors interpolator | ||
============================================== | ||
Verde offers the :class:`verde.KNeighbors` class for nearest-neighbor gridding. | ||
The interpolation looks at the data values of the *k* nearest neighbors of a | ||
interpolated point. If *k* is 1, then the data value of the closest neighbor is | ||
assigned to the point. If *k* is greater than 1, the average value of the | ||
closest *k* neighbors is assigned to the point. | ||
The interpolation works on Cartesian data, so if we want to grid geographic | ||
data (like our Baja California bathymetry) we need to project them into a | ||
Cartesian system. We'll use `pyproj <https://github.com/jswhit/pyproj>`__ to | ||
calculate a Mercator projection for the data. | ||
For convenience, Verde still allows us to make geographic grids by passing the | ||
``projection`` argument to :meth:`verde.KNeighbors.grid` and the like. When | ||
doing so, the grid will be generated using geographic coordinates which will be | ||
projected prior to interpolation. | ||
""" | ||
import cartopy.crs as ccrs | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pyproj | ||
|
||
import verde as vd | ||
|
||
# We'll test this on the Baja California shipborne bathymetry data | ||
data = vd.datasets.fetch_baja_bathymetry() | ||
|
||
# Data decimation using verde.BlockReduce is not necessary here since the | ||
# averaging operation is already performed by the k nearest-neighbor | ||
# interpolator. | ||
|
||
# Project the data using pyproj so that we can use it as input for the gridder. | ||
# We'll set the latitude of true scale to the mean latitude of the data. | ||
projection = pyproj.Proj(proj="merc", lat_ts=data.latitude.mean()) | ||
proj_coordinates = projection(data.longitude, data.latitude) | ||
|
||
# Now we can set up a gridder using the 10 nearest neighbors and averaging | ||
# using using a median instead of a mean (the default). The median is better in | ||
# this case since our data are expected to have sharp changes at ridges and | ||
# faults. | ||
grd = vd.KNeighbors(k=10, reduction=np.median) | ||
grd.fit(proj_coordinates, data.bathymetry_m) | ||
|
||
# Get the grid region in geographic coordinates | ||
region = vd.get_region((data.longitude, data.latitude)) | ||
print("Data region:", region) | ||
|
||
# The 'grid' method can still make a geographic grid if we pass in a projection | ||
# function that converts lon, lat into the easting, northing coordinates that | ||
# we used in 'fit'. This can be any function that takes lon, lat and returns x, | ||
# y. In our case, it'll be the 'projection' variable that we created above. | ||
# We'll also set the names of the grid dimensions and the name the data | ||
# variable in our grid (the default would be 'scalars', which isn't very | ||
# informative). | ||
grid = grd.grid( | ||
region=region, | ||
spacing=1 / 60, | ||
projection=projection, | ||
dims=["latitude", "longitude"], | ||
data_names="bathymetry_m", | ||
) | ||
print("Generated geographic grid:") | ||
print(grid) | ||
|
||
# Cartopy requires setting the coordinate reference system (CRS) of the | ||
# original data through the transform argument. Their docs say to use | ||
# PlateCarree to represent geographic data. | ||
crs = ccrs.PlateCarree() | ||
|
||
plt.figure(figsize=(7, 6)) | ||
# Make a Mercator map of our gridded bathymetry | ||
ax = plt.axes(projection=ccrs.Mercator()) | ||
# Plot the gridded bathymetry | ||
pc = grid.bathymetry_m.plot.pcolormesh( | ||
ax=ax, transform=crs, vmax=0, zorder=-1, add_colorbar=False | ||
) | ||
plt.colorbar(pc).set_label("meters") | ||
# Plot the locations of the data | ||
ax.plot(data.longitude, data.latitude, ".k", markersize=0.1, transform=crs) | ||
# Use an utility function to setup the tick labels and the land feature | ||
vd.datasets.setup_baja_bathymetry_map(ax) | ||
ax.set_title("Nearest-neighbor gridding of bathymetry") | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) 2017 The Verde Developers. | ||
# Distributed under the terms of the BSD 3-Clause License. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# This code is part of the Fatiando a Terra project (https://www.fatiando.org) | ||
# | ||
""" | ||
Nearest neighbor interpolation | ||
""" | ||
import warnings | ||
|
||
import numpy as np | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
from .base import BaseGridder, check_fit_input, n_1d_arrays | ||
from .coordinates import get_region | ||
from .utils import kdtree | ||
|
||
|
||
class KNeighbors(BaseGridder): | ||
""" | ||
Nearest neighbor interpolation. | ||
This gridder assumes Cartesian coordinates. | ||
Interpolation based on the values of the *k* nearest neighbors of each | ||
interpolated point. The number of neighbors *k* can be controlled and | ||
mostly influences the spatial smoothness of the interpolated values. | ||
The data values of the *k* nearest neighbors are combined into a single | ||
value by a reduction function, which defaults to the mean. This can also be | ||
configured. | ||
.. note:: | ||
If installed, package ``pykdtree`` will be used for the nearest | ||
neighbors look-up instead of :class:`scipy.spatial.cKDTree` for better | ||
performance. | ||
Parameters | ||
---------- | ||
k : int | ||
The number of neighbors to use for each interpolated point. Default is | ||
1. | ||
reduction : function | ||
Function used to combine the values of the *k* neighbors into a single | ||
value. Can be any function that takes a 1D numpy array as input and | ||
outputs a single value. Default is :func:`numpy.mean`. | ||
Attributes | ||
---------- | ||
tree_ : K-D tree | ||
An instance of the K-D tree data structure for the data points that is | ||
used to query for nearest neighbors. | ||
data_ : 1D array | ||
A copy of the input data as a 1D array. Used to look up values for | ||
interpolation/prediction. | ||
region_ : tuple | ||
The boundaries (``[W, E, S, N]``) of the data used to fit the | ||
interpolator. Used as the default region for the | ||
:meth:`~verde.KNeighbors.grid`` method. | ||
""" | ||
|
||
def __init__(self, k=1, reduction=np.mean): | ||
super().__init__() | ||
self.k = k | ||
self.reduction = reduction | ||
|
||
def fit(self, coordinates, data, weights=None): | ||
""" | ||
Fit the interpolator to the given data. | ||
The data region is captured and used as default for the | ||
:meth:`~verde.KNeighbors.grid` method. | ||
Parameters | ||
---------- | ||
coordinates : tuple of arrays | ||
Arrays with the coordinates of each data point. Should be in the | ||
following order: (easting, northing, vertical, ...). Only easting | ||
and northing will be used, all subsequent coordinates will be | ||
ignored. | ||
data : array | ||
The data values that will be interpolated. | ||
weights : None or array | ||
Data weights are **not supported** by this interpolator and will be | ||
ignored. Only present for compatibility with other gridders. | ||
Returns | ||
------- | ||
self | ||
Returns this gridder instance for chaining operations. | ||
""" | ||
if weights is not None: | ||
warnings.warn( | ||
"{} does not support weights and they will be ignored.".format( | ||
self.__class__.__name__ | ||
) | ||
) | ||
coordinates, data, weights = check_fit_input(coordinates, data, weights) | ||
self.region_ = get_region(coordinates[:2]) | ||
self.tree_ = kdtree(coordinates[:2]) | ||
# Make sure this is an array and not a subclass of array (pandas, | ||
# xarray, etc) so that we can index it later during predict. | ||
self.data_ = np.asarray(data).ravel().copy() | ||
return self | ||
|
||
def predict(self, coordinates): | ||
""" | ||
Interpolate data on the given set of points. | ||
Requires a fitted gridder (see :meth:`~verde.KNeighbors.fit`). | ||
Parameters | ||
---------- | ||
coordinates : tuple of arrays | ||
Arrays with the coordinates of each data point. Should be in the | ||
following order: (easting, northing, vertical, ...). Only easting | ||
and northing will be used, all subsequent coordinates will be | ||
ignored. | ||
Returns | ||
------- | ||
data : array | ||
The data values interpolated on the given points. | ||
""" | ||
check_is_fitted(self, ["tree_"]) | ||
distances, indices = self.tree_.query( | ||
np.transpose(n_1d_arrays(coordinates, 2)), k=self.k | ||
) | ||
if indices.ndim == 1: | ||
indices = np.atleast_2d(indices).T | ||
neighbor_values = np.reshape(self.data_[indices.ravel()], indices.shape) | ||
data = self.reduction(neighbor_values, axis=1) | ||
shape = np.broadcast(*coordinates[:2]).shape | ||
return data.reshape(shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) 2017 The Verde Developers. | ||
# Distributed under the terms of the BSD 3-Clause License. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# This code is part of the Fatiando a Terra project (https://www.fatiando.org) | ||
# | ||
""" | ||
Test the nearest neighbors interpolator. | ||
""" | ||
import warnings | ||
|
||
import numpy as np | ||
import numpy.testing as npt | ||
import pytest | ||
|
||
from ..coordinates import grid_coordinates | ||
from ..neighbors import KNeighbors | ||
from ..synthetic import CheckerBoard | ||
|
||
|
||
def test_neighbors_same_points(): | ||
"See if the gridder recovers known points." | ||
region = (1000, 5000, -8000, -7000) | ||
synth = CheckerBoard(region=region) | ||
data = synth.scatter(size=1000, random_state=0) | ||
coords = (data.easting, data.northing) | ||
# The interpolation should be perfect on top of the data points | ||
gridder = KNeighbors() | ||
gridder.fit(coords, data.scalars) | ||
predicted = gridder.predict(coords) | ||
npt.assert_allclose(predicted, data.scalars) | ||
npt.assert_allclose(gridder.score(coords, data.scalars), 1) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"gridder", | ||
[ | ||
KNeighbors(), | ||
KNeighbors(k=1), | ||
KNeighbors(k=2), | ||
KNeighbors(k=10), | ||
KNeighbors(k=1, reduction=np.median), | ||
], | ||
ids=[ | ||
"k=default", | ||
"k=1", | ||
"k=2", | ||
"k=10", | ||
"median", | ||
], | ||
) | ||
def test_neighbors(gridder): | ||
"See if the gridder recovers known points." | ||
region = (1000, 5000, -8000, -6000) | ||
synth = CheckerBoard(region=region) | ||
data_coords = grid_coordinates(region, shape=(100, 100)) | ||
data = synth.predict(data_coords) | ||
coords = grid_coordinates(region, shape=(95, 95)) | ||
true_data = synth.predict(coords) | ||
# nearest will never be too close to the truth | ||
gridder.fit(data_coords, data) | ||
npt.assert_allclose(gridder.predict(coords), true_data, rtol=0, atol=100) | ||
|
||
|
||
def test_neighbors_weights_warning(): | ||
"Check that a warning is issued when using weights." | ||
data = CheckerBoard().scatter(random_state=100) | ||
weights = np.ones_like(data.scalars) | ||
grd = KNeighbors() | ||
msg = "KNeighbors does not support weights and they will be ignored." | ||
with warnings.catch_warnings(record=True) as warn: | ||
grd.fit((data.easting, data.northing), data.scalars, weights=weights) | ||
assert len(warn) == 1 | ||
assert issubclass(warn[-1].category, UserWarning) | ||
assert str(warn[-1].message) == msg |