Skip to content

Commit

Permalink
Enforce linting of docstrings (#165)
Browse files Browse the repository at this point in the history
* enforce pydocstyle linting

* correct a typo

* ignore rule D301 on docstring containing newlines
  • Loading branch information
niksirbi committed Apr 22, 2024
1 parent 6702100 commit 2744de2
Show file tree
Hide file tree
Showing 25 changed files with 316 additions and 222 deletions.
5 changes: 2 additions & 3 deletions examples/compute_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ruff: noqa: E402
"""
Compute and visualise kinematics
=================================
"""Compute and visualise kinematics.
====================================
Compute displacement, velocity and acceleration data on an example dataset and
visualise the results.
Expand Down
3 changes: 1 addition & 2 deletions examples/filter_and_interpolate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Filtering and interpolation
"""Filtering and interpolation
============================
Filter out points with low confidence scores and interpolate over
Expand Down
3 changes: 1 addition & 2 deletions examples/load_and_explore_poses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Load and explore pose tracks
"""Load and explore pose tracks
============================
Load and explore an example dataset of pose tracks.
Expand Down
40 changes: 24 additions & 16 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Functions for computing kinematic variables."""

import numpy as np
import xarray as xr

from movement.logging import log_error


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
"""Compute the displacement between consecutive positions
of each keypoint for each individual across time.
"""Compute displacement between consecutive positions.
This is the difference between consecutive positions of each keypoint for
each individual across time. At each time point ``t``, it's defined as a
vector in cartesian ``(x,y)`` coordinates, pointing from the previous
``(t-1)`` to the current ``(t)`` position.
Parameters
----------
Expand All @@ -17,6 +23,7 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:
-------
xarray.DataArray
An xarray DataArray containing the computed displacement.
"""
_validate_time_dimension(data)
result = data.diff(dim="time")
Expand All @@ -25,8 +32,11 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:


def compute_velocity(data: xr.DataArray) -> xr.DataArray:
"""Compute the velocity between consecutive positions
of each keypoint for each individual across time.
"""Compute the velocity in cartesian ``(x,y)`` coordinates.
Velocity is the first derivative of position for each keypoint
and individual across time. It's computed using numerical differentiation
and assumes equidistant time spacing.
Parameters
----------
Expand All @@ -38,17 +48,16 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:
xarray.DataArray
An xarray DataArray containing the computed velocity.
Notes
-----
This function computes velocity using numerical differentiation
and assumes equidistant time spacing.
"""
return _compute_approximate_derivative(data, order=1)


def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
"""Compute the acceleration between consecutive positions
of each keypoint for each individual across time.
"""Compute acceleration in cartesian ``(x,y)`` coordinates.
Acceleration represents the second derivative of position for each keypoint
and individual across time. It's computed using numerical differentiation
and assumes equidistant time spacing.
Parameters
----------
Expand All @@ -60,19 +69,16 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
xarray.DataArray
An xarray DataArray containing the computed acceleration.
Notes
-----
This function computes acceleration using numerical differentiation
and assumes equidistant time spacing.
"""
return _compute_approximate_derivative(data, order=2)


def _compute_approximate_derivative(
data: xr.DataArray, order: int
) -> xr.DataArray:
"""Compute velocity or acceleration using numerical differentiation,
assuming equidistant time spacing.
"""Compute the derivative using numerical differentiation.
This assumes equidistant time spacing.
Parameters
----------
Expand All @@ -86,6 +92,7 @@ def _compute_approximate_derivative(
-------
xarray.DataArray
An xarray DataArray containing the derived variable.
"""
if not isinstance(order, int):
raise log_error(
Expand Down Expand Up @@ -119,6 +126,7 @@ def _validate_time_dimension(data: xr.DataArray) -> None:
------
ValueError
If the input data does not contain a ``time`` dimension.
"""
if "time" not in data.dims:
raise log_error(
Expand Down
28 changes: 16 additions & 12 deletions movement/filtering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Functions for filtering and interpolating pose tracks in xarray datasets."""

import logging
from datetime import datetime
from functools import wraps
Expand All @@ -7,9 +9,9 @@


def log_to_attrs(func):
"""
Decorator that logs the operation performed by the wrapped function
and appends the log entry to the xarray.Dataset's "log" attribute.
"""Log the operation performed by the wrapped function.
This decorator appends log entries to the xarray.Dataset's "log" attribute.
For the decorator to work, the wrapped function must accept an
xarray.Dataset as its first argument and return an xarray.Dataset.
"""
Expand Down Expand Up @@ -37,18 +39,18 @@ def wrapper(*args, **kwargs):


def report_nan_values(ds: xr.Dataset, ds_label: str = "dataset"):
"""
Report the number and percentage of points that are NaN for each individual
and each keypoint in the provided dataset.
"""Report the number and percentage of points that are NaN.
Numbers are reported for each individual and keypoint in the dataset.
Parameters
----------
ds : xarray.Dataset
Dataset containing pose tracks, confidence scores, and metadata.
ds_label : str
Label to identify the dataset in the report. Default is "dataset".
"""
"""
# Compile the report
nan_report = f"\nMissing points (marked as NaN) in {ds_label}:"
for ind in ds.individuals.values:
Expand Down Expand Up @@ -77,8 +79,7 @@ def interpolate_over_time(
max_gap: Union[int, None] = None,
print_report: bool = True,
) -> Union[xr.Dataset, None]:
"""
Fill in NaN values by interpolating over the time dimension.
"""Fill in NaN values by interpolating over the time dimension.
Parameters
----------
Expand All @@ -100,6 +101,7 @@ def interpolate_over_time(
ds_interpolated : xr.Dataset
The provided dataset (ds), where NaN values have been
interpolated over using the parameters provided.
"""
ds_interpolated = ds.copy()
position_interpolated = ds.position.interpolate_na(
Expand All @@ -118,9 +120,10 @@ def filter_by_confidence(
threshold: float = 0.6,
print_report: bool = True,
) -> Union[xr.Dataset, None]:
"""
Drop all points where the associated confidence value falls below a
user-defined threshold.
"""Drop all points below a certain confidence threshold.
Position points with an associated confidence value below the threshold are
converted to NaN.
Parameters
----------
Expand Down Expand Up @@ -150,6 +153,7 @@ def filter_by_confidence(
datasets and does not have the same meaning across pose estimation
frameworks. We advise users to inspect the confidence values
in their dataset and adjust the threshold accordingly.
"""
ds_thresholded = ds.copy()
ds_thresholded.update(
Expand Down
Loading

0 comments on commit 2744de2

Please sign in to comment.