In [1]:
# scatter_add

import torch
import torch.nn.functional as F
from typing import Optional, Tuple
import torch_scatter


# torch_scatter/utils.py
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand(other.size())
    return src


In [4]:
from torch_scatter import scatter_max

# value only
def scatter_max_raw(
    src: torch.Tensor,
    index: torch.Tensor,
    dim: int = -1,
    out: Optional[torch.Tensor] = None,
    dim_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

    index = broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_reduce(0, index=index, src=src, reduce="amax", include_self=False)
        
    else:
        return out.scatter_reduce(
            0, index=index, src=src, reduce="amax", include_self=False
        )

# test the function

src = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.float32)
index = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0], dtype=torch.int64)
dim = 0
out = None
dim_size = None

out1 = scatter_max(src, index, dim, out, dim_size)
out2 = scatter_max_raw(src, index, dim, out, dim_size)

print(out1)
print(out2)



(tensor([10.,  8.,  9.]), tensor([9, 7, 8]))
tensor([10.,  8.,  9.])
