-
Notifications
You must be signed in to change notification settings - Fork 16
Labels
Description
I'd like to add partition
and argpartition
in this lib, that would help me in one of my PR in scikit-learn: scikit-learn/scikit-learn#32288 😄
I'd love to do it myself, but I need some guidance. Here is what I want to do:
- jax/numpy/cupy implement those functions, so just call them.
- in torch you have torch.topk that can be used to implement those functions in O(n) (instead of O(n log n) if you rely on the sort). Same in TensorFlow, you have tf.math.top_k. In Dask too: dask.array.topk
- the default case would be to use sort/argsort
My main question is: how should I implement the control-flow "if torch: do this; if dask: do that; [...]"?
Thank you! 🙏