In [27]:
# scatter_add

import torch
import torch.nn.functional as F
import torch_scatter

input_list = [
    torch.rand(6),
    torch.rand(6, 2),
    torch.rand(6, 3),
]

# input
for x in input_list:
    index = torch.tensor([0, 2, 1, 0, 2, 1])

    # scatter_add
    out1 = torch_scatter.scatter_add(x, index, dim=0)
    print(out1)

    # scatter_add with raw torch
    out2 = torch.zeros_like(out1)
    out2.index_add_(dim=0, index=index, source=x)
    print(out2)
    
    assert torch.allclose(out1, out2)


tensor([0.1580, 1.3217, 0.8653])
tensor([0.1580, 1.3217, 0.8653])
tensor([[1.2658, 1.3989],
        [0.9311, 1.2791],
        [0.2845, 0.6256]])
tensor([[1.2658, 1.3989],
        [0.9311, 1.2791],
        [0.2845, 0.6256]])
tensor([[0.9256, 0.7318, 0.6759],
        [1.7160, 1.1216, 1.3094],
        [1.1963, 1.3095, 0.3293]])
tensor([[0.9256, 0.7318, 0.6759],
        [1.7160, 1.1216, 1.3094],
        [1.1963, 1.3095, 0.3293]])


In [28]:
# scatter_mean

import torch
import torch.nn.functional as F
import torch_scatter

input_list = [
    torch.rand(6),
    torch.rand(6, 2),
    torch.rand(6, 3),
]

# input
for x in input_list:
    index = torch.tensor([0, 2, 1, 0, 2, 1])

    # scatter_mean
    out1 = torch_scatter.scatter_mean(x, index, dim=0)
    print(out1)

    # scatter_mean with raw torch
    out2 = torch.zeros_like(out1)
    count = torch.zeros(out1.size(0), dtype=torch.float32, device=x.device)
    
    out2.index_add_(dim=0, index=index, source=x)
    count.index_add_(0, index, torch.ones_like(index, dtype=torch.float32))
    
    out2 = out2 / count.unsqueeze(-1) if x.dim() > 1 else out2 / count
    print(out2)

    assert torch.allclose(out1, out2)


tensor([0.6122, 0.0702, 0.2854])
tensor([0.6122, 0.0702, 0.2854])
tensor([[0.4387, 0.4480],
        [0.6411, 0.9629],
        [0.4194, 0.4044]])
tensor([[0.4387, 0.4480],
        [0.6411, 0.9629],
        [0.4194, 0.4044]])
tensor([[0.4736, 0.8423, 0.2250],
        [0.3617, 0.5037, 0.3310],
        [0.5780, 0.4939, 0.3144]])
tensor([[0.4736, 0.8423, 0.2250],
        [0.3617, 0.5037, 0.3310],
        [0.5780, 0.4939, 0.3144]])


In [29]:
# scatter_max

import torch
import torch.nn.functional as F
import torch_scatter

input_list = [
    torch.rand(6),
    torch.rand(6, 2),
    torch.rand(6, 3),
]

# input
for x in input_list:
    index = torch.tensor([0, 2, 1, 0, 2, 1])

    # scatter_max
    out1 = torch_scatter.scatter_max(x, index, dim=0)[0]
    print(out1)

    # scatter_max with raw torch using scatter_reduce_
    out2 = torch.full_like(out1, float("-inf"))

    # Expand index to match the dimensions of x
    if x.dim() > 1:
        expanded_index = index.unsqueeze(-1).expand_as(x)
    else:
        expanded_index = index

    out2.scatter_reduce_(
        0, index=expanded_index, src=x, reduce="amax", include_self=False
    )
    print(out2)

    assert torch.allclose(out1, out2)

tensor([0.2039, 0.4027, 0.4981])
tensor([0.2039, 0.4027, 0.4981])
tensor([[0.6035, 0.9495],
        [0.4197, 0.4375],
        [0.2978, 0.9932]])
tensor([[0.6035, 0.9495],
        [0.4197, 0.4375],
        [0.2978, 0.9932]])
tensor([[0.7330, 0.6019, 0.6966],
        [0.9969, 0.4755, 0.8426],
        [0.6177, 0.9917, 0.9389]])
tensor([[0.7330, 0.6019, 0.6966],
        [0.9969, 0.4755, 0.8426],
        [0.6177, 0.9917, 0.9389]])


In [30]:
# scatter_min
import torch
import torch.nn.functional as F
import torch_scatter

input_list = [
    torch.rand(6),
    torch.rand(6, 2),
    torch.rand(6, 3),
]

# input
for x in input_list:
    index = torch.tensor([0, 2, 1, 0, 2, 1])

    # scatter_min
    out1 = torch_scatter.scatter_min(x, index, dim=0)[0]
    print(out1)

    # scatter_min with raw torch using scatter_reduce_
    out2 = torch.full_like(out1, float("inf"))

    # Expand index to match the dimensions of x
    if x.dim() > 1:
        expanded_index = index.unsqueeze(-1).expand_as(x)
    else:
        expanded_index = index

    out2.scatter_reduce_(
        0, index=expanded_index, src=x, reduce="amin", include_self=False
    )
    print(out2)

    assert torch.allclose(out1, out2)

tensor([0.5641, 0.5191, 0.1730])
tensor([0.5641, 0.5191, 0.1730])
tensor([[0.3776, 0.3297],
        [0.0609, 0.1721],
        [0.0796, 0.6229]])
tensor([[0.3776, 0.3297],
        [0.0609, 0.1721],
        [0.0796, 0.6229]])
tensor([[0.1060, 0.0128, 0.0849],
        [0.2093, 0.2405, 0.3260],
        [0.0378, 0.4159, 0.0293]])
tensor([[0.1060, 0.0128, 0.0849],
        [0.2093, 0.2405, 0.3260],
        [0.0378, 0.4159, 0.0293]])


In [33]:
# scatter_mul
import torch
import torch.nn.functional as F
import torch_scatter

input_list = [
    torch.rand(6),
    torch.rand(6, 2),
    torch.rand(6, 3),
]

# input
for x in input_list:
    index = torch.tensor([0, 2, 1, 0, 2, 1])

    # scatter_mul
    out1 = torch_scatter.scatter_mul(x, index, dim=0)
    print(out1)
    
    # scatter_mul with raw torch
    out2 = torch.ones_like(out1)

    # Expand index to match the dimensions of x
    if x.dim() > 1:
        expanded_index = index.unsqueeze(-1).expand_as(x)
    else:
        expanded_index = index

    out2.scatter_reduce_(
        0, index=expanded_index, src=x, reduce="prod", include_self=False
    )
    print(out2)   
    
    assert torch.allclose(out1, out2)


tensor([0.5676, 0.6368, 0.3176])
tensor([0.5676, 0.6368, 0.3176])
tensor([[0.3286, 0.0689],
        [0.5471, 0.1230],
        [0.0749, 0.0534]])
tensor([[0.3286, 0.0689],
        [0.5471, 0.1230],
        [0.0749, 0.0534]])
tensor([[0.2411, 0.3325, 0.0015],
        [0.3459, 0.0031, 0.6834],
        [0.2178, 0.4438, 0.0377]])
tensor([[0.2411, 0.3325, 0.0015],
        [0.3459, 0.0031, 0.6834],
        [0.2178, 0.4438, 0.0377]])


In [42]:
# scatter_softmax
import torch
import torch.nn.functional as F
import torch_scatter

input_list = [
    torch.rand(6),
    torch.rand(6, 2),
    torch.rand(6, 3),
]

# input
for x in input_list:
    index = torch.tensor([0, 2, 1, 0, 2, 1])

    # scatter_softmax
    out1 = torch_scatter.scatter_softmax(x, index, dim=0)
    print(out1)
    
    # scatter_softmax with raw torch using scatter_reduce_ in multiple steps
    out2 = torch.zeros_like(x)

    # Expand index to match the dimensions of x
    if x.dim() > 1:
        expanded_index = index.unsqueeze(-1).expand_as(x)
    else:
        expanded_index = index

    # Step 1: Compute exponentials
    exp_x = torch.exp(x)
    
    # Step 2: Sum of exponentials for normalization
    sum_exp = torch.zeros((index.max().item() + 1,) + x.shape[1:], dtype=x.dtype, device=x.device)
    sum_exp.scatter_reduce_(0, index=expanded_index, src=exp_x, reduce="sum", include_self=False)
    
    # Step 3: Normalize by the sum of exponentials
    out2 = exp_x / sum_exp[index]
    
    print(out2)
    
    assert torch.allclose(out1, out2)

tensor([0.6806, 0.5431, 0.5616, 0.3194, 0.4569, 0.4384])
tensor([0.6806, 0.5431, 0.5616, 0.3194, 0.4569, 0.4384])
tensor([[0.6091, 0.5467],
        [0.4640, 0.3468],
        [0.3156, 0.6954],
        [0.3909, 0.4533],
        [0.5360, 0.6532],
        [0.6844, 0.3046]])
tensor([[0.6091, 0.5467],
        [0.4640, 0.3468],
        [0.3156, 0.6954],
        [0.3909, 0.4533],
        [0.5360, 0.6532],
        [0.6844, 0.3046]])
tensor([[0.6641, 0.3011, 0.4325],
        [0.5956, 0.5894, 0.3968],
        [0.4486, 0.4295, 0.5433],
        [0.3359, 0.6989, 0.5675],
        [0.4044, 0.4106, 0.6032],
        [0.5514, 0.5705, 0.4567]])
tensor([[0.6641, 0.3011, 0.4325],
        [0.5956, 0.5894, 0.3968],
        [0.4486, 0.4295, 0.5433],
        [0.3359, 0.6989, 0.5675],
        [0.4044, 0.4106, 0.6032],
        [0.5514, 0.5705, 0.4567]])
