In [1]:
import torch

In [None]:
k denotes the current step in the outer loop (or the depth of the search)
h_k denotes a node’s representation at this step:

In [None]:
"""
    1) Each node v ∈ V aggregates the representations of the nodes in its immediate neighborhood, into a single vector (Note that this aggregation
step depends on the representations generated at the previous iteration of the outer loop and the k = 0 (“base case”) representations are defined as the input node features)
    2) After aggregating
the neighboring feature vectors, GraphSAGE then concatenates the node’s current representation with the aggregated neighborhood vector ... torch.cat((node_features, neighbor_features), dim = 1)
    3) This concatenated vector is fed through a fully connected layer with nonlinear activation function σ, which transforms the representations to be used at the next step of the algorithm
"""

In [6]:
x = torch.rand(5,2)
x

tensor([[0.7804, 0.7554],
        [0.8911, 0.9375],
        [0.5990, 0.2697],
        [0.1539, 0.8532],
        [0.0819, 0.4254]])

In [15]:
adj_list = torch.tensor(
                        [[1,2,3,-99],
                         [0,2,3,-99],
                         [0,3,4,-99],
                         [0,1,2,4],
                         [2,3,-99,-99]]
                         
                         )
adj_list

tensor([[  1,   2,   3, -99],
        [  0,   2,   3, -99],
        [  0,   3,   4, -99],
        [  0,   1,   2,   4],
        [  2,   3, -99, -99]])

In [83]:
class Sage(torch.nn.Module):

    def __init__(self, feature_dim): 
        super(Sage, self).__init__()
        self.w = torch.rand(1, feature_dim)


    def forward(self, x, adj_list):
        for node in adj_list:
            cur_neigh_feats = torch.mean(torch.index_select(x, 0, node[node > 0]), dim = 0).unsqueeze(0)
            try:
                neigh_feats = torch.cat((neigh_feats, cur_neigh_feats), dim = 0)
            except:
                neigh_feats = cur_neigh_feats

        cat = torch.cat((x, neigh_feats), dim = 1)
        print("weight: ", self.w)
        print("cat: ", cat)
        print("cat.t()", cat.t())
        out = self.w.mm(cat.t())

        return out



In [84]:
s = Sage(2 * 2)
s(x, adj_list)

weight:  tensor([[0.1129, 0.8911, 0.6820, 0.4293]])
cat:  tensor([[0.7804, 0.7554, 0.5480, 0.6868],
        [0.8911, 0.9375, 0.3765, 0.5615],
        [0.5990, 0.2697, 0.1179, 0.6393],
        [0.1539, 0.8532, 0.5240, 0.5442],
        [0.0819, 0.4254, 0.3765, 0.5615]])
cat.t() tensor([[0.7804, 0.8911, 0.5990, 0.1539, 0.0819],
        [0.7554, 0.9375, 0.2697, 0.8532, 0.4254],
        [0.5480, 0.3765, 0.1179, 0.5240, 0.3765],
        [0.6868, 0.5615, 0.6393, 0.5442, 0.5615]])


tensor([[1.4299, 1.4338, 0.6628, 1.3687, 0.8861]])

In [48]:
try: del neigh_feats
except: pass


for node in adj_list:
    cur_neigh_feats = torch.mean(torch.index_select(x, 0, node[node > 0]), dim = 0).unsqueeze(0)
    try:
        neigh_feats = torch.cat((neigh_feats, cur_neigh_feats), dim = 0)
    except:
        neigh_feats = cur_neigh_feats

print(neigh_feats)


    

    # print("\n")

tensor([[0.5480, 0.6868],
        [0.3765, 0.5615],
        [0.1179, 0.6393],
        [0.5240, 0.5442],
        [0.3765, 0.5615]])


In [49]:
neigh_feats

tensor([[0.5480, 0.6868],
        [0.3765, 0.5615],
        [0.1179, 0.6393],
        [0.5240, 0.5442],
        [0.3765, 0.5615]])

In [50]:
x

tensor([[0.7804, 0.7554],
        [0.8911, 0.9375],
        [0.5990, 0.2697],
        [0.1539, 0.8532],
        [0.0819, 0.4254]])

In [53]:
cat = torch.cat((x, neigh_feats), dim = 1)
cat

tensor([[0.7804, 0.7554, 0.5480, 0.6868],
        [0.8911, 0.9375, 0.3765, 0.5615],
        [0.5990, 0.2697, 0.1179, 0.6393],
        [0.1539, 0.8532, 0.5240, 0.5442],
        [0.0819, 0.4254, 0.3765, 0.5615]])

In [66]:
cat.t()

tensor([[0.7804, 0.8911, 0.5990, 0.1539, 0.0819],
        [0.7554, 0.9375, 0.2697, 0.8532, 0.4254],
        [0.5480, 0.3765, 0.1179, 0.5240, 0.3765],
        [0.6868, 0.5615, 0.6393, 0.5442, 0.5615]])

In [86]:
l = torch.nn.Linear(4, 1)
l(cat)

tensor([[-0.3884],
        [-0.3401],
        [-0.3273],
        [-0.1176],
        [-0.1519]], grad_fn=<AddmmBackward>)

In [87]:
param = torch.rand(1, 4)
param

tensor([[0.7348, 0.4718, 0.1690, 0.1551]])

In [74]:
param.mm(cat.t())

tensor([[0.8810, 0.8597, 0.5320, 0.7657, 0.5579]])

In [93]:
t1 = torch.rand(128, 2866)
t2 = torch.rand(724, 2866)
# t3 = torch.rand(1)

# torch.cat((t1, t2, t3))

In [94]:
t1

tensor([[0.1119, 0.2839, 0.9802,  ..., 0.1941, 0.5874, 0.9189],
        [0.7242, 0.6031, 0.9000,  ..., 0.2070, 0.0143, 0.3608],
        [0.2858, 0.7533, 0.4856,  ..., 0.2480, 0.1502, 0.6680],
        ...,
        [0.5362, 0.1371, 0.0212,  ..., 0.9871, 0.5192, 0.6612],
        [0.3483, 0.7225, 0.5843,  ..., 0.5180, 0.6524, 0.9938],
        [0.5438, 0.2904, 0.3011,  ..., 0.1303, 0.5327, 0.6136]])

In [95]:
param = torch.rand(1, 2)
param

tensor([[0.0219, 0.1333]])

In [97]:
t1.mm(t2.t()).shape

torch.Size([128, 724])

In [14]:
from collections import defaultdict
# converts from adjacency matrix to adjacency list
def convert(a):
    adjList = defaultdict(list)
    for i in range(len(a)):
        for j in range(len(a[i])):
                       if a[i][j]== 1:
                           adjList[i].append(j)
    return adjList
  
# driver code
a =[[0, 0, 1], [0, 0, 1], [1, 1, 0]] # adjacency matrix
AdjList = convert(a)
print("Adjacency List:")
# print the adjacency list
for i in AdjList:
    print(i, end ="")
    for j in AdjList[i]:
        print(" -> {}".format(j), end ="")
    print()
   

Adjacency List:
0 -> 2
1 -> 2
2 -> 0 -> 1
