Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
227 lines (185 sloc) 8.66 KB
import torch
def sort_clusters(x, lab) :
r"""Sorts a list of points and labels to make sure that the clusters are contiguous in memory.
On the GPU, **contiguous memory accesses** are key to high performances.
By making sure that points in the same cluster are stored next
to each other in memory, this pre-processing routine allows
KeOps to compute block-sparse reductions with maximum efficiency.
Warning:
For unknown reasons, ``torch.bincount`` is much more efficient
on *unsorted* arrays of labels... so make sure not to call ``bincount``
on the output of this routine!
Args:
x ((M,D) Tensor or tuple/list of (M,..) Tensors): List of points :math:`x_i \in \mathbb{R}^D`.
lab ((M,) IntTensor): Vector of class labels :math:`\ell_i\in\mathbb{N}`.
Returns:
(M,D) Tensor or tuple/list of (M,..) Tensors, (M,) IntTensor:
Sorted **point cloud(s)** and **vector of labels**.
Example:
>>> x = torch.Tensor( [ [0.], [5.], [.4], [.3], [2.] ])
>>> lab = torch.IntTensor([ 0, 2, 0, 0, 1 ])
>>> x_sorted, lab_sorted = sort_clusters(x, lab)
>>> print(x_sorted)
tensor([[0.0000],
[0.4000],
[0.3000],
[2.0000],
[5.0000]])
>>> print(lab_sorted)
tensor([0, 0, 0, 1, 2], dtype=torch.int32)
"""
lab, perm = torch.sort(lab.view(-1))
if type(x) is tuple:
x_sorted = tuple( a[perm] for a in x )
elif type(x) is list:
x_sorted = list( a[perm] for a in x )
else:
x_sorted = x[perm]
return x_sorted, lab
def cluster_ranges(lab, Nlab=None) :
r"""Computes the ``[start,end)`` indices that specify clusters in a sorted point cloud.
If **lab** denotes a vector of labels :math:`\ell_i\in[0,C)`,
:func:`sort_clusters` allows us to sort our point clouds and make sure
that points that share the same label are stored next to each other in memory.
:func:`cluster_ranges` is simply there to give you the **slice indices**
that correspond to each of those :math:`C` classes.
Args:
x ((M,D) Tensor): List of points :math:`x_i \in \mathbb{R}^D`.
lab ((M,) IntTensor): Vector of class labels :math:`\ell_i\in\mathbb{N}`.
Keyword Args:
Nlab ((C,) IntTensor, optional): If you have computed it already,
you may specify the number of points per class through
this integer vector of length :math:`C`.
Returns:
(C,2) IntTensor:
Stacked array of :math:`[\\text{start}_k,\\text{end}_k)` indices in :math:`[0,M]`,
for :math:`k\in[0,C)`.
Example:
>>> x = torch.Tensor( [ [0.], [5.], [.4], [.3], [2.] ])
>>> lab = torch.IntTensor([ 0, 2, 0, 0, 1 ])
>>> x_sorted, lab_sorted = sort_clusters(x, lab)
>>> print(x_sorted)
tensor([[0.0000],
[0.4000],
[0.3000],
[2.0000],
[5.0000]])
>>> print(lab_sorted)
tensor([0, 0, 0, 1, 2], dtype=torch.int32)
>>> ranges_i = cluster_ranges(lab)
>>> print( ranges_i )
tensor([[0, 3],
[3, 4],
[4, 5]], dtype=torch.int32)
--> cluster 0 = x_sorted[0:3, :]
--> cluster 1 = x_sorted[3:4, :]
--> cluster 2 = x_sorted[4:5, :]
"""
if Nlab is None : Nlab = torch.bincount(lab).float()
pivots = torch.cat( (torch.Tensor([0.]).to(Nlab.device), Nlab.cumsum(0)) )
return torch.stack( (pivots[:-1], pivots[1:]) ).t().int()
def cluster_centroids(x, lab, Nlab=None, weights=None, weights_c=None) :
r"""Computes the (weighted) centroids of classes specified by a vector of labels.
If points :math:`x_i \in\mathbb{R}^D` are assigned to :math:`C` different classes
by the vector of integer labels :math:`\ell_i \in [0,C)`,
this function returns a collection of :math:`C` centroids
.. math::
c_k ~=~ \\frac{\sum_{i, \ell_i = k} w_i\cdot x_i}{\sum_{i, \ell_i=k} w_i},
where the weights :math:`w_i` are set to 1 by default.
Args:
x ((M,D) Tensor): List of points :math:`x_i \in \mathbb{R}^D`.
lab ((M,) IntTensor): Vector of class labels :math:`\ell_i\in\mathbb{N}`.
Keyword Args:
Nlab ((C,) IntTensor): Number of points per class. Recomputed if None.
weights ((N,) Tensor): Positive weights :math:`w_i` of each point.
weights_c ((C,) Tensor): Total weight of each class. Recomputed if None.
Returns:
(C,D) Tensor:
List of centroids :math:`c_k \in \mathbb{R}^D`.
Example:
>>> x = torch.Tensor([ [0.], [1.], [4.], [5.], [6.] ])
>>> lab = torch.IntTensor([ 0, 0, 1, 1, 1 ])
>>> weights = torch.Tensor([ .5, .5, 2., 1., 1. ])
>>> centroids = cluster_centroids(x, lab, weights=weights)
>>> print(centroids)
tensor([[0.5000],
[4.7500]])
"""
if Nlab is None : Nlab = torch.bincount(lab).float()
if weights is not None and weights_c is None:
weights_c = torch.bincount(lab, weights=weights).view(-1,1)
c = torch.zeros( (len(Nlab), x.shape[1]), dtype=x.dtype,device=x.device)
for d in range(x.shape[1]):
if weights is None:
c[:,d] = torch.bincount(lab,weights=x[:,d]) / Nlab
else :
c[:,d] = torch.bincount(lab,weights=x[:,d]*weights.view(-1)) / weights_c.view(-1)
return c
def cluster_ranges_centroids(x, lab, weights=None) :
r"""Computes the cluster indices and centroids of a (weighted) point cloud with labels.
If **x** and **lab** encode a cloud of points :math:`x_i\in\mathbb{R}^D`
with labels :math:`\ell_i\in[0,C)`, for :math:`i\in[0,M)`, this routine returns:
- Ranges :math:`[\\text{start}_k,\\text{end}_k)` compatible with
:func:`sort_clusters` for :math:`k\in[0,C)`.
- Centroids :math:`c_k` for each cluster :math:`k`, computed as barycenters
using the weights :math:`w_i \in \mathbb{R}_{>0}`:
.. math::
c_k = \\frac{\sum_{i, \ell_i=k} w_i\cdot \ell_i}{\sum_{i, \ell_i=k} w_i}
- Total weights :math:`\sum_{i, \ell_i=k} w_i`, for :math:`k\in[0,C)`.
The weights :math:`w_i` can be given through a vector **weights**
of size :math:`M`, and are set by default to 1 for all points in the cloud.
Args:
x ((M,D) Tensor): List of points :math:`x_i \in \mathbb{R}^D`.
lab ((M,) IntTensor): Vector of class labels :math:`\ell_i\in\mathbb{N}`.
Keyword Args:
weights ((M,) Tensor): Positive weights :math:`w_i` that can be used to compute
our barycenters.
Returns:
(C,2) IntTensor, (C,D) Tensor, (C,) Tensor:
**ranges** - Stacked array of :math:`[\\text{start}_k,\\text{end}_k)` indices in :math:`[0,M]`,
for :math:`k\in[0,C)`, compatible with the :func:`sort_clusters` routine.
**centroids** - List of centroids :math:`c_k \in \mathbb{R}^D`.
**weights_c** - Total weight of each cluster.
Example:
>>> x = torch.Tensor([ [0.], [.5], [1.], [2.], [3.] ])
>>> lab = torch.IntTensor([ 0, 0, 1, 1, 1 ])
>>> ranges, centroids, weights_c = cluster_ranges_centroids(x, lab)
>>> print(ranges)
tensor([[0, 2],
[2, 5]], dtype=torch.int32)
--> cluster 0 = x[0:2, :]
--> cluster 1 = x[2:5, :]
>>> print(centroids)
tensor([[0.2500],
[2.0000]])
>>> print(weights_c)
tensor([2., 3.])
>>> weights = torch.Tensor([ 1., .5, 1., 1., 10. ])
>>> ranges, centroids, weights_c = cluster_ranges_centroids(x, lab, weights=weights)
>>> print(ranges)
tensor([[0, 2],
[2, 5]], dtype=torch.int32)
--> cluster 0 = x[0:2, :]
--> cluster 1 = x[2:5, :]
>>> print(centroids)
tensor([[0.1667],
[2.7500]])
>>> print(weights_c)
tensor([1.5000, 12.0000])
"""
Nlab = torch.bincount(lab).float()
if weights is not None :
w_c = torch.bincount(lab, weights=weights).view(-1)
return cluster_ranges(lab, Nlab), cluster_centroids(x, lab, Nlab, weights=weights, weights_c=w_c), w_c
else :
return cluster_ranges(lab, Nlab), cluster_centroids(x, lab, Nlab), Nlab
def swap_axes(ranges) :
r"""
Swaps the ":math:`i`" and ":math:`j`" axes of a reduction's optional **ranges** parameter.
This function returns **None** if **ranges** is **None**,
and swaps the :math:`i` and :math:`j` arrays of indices otherwise.
"""
if ranges is None :
return None
else :
return (*ranges[3:6],*ranges[0:3])
You can’t perform that action at this time.