diff --git a/verde/tests/test_utils.py b/verde/tests/test_utils.py index b39cf70ea..27b5ed82b 100644 --- a/verde/tests/test_utils.py +++ b/verde/tests/test_utils.py @@ -27,9 +27,27 @@ meshgrid_to_1d, parse_engine, partition_by_sum, + fill_nans ) +def test_fill_nans(): + """ + Test filling NaNs on a small sample grid + """ + grid = xr.DataArray([[1, np.nan, 3], + [4, 5, np.nan], + [np.nan, 7, 8]]) + + filled_grid = fill_nans(grid) + expected_values = xr.DataArray([[1, 1, 3], + [4, 5, 3], + [4, 7, 8]]) + + assert np.any(np.isnan(filled_grid)) + assert np.allclose(filled_grid, expected_values) + + def test_parse_engine(): "Check that it works for common input" assert parse_engine("numba") == "numba" diff --git a/verde/utils.py b/verde/utils.py index 2761aea04..3930f11b2 100644 --- a/verde/utils.py +++ b/verde/utils.py @@ -13,8 +13,10 @@ import numpy as np import pandas as pd import xarray as xr +import verde as vd from scipy.spatial import cKDTree + try: from pykdtree.kdtree import KDTree as pyKDTree except ImportError: @@ -681,6 +683,37 @@ def kdtree(coordinates, use_pykdtree=True, **kwargs): return tree +def fill_nans(grid): + """ + Fill missing values in a grid by nearest neighbor interpolation + + Parameters + ---------- + grid : :class:`xarray.DataArray` + A 2D grid with one or more data variables. + Returns + ------- + grid : :class:`xarray.DataArray` + A 2D grid with the NaN values filled. + """ + + filled_grid = grid.copy() + + not_nan_values = np.argwhere(~np.isnan(grid.values)) + unknown_indices = np.argwhere(np.isnan(grid.values)) + + knn = vd.KNeighbors() + easting, northing = not_nan_values[:, 0], not_nan_values[:, 1] + knn.fit((easting, northing), grid.values[not_nan_values[:, 0], + not_nan_values[:, 1]]) + predicted_values = knn_imputer.predict((easting, northing)) + + for i, idx in enumerate(unknown_indices): + filled_grid[tuple(idx)] = predicted_values[i] + + return filled_grid + + def partition_by_sum(array, parts): """ Partition an array into parts of approximately equal sum.