In [11]:
import torch

# Create a toy graph with 4 nodes and 5 edges
edge_index = torch.tensor([[0, 1, 1, 2, 3],
                           [1, 2, 3, 0, 1]], dtype=torch.long)
edge_weight = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float)
edge_attr = torch.tensor([[1, 2, 3],
                          [4, 5, 6],
                          [7, 8, 9],
                          [10, 11, 12],
                          [13, 14, 15]], dtype=torch.float)

# Initialize incoming_edge_weighted_sum with zeros
num_nodes = edge_index.max().item() + 1
incoming_edge_weighted_sum = torch.zeros(num_nodes, 3)

# Reshape edge_index[1] to have an extra dimension
edge_index_reshaped = edge_index[1].view(-1, 1)

# Perform scatter_add operation
incoming_edge_weighted_sum.scatter_add_(0, edge_index_reshaped.expand_as(edge_attr), edge_weight.view(-1, 1) * edge_attr)

# Print the results
print("Original edge_index:\n", edge_index)
print("Original edge_index[1]:\n", edge_index[1])
print("Reshaped edge_index_reshaped:\n", edge_index_reshaped)
print("Original edge_attr:\n", edge_attr)
print("Original edge_weights:\n", edge_weight)
print('multiplied edge_weight * edge_attr:\n', edge_weight.view(-1, 1) * edge_attr)

print("Incoming_edge_weighted_sum after scatter_add:")
print(incoming_edge_weighted_sum)


Original edge_index:
 tensor([[0, 1, 1, 2, 3],
        [1, 2, 3, 0, 1]])
Original edge_index[1]:
 tensor([1, 2, 3, 0, 1])
Reshaped edge_index_reshaped:
 tensor([[1],
        [2],
        [3],
        [0],
        [1]])
Original edge_attr:
 tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.]])
Original edge_weights:
 tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000])
multiplied edge_weight * edge_attr:
 tensor([[0.1000, 0.2000, 0.3000],
        [0.8000, 1.0000, 1.2000],
        [2.1000, 2.4000, 2.7000],
        [4.0000, 4.4000, 4.8000],
        [6.5000, 7.0000, 7.5000]])
Incoming_edge_weighted_sum after scatter_add:
tensor([[4.0000, 4.4000, 4.8000],
        [6.6000, 7.2000, 7.8000],
        [0.8000, 1.0000, 1.2000],
        [2.1000, 2.4000, 2.7000]])
