In [2]:
import torch
import dgl
import dgl.nn.pytorch as dglnn

In [13]:
num_nodes = 4

# Example node embeddings for a graph with 4 nodes and 3-dimensional embeddings
embeddings1 = torch.tensor([
    [1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0],
    [7.0, 8.0, 9.0],
    [10.0, 11.0, 12.0]
], dtype=torch.float32)

embeddings2 = torch.tensor([
    [1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0],
    [7.0, 8.0, 9.0],
    [10.0, 0, 12.0]
], dtype=torch.float32)

# Compute cosine similarity between corresponding node embeddings for demonstration
cosine_similarity = torch.nn.CosineSimilarity(dim=1)
out = cosine_similarity(embeddings1, embeddings2).unsqueeze(1)  # shape [4] -> [4, 1]

out

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [0.8176]])

In [14]:
# Create a diagonal graph (each node is isolated)
edges_diagonal = (torch.arange(num_nodes), torch.arange(num_nodes))
g_diagonal = dgl.graph(edges_diagonal, num_nodes=num_nodes)

# Assign the node embeddings to the graph
g_diagonal.ndata['h'] = embeddings1

g_diagonal.ndata['h']

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])

In [15]:
# Create a fully connected graph (complete graph)
src = torch.tensor([i for i in range(num_nodes) for j in range(num_nodes)])
dst = torch.tensor([j for i in range(num_nodes) for j in range(num_nodes)])
g_fully_connected = dgl.graph((src, dst), num_nodes=num_nodes)

# Assign the embeddings as node features for both graph
g_fully_connected.ndata['h'] = embeddings2

g_fully_connected.ndata['h']

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10.,  0., 12.]])

In [16]:
# Initialize AvgPooling layer
avg_pooling = dglnn.glob.AvgPooling()

# Apply AvgPooling to both graphs
avg_similarity_diagonal = avg_pooling(g_diagonal, out)
avg_similarity_fully_connected = avg_pooling(g_fully_connected, out)

print("Average Node Features (Diagonal Graph):", avg_similarity_diagonal)
print("Average Node Features (Fully Connected Graph):", avg_similarity_fully_connected)


Average Node Features (Diagonal Graph): tensor([[0.9544]])
Average Node Features (Fully Connected Graph): tensor([[0.9544]])
