Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions movement/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import xarray as xr
from scipy import signal

from movement.kinematics import compute_displacement
from movement.utils.logging import log_error, log_to_attrs
from movement.utils.reports import report_nan_values
from movement.utils.vector import compute_norm


@log_to_attrs
Expand Down Expand Up @@ -60,6 +62,54 @@
return data_filtered


def filter_by_displacement(
position: xr.DataArray,
threshold: float = 10.0,
print_report: bool = False,
) -> xr.DataArray:
"""Drop data points with a displacement above a certain distance threshold.

Frames in the ``position`` array that have a displacement magnitude above
the given ``threshold`` are set to NaN. In effect, if a point at time ``t``
has moved more than the ``threshold`` euclidean distance from the same
point at time ``t-1``, its value at time ``t`` is set to NaN.

Parameters
----------
position : xr.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
threshold : float, optional
The maximum euclidean distance allowed between 2 consecutive positions.
Defaults to 10.0.
print_report : bool, optional
Whether to print a report of the number of NaN values before and after
filtering. Defaults to False.

Returns
-------
xr.DataArray
The filtered position array.

See Also
--------
movement.kinematics.compute_displacement:
The function used to compute an array of displacement vectors.
movement.utils.vector.compute_norm:
The function used to compute distance as the magnitude of
displacement vectors.

"""
displacement = compute_displacement(position)
distance = compute_norm(displacement)
position_filtered = position.where(distance < threshold)

Check warning on line 105 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L103-L105

Added lines #L103 - L105 were not covered by tests

if print_report:
print(report_nan_values(position, "input"))
print(report_nan_values(position_filtered, "output"))
return position_filtered

Check warning on line 110 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L107-L110

Added lines #L107 - L110 were not covered by tests


@log_to_attrs
def interpolate_over_time(
data: xr.DataArray,
Expand Down