In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
from torch import Tensor

import torch_geometric.typing
import torch_scatter


def segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:
    if not torch_geometric.typing.WITH_TORCH_SCATTER:
        raise ImportError("'segment' requires the 'torch-scatter' package")
    return torch_scatter.segment_csr(src, ptr, reduce=reduce)

In [16]:
from typing import Optional

from torch import Tensor

from torch_geometric.utils import scatter# , segment
from torch_geometric.utils.num_nodes import maybe_num_nodes

def softmax(
    src: Tensor,
    index: Optional[Tensor] = None,
    ptr: Optional[Tensor] = None,
    num_nodes: Optional[int] = None,
    dim: int = 0,
) -> Tensor:

    if ptr is not None:
        dim = dim + src.dim() if dim < 0 else dim
        size = ([1] * dim) + [-1]
        count = ptr[1:] - ptr[:-1]
        ptr = ptr.view(size)
        src_max = segment(src.detach(), ptr, reduce='max')
        src_max = src_max.repeat_interleave(count, dim=dim)
        out = (src - src_max).exp()
        out_sum = segment(out, ptr, reduce='sum') + 1e-16
        out_sum = out_sum.repeat_interleave(count, dim=dim)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        print("N is ", N)
        src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max')
        print("src_max is ", src_max)
        out = src - src_max.index_select(dim, index)
        print("out is ", out)
        out = out.exp()
        print("out is ", out)
        out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + 1e-16
        print("out_sum is ", out_sum)
        out_sum = out_sum.index_select(dim, index)
        print("out_sum is ", out_sum)
    else:
        raise NotImplementedError

    return out / out_sum

In [19]:
src = torch.as_tensor([3,1,1,1,1,1,2,1])
index = torch.as_tensor([0,0,1,1,1,1,2,2])
ptr = torch.as_tensor([0,2,6,8])

In [20]:
softmax(src, index)

N is  3
src_max is  tensor([3, 1, 2])
out is  tensor([ 0, -2,  0,  0,  0,  0,  0, -1])
out is  tensor([1.0000, 0.1353, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.3679])
out_sum is  tensor([1.1353, 4.0000, 1.3679])
out_sum is  tensor([1.1353, 1.1353, 4.0000, 4.0000, 4.0000, 4.0000, 1.3679, 1.3679])


tensor([0.8808, 0.1192, 0.2500, 0.2500, 0.2500, 0.2500, 0.7311, 0.2689])

In [22]:
import numpy as np
(np.exp(3))/(np.exp(3)+np.exp(1))

0.8807970779778824