In [1]:
import torch



In [154]:
import torch
from typing import Dict, List, Tuple


@torch.jit.script
def stable_arg_sort_long(arr):
    """Stable sort of long tensors.

    Note that Pytorch 1.5.0 does not have a stable sort implementation.
    Here we simply add a delta value between 0 and 1 (exclusive) and
    assuming we are using integers, call
    torch.argsort to get a stable sort."""
    delta = torch.linspace(0, 0.99, arr.shape[0])
    return torch.argsort(arr + delta)


# @torch.jit.script
def unique_with_counts(arr: torch.Tensor, grouped: Dict[int, int]):
    """
    Equivalent to `np.unqiue(x, return_counts=True)`
    
    :param arr:
    :param grouped:
    :return:
    """
    for x in arr:
        if x.item() not in grouped:
            grouped[x.item()] = 1
        else:
            grouped[x.item()] += 1

    counts = torch.zeros(len(grouped), dtype=torch.long)
    values = torch.empty(len(grouped), dtype=arr.dtype)
    for i, (k, v) in enumerate(grouped.items()):
        values[i] = k
        counts[i] = v
    a = torch.argsort(values)

    return values[a], counts[a]


@torch.jit.script
def _jit_scatter_group(x: torch.Tensor, idx: torch.Tensor, d: Dict[int, int]) -> Tuple[
    torch.Tensor, List[torch.Tensor]]:
    """
    Assume idx is a sorted index

    :param x:
    :param idx:
    :param d:
    :return:
    """
    arg = stable_arg_sort_long(idx)
    x = x[arg]
    groups, b = unique_with_counts(idx, d)
    i_a = 0
    arr_list = []
    for i_b in b:
        arr_list.append(x[i_a:i_a + i_b.item()])
        i_a += i_b.item()
    return groups, arr_list


def scatter_group(x: torch.Tensor, idx: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    """
    Group a tensor by indices. This is equivalent to successive applications of `x[torch.where(x == index)]`
    for all provided sorted indices

    Example:

    .. code-block:: python

        idx = torch.tensor([2, 2, 0, 1, 1, 1, 2])
        x = torch.tensor([0, 1, 2, 3, 4, 5, 6])

        uniq_sorted_idx, out = scatter_group(x, idx)

        # node the idx is sorted
        assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2])))

        # where idx == 0
        assert torch.all(torch.eq(out[1][0], torch.tensor([2])))

        # where idx == 1
        assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5])))

        # where idx == 2
        assert torch.all(torch.eq(out[1][2], torch.tensor([0, 1, 6])))

    :param x: tensor to group
    :param idx: indices
    :return: tuple of unique, sorted indices and a list of tensors corresponding to the groups
    """
    return _jit_scatter_group(x, idx, {})


idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float)
x.requires_grad = True

scatter_group(x, idx)[1

(tensor([0, 1, 2]),
 [tensor([0., 1., 2.], grad_fn=<SliceBackward>),
  tensor([3., 4., 5.], grad_fn=<SliceBackward>),
  tensor([6., 7., 8.], grad_fn=<SliceBackward>)])

In [143]:
from typing import List

a = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2], dtype=torch.float)
a.requires_grad = True

from pyrographnets.utils import scatter_group
from pyrographnets.utils.scatter_group import _jit_scatter_group
# def scatter_insert(arr: torch.Tensor, idx: torch.Tensor, insert_arrs: List[torch.Tensor]):
#     _, groups = scatter_group(arr, arr)
#     new_arrs = []
#     for i, group in enumerate(groups):
#         if i in idx:
#             new_arrs.append(insert_arrs[i])
#         new_arrs.append(group)
#     return torch.cat(new_arrs)
        

# scatter_insert(a, )

# this would be the graph idx

idx1 = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3])
x1 = torch.randn(10, 5)
idx2 = torch.tensor([0, 1, 2])
x2 = torch.randn(3, 5)


i1, groups1 = scatter_group(x1, idx1)
i2, groups2 = scatter_group(x2, idx2)

def dict_collate(d1, d2, collate_fn):
    d = {}
    for k, v in d1.items():
        if k not in d:
            d[k] = [v]
        else:
            d[k].append(v)
    for k, v in d2.items():
        if k not in d:
            d[k] = [v]
        else:
            d[k].append(v)
    return {k: collate_fn(v) for k, v in d.items()}



In [168]:
d1 = {k.item(): v for k, v in zip(i1, groups1)}
d2 = {k.item(): v for k, v in zip(i2, groups2)}
d = dict_merge(d1, d2, torch.cat)
keys = sorted(d)
torch.cat([d[k] for k in keys])

# collect node indices
node_idx = []
delta_edges = []
i = 0
for k in keys:
    if k in d2:
        delta_edges += [i] * d1[k].shape[0] + [i + d2[k].shape[0]] * d2[k].shape[0]
    else:
        delta_edges += [i] * d1[k].shape[0]
    i += 1
    node_idx += [k] * d[k].shape[0]

print(idx1)
print(idx2)
print(node_idx)
delta_edges


tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3])
tensor([0, 1, 2])
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3]


[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3]

In [256]:
edges = torch.randint(0, 10, (2, 20), dtype=torch.float)
edges.requires_grad = True
data = torch.randn(10)

class Mask(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.blocks = torch.nn.Sequential(
            torch.nn.Linear(10, 20),
            torch.nn.Sigmoid()
        )
        
    def forward(self, edges, x):
        out = self.blocks(x)
#         return out
        i = torch.where(torch.round(out) == 1)[0]
        return edges[:, i]

model = Mask()
model.zero_grad()
out = model(edges, data)
print(out.shape)
out.sum().backward()
p = list(model.parameters())[-1]
# model.blocks[0].weight.grad
edges.grad

torch.Size([2, 8])


tensor([[0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 0., 1., 0.,
         1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 0., 1., 0.,
         1., 0.]])

In [246]:
torch.Tensor([]).grad