Skip to content

themachinefan/torch-geometric-median

Repository files navigation

torch-geometric-median

ci PyPI

A simplified version of the geom-median Python library, updated to be higher performance on Pytorch and with full type-hinting. Thanks to @themachinefan!

Installation

pip install torch-geometric-median

Usage

This library exports a single function, geometric_median, which takes a tensor of shape (N, D) where N is the number of samples, and D is the size of each sample, and returns the geometric median of the points in the tensor .

from torch_geometric_median import geometric_median

# Create a tensor of points
points = torch.tensor([
    [0.0, 0.0],
    [1.0, 1.0],
    [2.0, 2.0],
    [3.0, 3.0],
    [4.0, 4.0],
])

# Compute the geometric median
median = geometric_median(points).median

Backprop

Like the original geom-median library, this library supports backpropagation through the geometric median computation.

median = geometric_median(points).median
torch.linalg.norm(out.median).backward()
# The gradient of the median with respect to the input points is now in `points.grad`

Extra options

The geometric_median function also supports a few extra options:

  • maxiter: The maximum number of iterations to run the optimization for. Default is 100.
  • ftol: If objective value does not improve by at least this ftol fraction, terminate the algorithm. Default 1e-20.
  • weights: A tensor of shape (N,) containing the weights for each point, where N is the number of samples. Default is None, which means all points are weighted equally.
  • show_progress: If True, show a progress bar for the optimization. Default is False.
  • log_objective_values: If True, log the objective value at each iteration under the key objective_values_log. Default is False.
median = geometric_median(
    points,
    maxiter=1000,
    ftol=1e-10,
    weights=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]),
    show_progress=True,
    log_objective_values=True
).median

Why does this library exist?

It appears that the original geom-median library is no longer maintained, and as pointed out by @themachinefan, the original library is not very performant on Pytorch. This library is a repackaging of @themachinefan's improvements to the original geom-median library, simplying the code to just support pytorch, improving torch performance, and adding full type-hinting.

Acknowledgements

This library is a repackaging of the work done by the original geom-median library, and @themachinefan in their PR, and as such, all credit goes to these incredible authors. If you use this library, you should cite the original geom-median paper.

License

This library is licensed under a GPL license, as per the original geom-median library.

Contributing

Contributions are welcome! Please open an issue or a PR if you have any suggestions or improvements. This library uses PDM for dependency management, Ruff for linting, Pyright for type-checking, and Pytest for tests.

To contribute to the repo, first install dependencies with pdm install. Tests are run with pdm run pytest. Formatting is done with pdm run ruff format and linting with pdm run ruff lint. Type-checking is done with pdm run pyright. Please ensure that all tests pass, and that the code is formatted, linted, and type-checked before opening a PR.