Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Nearest Neighbor Searcher as Alternative for Inter-/Extrapolation #232

Merged
merged 9 commits into from
Jul 4, 2023
2 changes: 2 additions & 0 deletions docs/changes/232.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This PR adds a DiscretePDFNearestNeighborSearcher and a ParametrizedNearestNeighborSeacher to support nearest neighbor approaches
as alternatives to inter-/ and extrapolation
8 changes: 8 additions & 0 deletions pyirf/interpolation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@
)
from .griddata_interpolator import GridDataInterpolator
from .moment_morph_interpolator import MomentMorphInterpolator
from .nearest_neighbor_searcher import (
BaseNearestNeighborSearcher,
DiscretePDFNearestNeighborSearcher,
ParametrizedNearestNeighborSearcher,
)
from .quantile_interpolator import QuantileInterpolator

__all__ = [
"BaseComponentEstimator",
"BaseInterpolator",
"BaseNearestNeighborSearcher",
"DiscretePDFComponentEstimator",
"DiscretePDFInterpolator",
"DiscretePDFNearestNeighborSearcher",
"GridDataInterpolator",
"MomentMorphInterpolator",
"ParametrizedComponentEstimator",
"ParametrizedInterpolator",
"ParametrizedNearestNeighborSearcher",
"QuantileInterpolator",
"EffectiveAreaEstimator",
"RadMaxEstimator",
Expand Down
50 changes: 32 additions & 18 deletions pyirf/interpolation/component_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
ParametrizedInterpolator,
)
from pyirf.interpolation.griddata_interpolator import GridDataInterpolator
from pyirf.interpolation.nearest_neighbor_searcher import (
DiscretePDFNearestNeighborSearcher,
ParametrizedNearestNeighborSearcher,
)
from pyirf.interpolation.quantile_interpolator import QuantileInterpolator
from pyirf.utils import cone_solid_angle
from scipy.spatial import Delaunay
Expand Down Expand Up @@ -169,7 +173,7 @@ def __init__(
grid_points: np.ndarray, shape=(n_points, n_dims):
Grid points at which interpolation templates exist
bin_edges: np.ndarray, shape=(n_bins+1)
Common set of bin-edges for all discretized PDFs
Common set of bin-edges for all discretized PDFs.
bin_contents: np.ndarray, shape=(n_points, ..., n_bins)
Discretized PDFs for all grid points and arbitrary further dimensions
(in IRF term e.g. field-of-view offset bins). Actual interpolation dimension,
Expand All @@ -191,14 +195,15 @@ def __init__(
Raises
------
TypeError:
When bin_edges is not a np.ndarray
When bin_edges is not a np.ndarray.
TypeError:
When bin_content is not a np.ndarray
TypeError:
When interpolator_cls is not a BinnedInterpolator subclass.
When interpolator_cls is not a DiscretePDFInterpolator subclass or
DiscretePDFNearestNeighborSeacher.
ValueError:
When number of bins in bin_edges and contents bin_contents is
not matching
When number of bins in bin_edges and contents in bin_contents is
not matching.
ValueError:
When number of histograms in bin_contents and points in grid_points
is not matching
Expand All @@ -212,31 +217,35 @@ def __init__(
grid_points,
)

if not isinstance(bin_edges, np.ndarray):
raise TypeError("Input bin_edges is not a numpy array.")
elif not isinstance(bin_contents, np.ndarray):
if not isinstance(bin_contents, np.ndarray):
raise TypeError("Input bin_contents is not a numpy array.")
elif bin_contents.shape[-1] != (bin_edges.shape[0] - 1):
raise ValueError(
f"Shape missmatch, bin_edges ({bin_edges.shape[0] - 1} bins) "
f"and bin_contents ({bin_contents.shape[-1]} bins) not matching."
)
elif self.n_points != bin_contents.shape[0]:
raise ValueError(
f"Shape missmatch, number of grid_points ({self.n_points}) and "
f"number of histograms in bin_contents ({bin_contents.shape[0]}) "
"not matching."
)
elif not isinstance(bin_edges, np.ndarray):
raise TypeError("Input bin_edges is not a numpy array.")
elif bin_contents.shape[-1] != (bin_edges.shape[0] - 1):
raise ValueError(
f"Shape missmatch, bin_edges ({bin_edges.shape[0] - 1} bins) "
f"and bin_contents ({bin_contents.shape[-1]} bins) not matching."
)

if interpolator_kwargs is None:
interpolator_kwargs = {}

if extrapolator_kwargs is None:
extrapolator_kwargs = {}

if not issubclass(interpolator_cls, DiscretePDFInterpolator):
if not (
issubclass(interpolator_cls, DiscretePDFInterpolator)
or issubclass(interpolator_cls, DiscretePDFNearestNeighborSearcher)
RuneDominik marked this conversation as resolved.
Show resolved Hide resolved
):
raise TypeError(
f"interpolator_cls must be a DiscretePDFInterpolator subclass, got {interpolator_cls}"
"interpolator_cls must be a DiscretePDFInterpolator subclass or "
f"DiscretePDFNearestNeighborSearcher, got {interpolator_cls}"
)

self.interpolator = interpolator_cls(
Expand Down Expand Up @@ -296,7 +305,8 @@ def __init__(
Raises
------
TypeError:
When interpolator_cls is not a ParametrizedInterpolator subclass.
When interpolator_cls is not a ParametrizedInterpolator subclass
or ParametrizedNearestNeighborSearcher.
TypeError:
When params is not a np.ndarray
ValueError:
Expand Down Expand Up @@ -324,9 +334,13 @@ def __init__(
if extrapolator_kwargs is None:
extrapolator_kwargs = {}

if not issubclass(interpolator_cls, ParametrizedInterpolator):
if not (
issubclass(interpolator_cls, ParametrizedInterpolator)
or issubclass(interpolator_cls, ParametrizedNearestNeighborSearcher)
):
raise TypeError(
f"interpolator_cls must be a ParametrizedInterpolator subclass, got {interpolator_cls}"
"interpolator_cls must be a ParametrizedInterpolator subclass or "
f"ParametrizedNearestNeighborSearcher, got {interpolator_cls}"
)

self.interpolator = interpolator_cls(grid_points, params, **interpolator_kwargs)
Expand Down
162 changes: 162 additions & 0 deletions pyirf/interpolation/nearest_neighbor_searcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import numpy as np

from .base_interpolators import BaseInterpolator

__all__ = [
"BaseNearestNeighborSearcher",
"DiscretePDFNearestNeighborSearcher",
"ParametrizedNearestNeighborSearcher",
]


class BaseNearestNeighborSearcher(BaseInterpolator):
"""
Dummy NearestNeighbor approach usable instead of
actual Interpolation/Extrapolation
"""

def __init__(self, grid_points, contents, norm_ord=2):
"""
BaseNearestNeighborSearcher

Parameters
----------
grid_points: np.ndarray, shape=(n_points, n_dims)
Grid points at which templates exist
contents: np.ndarray, shape=(n_points, ...)
Corresponding IRF contents at grid_points
norm_ord: non-zero int
Order of the norm which is used to compute the distances,
passed to numpy.linalg.norm [1]. Defaults to 2,
which uses the euclidean norm.

Raises
------
TypeError:
If norm_ord is not non-zero integer

Note
----
Also calls pyirf.interpolation.BaseInterpolators.__init__
"""

super().__init__(grid_points)

self.contents = contents

# Test wether norm_ord is a number
try:
norm_ord > 0
except TypeError:
raise ValueError(
f"Only positiv integers allowed for norm_ord, got {norm_ord}."
)

# Test wether norm_ord is a finite, positive integer
if (norm_ord <= 0) or ~np.isfinite(norm_ord) or (norm_ord != int(norm_ord)):
raise ValueError(
f"Only positiv integers allowed for norm_ord, got {norm_ord}."
)

self.norm_ord = norm_ord

def interpolate(self, target_point):
"""
Takes a grid of IRF contents for a bunch of different parameters
and returns the contents at the nearest grid point
as seen from the target point.

Parameters
----------
target_point: numpy.ndarray
Value for which the nearest neighbor should be found (target point)

Returns
-------
content_new: numpy.ndarray, shape=(1,...,M,...)
Contents at nearest neighbor

Note
----
In case of multiple nearest neighbors, the contents corresponding
to the first one are returned.
"""

if target_point.ndim == 1:
target_point = target_point.reshape(1, *target_point.shape)

Check warning on line 86 in pyirf/interpolation/nearest_neighbor_searcher.py

View check run for this annotation

Codecov / codecov/patch

pyirf/interpolation/nearest_neighbor_searcher.py#L86

Added line #L86 was not covered by tests

distances = np.linalg.norm(
self.grid_points - target_point, ord=self.norm_ord, axis=1
)

index = np.argmin(distances)

return self.contents[index, :]


class DiscretePDFNearestNeighborSearcher(BaseNearestNeighborSearcher):
"""
Dummy NearestNeighbor approach usable instead of
actual Interpolation/Extrapolation.
Compatible with discretized PDF IRF component API.
"""

def __init__(self, grid_points, bin_edges, bin_contents, norm_ord=2):
"""
NearestNeighborSearcher compatible with discretized PDF IRF components API

Parameters
----------
grid_points: np.ndarray, shape=(n_points, n_dims)
Grid points at which templates exist
bin_edges: np.ndarray, shape=(n_bins+1)
Edges of the data binning. Ignored for nearest neighbor searching.
bin_content: np.ndarray, shape=(n_points, ..., n_bins)
Content of each bin in bin_edges for
each point in grid_points. First dimesion has to correspond to number
of grid_points, last dimension has to correspond to number of bins for
the quantity that should be interpolated (e.g. the Migra axis for EDisp)
norm_ord: non-zero int
Order of the norm which is used to compute the distances,
passed to numpy.linalg.norm [1]. Defaults to 2,
which uses the euclidean norm.

Note
----
Also calls pyirf.interpolation.BaseNearestNeighborSearcher.__init__
"""

super().__init__(
grid_points=grid_points, contents=bin_contents, norm_ord=norm_ord
)


class ParametrizedNearestNeighborSearcher(BaseNearestNeighborSearcher):
"""
Dummy NearestNeighbor approach usable instead of
actual Interpolation/Extrapolation
Compatible with parametrized IRF component API.
"""

def __init__(self, grid_points, params, norm_ord=2):
"""
NearestNeighborSearcher compatible with parametrized IRF components API

Parameters
----------
grid_points: np.ndarray, shape=(n_points, n_dims)
Grid points at which templates exist
params: np.ndarray, shape=(n_points, ..., n_params)
Corresponding parameter values at each point in grid_points.
First dimesion has to correspond to number of grid_points
norm_ord: non-zero int
Order of the norm which is used to compute the distances,
passed to numpy.linalg.norm [1]. Defaults to 2,
which uses the euclidean norm.

Note
----
Also calls pyirf.interpolation.BaseNearestNeighborSearcher.__init__
"""

super().__init__(grid_points=grid_points, contents=params, norm_ord=norm_ord)