In [71]:
from caldera.utils import scatter_coo
import torch
from typing import Dict


from caldera.utils import torch_scatter_group


k, v = torch_scatter_group(values, m._indices()[-1])
data = dict(zip(k.tolist(), v))
data

# @torch.jit.script
def _gather(data: Dict[int, torch.Tensor], indices: torch.Tensor):
    arrs = []
    for i in range(indices.shape[0]):
        key = indices[i].item()
        if key in data:
            arrs.append(data[key])
    if len(arrs):
        return torch.cat(arrs)
    return torch.Tensor([])

gather = torch.jit.trace(_gather, ({0: torch.FloatTensor([0.])}, torch.LongTensor([0])))

from typing import Tuple, List

# @torch.jit.script
def stable_arg_sort_long(arr):
    """Stable sort of long tensors.

    Note that Pytorch 1.5.0 does not have a stable sort implementation.
    Here we simply add a delta value between 0 and 1 (exclusive) and
    assuming we are using integers, call torch.argsort to get a stable
    sort.
    """
    if not (arr.dtype == torch.long or arr.dtype == torch.int):
        raise ValueError("only torch.Long or torch.Int allowed")
    dim = -1
    if not dim == -1:
        raise ValueError("only last dimension sort is supported. Try reshaping tensor.")
    delta_shape = list(arr.shape)
    delta_shape[dim] = 1
    mn = 0.
    mx = 0.9
    delta = torch.linspace(mn, mx, arr.shape[dim])
    delta = delta.repeat(delta_shape)
    return torch.argsort(arr + delta, dim=dim)

traced_arg_sort_long = torch.jit.trace(stable_arg_sort_long, (torch.randint(1, 10, (10,))))

def torch_scatter_group(
    x: torch.Tensor, idx: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    """Group a tensor by indices. This is equivalent to successive applications
    of `x[torch.where(x == index)]` for all provided sorted indices.

    Example:

    .. code-block:: python

        idx = torch.tensor([2, 2, 0, 1, 1, 1, 2])
        x = torch.tensor([0, 1, 2, 3, 4, 5, 6])

        uniq_sorted_idx, out = scatter_group(x, idx)

        # node the idx is sorted
        assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2])))

        # where idx == 0
        assert torch.all(torch.eq(out[1][0], torch.tensor([2])))

        # where idx == 1
        assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5])))

        # where idx == 2
        assert torch.all(torch.eq(out[1][2], torch.tensor([0, 1, 6])))

    :param x: tensor to group
    :param idx: indices
    :return: tuple of unique, sorted indices and a list of tensors corresponding to the groups
    """
    arg = traced_arg_sort(idx)
    x = x[arg]
    groups, b = torch.unique(idx, return_counts=True)
    i_a = 0
    arr_list = []
    for i_b in b:
        arr_list.append(x[i_a : i_a + i_b.item()])
        i_a += i_b.item()
    return groups, arr_list

traced_torch_scatter_group = torch.jit.trace(torch_scatter_group, (torch.randn(10), torch.randint(0, 10, (10,))))

values = torch.randn(10000)
idx = torch.randint(0, 10, (values.shape[0],))


%timeit -n100 torch_scatter_group(values, idx)
%timeit -n100 traced_torch_scatter_group(values, idx)

  app.launch_new_instance()
  _force_outplace)


830 µs ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
794 µs ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [155]:
ij = torch.randint(1, 10, (4, 100))
values = torch.arange(ij.shape[1])

m = scatter_coo(ij, values)

k, v = traced_torch_scatter_group(m._values(), m._indices()[-1])
data = dict(zip(k.tolist(), v))

_gather(data, torch.LongTensor([0, 1, 2, 3]))

def where_is_in(idx: torch.LongTensor, a: torch.LongTensor):
    values = torch.arange(a.shape[-1])
    m = scatter_coo(a, values)
    k, v = torch_scatter_group(m._values(), m._indices()[-1])
    data = dict(zip(k.tolist(), v))
    return _gather(data, idx)

def other_where_is_in(a: torch.Tensor, b: torch.Tensor):
    result = torch.empty_like(a)
    for i in range(a.shape[0]):
        _a = a[i]
        for j in range(b.shape[0]):
            _b = b[j]
            if _a == _b:
                result[i] = j
    return result

a = torch.randint(1, 10, (10,))
b = torch.randint(1, 10, (10,))

traced_other_where_is_in = torch.jit.script(other_where_is_in) # (a, b))


a = torch.randint(1, 10, (10,))
b = torch.randint(1, 10, (10,))
print('calling')
print(a)
c = traced_other_where_is_in(a, b)
print(c)

a = torch.randint(1, 10, (100,))
b = torch.randint(1, 10, (100,))
print('calling')
print(a)
c = traced_other_where_is_in(a, b)
print(c)
%timeit -n100 where_is_in(a, b)
%timeit -n10 other_where_is_in(a, b)
%timeit -n10 traced_other_where_is_in(a, b)
print(a)
print(b)
where_is_in(a, b)

calling
tensor([5, 9, 3, 7, 2, 5, 4, 8, 6, 8])
tensor([              1, 140634073316320,               2,               6,
                      9,               1,               0,               8,
         94206487370624,               8])
calling
tensor([9, 9, 6, 3, 4, 8, 4, 4, 9, 6, 7, 7, 3, 8, 5, 2, 9, 5, 6, 1, 6, 9, 2, 8,
        7, 8, 1, 4, 8, 1, 1, 2, 7, 5, 8, 3, 9, 2, 7, 9, 5, 4, 6, 6, 8, 8, 4, 4,
        6, 8, 5, 7, 3, 4, 6, 2, 8, 7, 3, 2, 6, 1, 2, 4, 2, 5, 7, 7, 2, 5, 1, 6,
        8, 8, 3, 6, 6, 3, 5, 7, 9, 2, 8, 5, 8, 9, 5, 7, 4, 3, 4, 5, 3, 2, 3, 3,
        2, 5, 6, 3])
tensor([93, 93, 89, 80, 98, 87, 98, 98, 93, 89, 99, 99, 80, 87, 76, 84, 93, 76,
        89, 97, 89, 93, 84, 87, 99, 87, 97, 98, 87, 97, 97, 84, 99, 76, 87, 80,
        93, 84, 99, 93, 76, 98, 89, 89, 87, 87, 98, 98, 89, 87, 76, 99, 80, 98,
        89, 84, 87, 99, 80, 84, 89, 97, 84, 98, 84, 76, 99, 99, 84, 76, 97, 89,
        87, 87, 80, 89, 89, 80, 76, 99, 93, 84, 87, 76, 87, 93, 76, 99, 98, 80,
        9

tensor([ 4, 12, 17,  ..., 34, 67, 80])

In [193]:
# a = torch.randint(1, 20, (10,))
# b = torch.randint(1, 20, (10,))

# k, v = torch_scatter_group(a, b)

# data = dict(zip(k.tolist(), v))

# nodes = torch.tensor([1, 2, 3, 4])

# %timeit -n100 _gather(data, nodes)
# %timeit -n100 long_isin(a, b)
a = torch.tensor([0, 1, 2, 3, 4])
b = torch.tensor([
    [0, 1, 2, 3],
    [3, 3, 3, 4]
])

k, v = torch_scatter_group(a.repeat(2), b.flatten())



In [70]:
import torch_sparse

m = scatter_coo(ij, values)
m
# from torch_sparse import SparseTensor?

# SparseTensor()
# torch_sparse.index_select(m, 0, torch.LongTensor([1]))

RuntimeError: indices and values must have same nnz, but got nnz from indices: 100, nnz from values: 10000

In [7]:
from caldera.utils import torch_scatter_group

torch_scatter_group(m._values(), m._indices())

(tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]),
 [tensor([[ 0, 10, 23, 40, 53, 60, 69, 80, 85, 86, 91, 96,  1,  9, 11, 39, 43, 52,
           66, 67, 76, 93, 95,  6, 14, 18, 26, 42, 49, 56, 57, 59, 61, 72, 73, 77,
           78, 79, 89,  3,  7, 13, 19, 21, 45, 82,  2, 12, 16, 17, 34, 37, 44, 51,
           54, 55, 81, 94, 99,  8, 22, 24, 29, 30, 35, 36, 46, 47, 48, 75, 83, 84,
           97, 28, 31, 38, 68, 87, 90,  5, 20, 25, 27, 33, 41, 58, 62, 65, 74,  4,
           15, 32, 50, 63, 64, 70, 71, 88, 92, 98],
          [14, 16, 17, 24, 48, 90, 92,  1,  6, 11, 15, 30, 35, 37, 40, 55, 62, 69,
           74, 81, 85, 89,  3,  8, 19, 25, 26, 28, 39, 43, 52, 73, 79, 82, 97, 10,
           34, 38, 44, 45, 59, 61, 63, 68, 86, 93,  5, 12, 20, 23, 41, 47, 50, 60,
           64, 78, 95, 99, 21, 27, 36, 42, 46, 71, 84, 87, 94,  0,  9, 13, 22, 29,
           31, 32, 49, 54, 57, 67, 72, 75, 76, 80, 91, 96,  2,  4, 18, 33, 51, 56,
           77, 88, 98,  7, 53, 58, 65, 66, 70, 83]]),
  tensor([], size=(0, 100)