Skip to content

Commit

Permalink
Merge pull request #232 from cta-observatory/NearestNeighborSearcher
Browse files Browse the repository at this point in the history
Add Nearest Neighbor Searcher as Alternative for Inter-/Extrapolation
  • Loading branch information
maxnoe committed Jul 4, 2023
2 parents fdccc13 + 83ea9aa commit 5c0d8ac
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/changes/232.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This PR adds a DiscretePDFNearestNeighborSearcher and a ParametrizedNearestNeighborSeacher to support nearest neighbor approaches
as alternatives to inter-/ and extrapolation
8 changes: 8 additions & 0 deletions pyirf/interpolation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@
)
from .griddata_interpolator import GridDataInterpolator
from .moment_morph_interpolator import MomentMorphInterpolator
from .nearest_neighbor_searcher import (
BaseNearestNeighborSearcher,
DiscretePDFNearestNeighborSearcher,
ParametrizedNearestNeighborSearcher,
)
from .quantile_interpolator import QuantileInterpolator

__all__ = [
"BaseComponentEstimator",
"BaseInterpolator",
"BaseNearestNeighborSearcher",
"DiscretePDFComponentEstimator",
"DiscretePDFInterpolator",
"DiscretePDFNearestNeighborSearcher",
"GridDataInterpolator",
"MomentMorphInterpolator",
"ParametrizedComponentEstimator",
"ParametrizedInterpolator",
"ParametrizedNearestNeighborSearcher",
"QuantileInterpolator",
"EffectiveAreaEstimator",
"RadMaxEstimator",
Expand Down
2 changes: 1 addition & 1 deletion pyirf/interpolation/base_interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BaseInterpolator(metaclass=ABCMeta):
"""

def __init__(self, grid_points):
"""BaseInterpolator
"""BaseInterpolator
Parameters
----------
Expand Down
31 changes: 17 additions & 14 deletions pyirf/interpolation/component_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
ParametrizedInterpolator,
)
from pyirf.interpolation.griddata_interpolator import GridDataInterpolator
from pyirf.interpolation.nearest_neighbor_searcher import (
DiscretePDFNearestNeighborSearcher,
ParametrizedNearestNeighborSearcher,
)
from pyirf.interpolation.quantile_interpolator import QuantileInterpolator
from pyirf.utils import cone_solid_angle
from scipy.spatial import Delaunay
Expand Down Expand Up @@ -39,7 +43,6 @@ def __init__(self, grid_points):
Parameters
----------
grid_points: np.ndarray, shape=(n_points, n_dims):
Grid points at which interpolation templates exist
Raises
------
Expand Down Expand Up @@ -169,7 +172,7 @@ def __init__(
grid_points: np.ndarray, shape=(n_points, n_dims):
Grid points at which interpolation templates exist
bin_edges: np.ndarray, shape=(n_bins+1)
Common set of bin-edges for all discretized PDFs
Common set of bin-edges for all discretized PDFs.
bin_contents: np.ndarray, shape=(n_points, ..., n_bins)
Discretized PDFs for all grid points and arbitrary further dimensions
(in IRF term e.g. field-of-view offset bins). Actual interpolation dimension,
Expand All @@ -191,14 +194,14 @@ def __init__(
Raises
------
TypeError:
When bin_edges is not a np.ndarray
When bin_edges is not a np.ndarray.
TypeError:
When bin_content is not a np.ndarray
TypeError:
When interpolator_cls is not a BinnedInterpolator subclass.
When interpolator_cls is not a DiscretePDFInterpolator subclass.
ValueError:
When number of bins in bin_edges and contents bin_contents is
not matching
When number of bins in bin_edges and contents in bin_contents is
not matching.
ValueError:
When number of histograms in bin_contents and points in grid_points
is not matching
Expand All @@ -212,21 +215,21 @@ def __init__(
grid_points,
)

if not isinstance(bin_edges, np.ndarray):
raise TypeError("Input bin_edges is not a numpy array.")
elif not isinstance(bin_contents, np.ndarray):
if not isinstance(bin_contents, np.ndarray):
raise TypeError("Input bin_contents is not a numpy array.")
elif bin_contents.shape[-1] != (bin_edges.shape[0] - 1):
raise ValueError(
f"Shape missmatch, bin_edges ({bin_edges.shape[0] - 1} bins) "
f"and bin_contents ({bin_contents.shape[-1]} bins) not matching."
)
elif self.n_points != bin_contents.shape[0]:
raise ValueError(
f"Shape missmatch, number of grid_points ({self.n_points}) and "
f"number of histograms in bin_contents ({bin_contents.shape[0]}) "
"not matching."
)
elif not isinstance(bin_edges, np.ndarray):
raise TypeError("Input bin_edges is not a numpy array.")
elif bin_contents.shape[-1] != (bin_edges.shape[0] - 1):
raise ValueError(
f"Shape missmatch, bin_edges ({bin_edges.shape[0] - 1} bins) "
f"and bin_contents ({bin_contents.shape[-1]} bins) not matching."
)

if interpolator_kwargs is None:
interpolator_kwargs = {}
Expand Down
172 changes: 172 additions & 0 deletions pyirf/interpolation/nearest_neighbor_searcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import numpy as np

from .base_interpolators import (
BaseInterpolator,
DiscretePDFInterpolator,
ParametrizedInterpolator,
)

__all__ = [
"BaseNearestNeighborSearcher",
"DiscretePDFNearestNeighborSearcher",
"ParametrizedNearestNeighborSearcher",
]


class BaseNearestNeighborSearcher(BaseInterpolator):
"""
Dummy NearestNeighbor approach usable instead of
actual Interpolation/Extrapolation
"""

def __init__(self, grid_points, contents, norm_ord=2):
"""
BaseNearestNeighborSearcher
Parameters
----------
grid_points: np.ndarray, shape=(n_points, n_dims)
Grid points at which templates exist
contents: np.ndarray, shape=(n_points, ...)
Corresponding IRF contents at grid_points
norm_ord: non-zero int
Order of the norm which is used to compute the distances,
passed to numpy.linalg.norm [1]. Defaults to 2,
which uses the euclidean norm.
Raises
------
TypeError:
If norm_ord is not non-zero integer
Note
----
Also calls pyirf.interpolation.BaseInterpolators.__init__
"""

super().__init__(grid_points)

self.contents = contents

# Test wether norm_ord is a number
try:
norm_ord > 0
except TypeError:
raise ValueError(
f"Only positiv integers allowed for norm_ord, got {norm_ord}."
)

# Test wether norm_ord is a finite, positive integer
if (norm_ord <= 0) or ~np.isfinite(norm_ord) or (norm_ord != int(norm_ord)):
raise ValueError(
f"Only positiv integers allowed for norm_ord, got {norm_ord}."
)

self.norm_ord = norm_ord

def interpolate(self, target_point):
"""
Takes a grid of IRF contents for a bunch of different parameters
and returns the contents at the nearest grid point
as seen from the target point.
Parameters
----------
target_point: numpy.ndarray
Value for which the nearest neighbor should be found (target point)
Returns
-------
content_new: numpy.ndarray, shape=(1,...,M,...)
Contents at nearest neighbor
Note
----
In case of multiple nearest neighbors, the contents corresponding
to the first one are returned.
"""

if target_point.ndim == 1:
target_point = target_point.reshape(1, *target_point.shape)

distances = np.linalg.norm(
self.grid_points - target_point, ord=self.norm_ord, axis=1
)

index = np.argmin(distances)

return self.contents[index, :]


class DiscretePDFNearestNeighborSearcher(BaseNearestNeighborSearcher):
"""
Dummy NearestNeighbor approach usable instead of
actual Interpolation/Extrapolation.
Compatible with discretized PDF IRF component API.
"""

def __init__(self, grid_points, bin_edges, bin_contents, norm_ord=2):
"""
NearestNeighborSearcher compatible with discretized PDF IRF components API
Parameters
----------
grid_points: np.ndarray, shape=(n_points, n_dims)
Grid points at which templates exist
bin_edges: np.ndarray, shape=(n_bins+1)
Edges of the data binning. Ignored for nearest neighbor searching.
bin_content: np.ndarray, shape=(n_points, ..., n_bins)
Content of each bin in bin_edges for
each point in grid_points. First dimesion has to correspond to number
of grid_points, last dimension has to correspond to number of bins for
the quantity that should be interpolated (e.g. the Migra axis for EDisp)
norm_ord: non-zero int
Order of the norm which is used to compute the distances,
passed to numpy.linalg.norm [1]. Defaults to 2,
which uses the euclidean norm.
Note
----
Also calls pyirf.interpolation.BaseNearestNeighborSearcher.__init__
"""

super().__init__(
grid_points=grid_points, contents=bin_contents, norm_ord=norm_ord
)


DiscretePDFInterpolator.register(DiscretePDFNearestNeighborSearcher)


class ParametrizedNearestNeighborSearcher(BaseNearestNeighborSearcher):
"""
Dummy NearestNeighbor approach usable instead of
actual Interpolation/Extrapolation
Compatible with parametrized IRF component API.
"""

def __init__(self, grid_points, params, norm_ord=2):
"""
NearestNeighborSearcher compatible with parametrized IRF components API
Parameters
----------
grid_points: np.ndarray, shape=(n_points, n_dims)
Grid points at which templates exist
params: np.ndarray, shape=(n_points, ..., n_params)
Corresponding parameter values at each point in grid_points.
First dimesion has to correspond to number of grid_points
norm_ord: non-zero int
Order of the norm which is used to compute the distances,
passed to numpy.linalg.norm [1]. Defaults to 2,
which uses the euclidean norm.
Note
----
Also calls pyirf.interpolation.BaseNearestNeighborSearcher.__init__
"""

super().__init__(grid_points=grid_points, contents=params, norm_ord=norm_ord)


ParametrizedInterpolator.register(ParametrizedNearestNeighborSearcher)
15 changes: 15 additions & 0 deletions pyirf/interpolation/tests/test_base_interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,18 @@ def interpolate(self, target_point, **kwargs):

interp2D = DummyBinnedInterpolator(grid_points2D, bin_edges, bin_content)
assert interp2D(target2D) == 42


def test_virtual_subclasses():
"""Tests that corresponding nearest neighbor seacher are virtual sublasses of interpolators"""
from pyirf.interpolation import (
DiscretePDFInterpolator,
DiscretePDFNearestNeighborSearcher,
ParametrizedInterpolator,
ParametrizedNearestNeighborSearcher,
)

assert issubclass(DiscretePDFNearestNeighborSearcher, DiscretePDFInterpolator)
assert issubclass(ParametrizedNearestNeighborSearcher, ParametrizedInterpolator)
assert not issubclass(ParametrizedNearestNeighborSearcher, DiscretePDFInterpolator)
assert not issubclass(DiscretePDFNearestNeighborSearcher, ParametrizedInterpolator)

0 comments on commit 5c0d8ac

Please sign in to comment.