In [1]:
import torch

In [3]:
"""
    1) Each node v ∈ V aggregates the representations of the nodes in its immediate neighborhood, into a single vector (Note that this aggregation
step depends on the representations generated at the previous iteration of the outer loop and the k = 0 (“base case”) representations are defined as the input node features)
    2) After aggregating
the neighboring feature vectors, GraphSAGE then concatenates the node’s current representation with the aggregated neighborhood vector ... torch.cat((node_features, neighbor_features), dim = 1)
    3) This concatenated vector is fed through a fully connected layer with nonlinear activation function σ, which transforms the representations to be used at the next step of the algorithm
"""

'\n    1) Each node v ∈ V aggregates the representations of the nodes in its immediate neighborhood, into a single vector (Note that this aggregation\nstep depends on the representations generated at the previous iteration of the outer loop and the k = 0 (“base case”) representations are defined as the input node features)\n    2) After aggregating\nthe neighboring feature vectors, GraphSAGE then concatenates the node’s current representation with the aggregated neighborhood vector ... torch.cat((node_features, neighbor_features), dim = 1)\n    3) This concatenated vector is fed through a fully connected layer with nonlinear activation function σ, which transforms the representations to be used at the next step of the algorithm\n'

In [141]:
x = torch.rand(5,2, requires_grad = True, dtype = torch.float32)
x

tensor([[0.9714, 0.9065],
        [0.6259, 0.3734],
        [0.7635, 0.8703],
        [0.1528, 0.7398],
        [0.6117, 0.7816]], requires_grad=True)

In [147]:
adj_list = torch.tensor(
                        [[1,2,3,-99],
                         [0,2,3,-99],
                         [0,3,4,-99],
                         [0,1,2,4],
                         [2,3,-99,-99]]
                         
                         )
adj_list

tensor([[  1,   2,   3, -99],
        [  0,   2,   3, -99],
        [  0,   3,   4, -99],
        [  0,   1,   2,   4],
        [  2,   3, -99, -99]])

In [232]:
def one_degree_hotter(adj_list):

    new_adj_list = torch.ones(adj_list.shape[0], torch.max(adj_list) * torch.max(adj_list))

    for i in range(0, adj_list.shape[0]): 
        cur_nodes = adj_list[i]
        cur_stack = torch.tensor([-99]) 
        for j in cur_nodes:
            if j != -99:
                nodes_to_add_in = adj_list[j]
                for n in nodes_to_add_in:
                    if n not in cur_stack and n not in cur_nodes and n != i:
                        cur_stack = torch.cat((cur_stack, n.unsqueeze(0)))
        ready_to_stack = torch.nn.functional.pad(cur_stack, (0, adj_list.shape[0] - cur_stack.shape[0]), value = -99).unsqueeze(0)
        try:
            final_stack = torch.cat((final_stack, ready_to_stack), dim = 0)
        except:
            final_stack = ready_to_stack

    return final_stack


In [234]:
test = one_degree_hotter(adj_list)
test


tensor([[-99,   4, -99, -99, -99],
        [-99,   4, -99, -99, -99],
        [-99,   1, -99, -99, -99],
        [-99, -99, -99, -99, -99],
        [-99,   0,   1, -99, -99]])

In [235]:
test2 = one_degree_hotter(test)
test2


tensor([[-99,   1, -99, -99, -99],
        [-99,   0, -99, -99, -99],
        [-99,   4, -99, -99, -99],
        [-99, -99, -99, -99, -99],
        [-99, -99, -99, -99, -99]])

In [183]:
for i in adj_list:
    for j in i:
        print(j)

tensor(1)
tensor(2)
tensor(3)
tensor(-99)
tensor(0)
tensor(2)
tensor(3)
tensor(-99)
tensor(0)
tensor(3)
tensor(4)
tensor(-99)
tensor(0)
tensor(1)
tensor(2)
tensor(4)
tensor(2)
tensor(3)
tensor(-99)
tensor(-99)


In [265]:
class Sage(torch.nn.Module):

    def __init__(self, embed_dim, feature_dim, num_classes, K): 
        super(Sage, self).__init__()
        '''weights is of shape [embed_dim (arbitrary), feature_dim * 2]'''
        self.w1 = torch.nn.Parameter(torch.rand(embed_dim, feature_dim * 2), requires_grad = True)
        self.w2 = torch.nn.Parameter(torch.rand(num_classes, embed_dim), requires_grad = True)
        self.K = K


    def forward(self, x, adj_list):

        # for degree in range(self.K):
        
        #     if degree == 1:

        #         adj_list = adj_list
            
        #     else:

        #         adj_list = one_degree_hotter(adj_list)

        for node in adj_list:
            cur_neigh_feats = torch.mean(torch.index_select(x, 0, node[node > 0]), dim = 0).unsqueeze(0)
            try:
                neigh_feats = torch.cat((neigh_feats, cur_neigh_feats), dim = 0)
            except:
                neigh_feats = cur_neigh_feats

        combined = torch.cat((x, neigh_feats), dim = 1)
        combined = torch.mm(self.w1, combined.t())
        out = torch.mm(self.w2, combined).t()

        return out

In [290]:
class sage_net(torch.nn.Module):

    def __init__(self, embed_dim, feature_dim, num_classes, K): 
        super(sage_net, self).__init__()
        self.s = Sage(embed_dim, feature_dim, num_classes, K)
        # self.fc = torch.nn.Linear(4, 1)

    def forward(self, x):

        x, adj_list = x[0], x[1]

        out = self.s(x, adj_list)
        out = torch.mean(out)

        return out

In [291]:
# torch.tensor([[3, 3, 0, 3, 1, 3, 2],
#         [3, 0, 3, 1, 3, 2, 3]]).t()




In [292]:
x = torch.tensor([[1.0000, 1.0000, 0.0396, 0.9787],
        [0.5366, 0.5700, 0.1209, 0.5313],
        [0.2313, 0.2818, 0.1446, 0.4948],
        [0.2336, 0.0585, 0.2096, 0.8928]], dtype = torch.float32, requires_grad = True)
y = torch.tensor([345.], dtype = torch.float32, requires_grad = True)

adj_list = torch.tensor([[1,3,-99],
            [0,2,3],
            [1,3,-99],
            [0,1,2]])



In [300]:
model = sage_net(embed_dim = 16, feature_dim = 4, num_classes = 1, K = 2)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1)
# y = torch.tensor([1,2,3,4,5], dtype = torch.float32, requires_grad = True).view(-1, 1)

input = (x, adj_list)

for i in range(0, 200):

    pred = model(input)

    loss = criterion(pred, y)

    print(loss.item())
    # print(pred.item())

    grad = torch.autograd.grad(outputs = loss, inputs = input[0], retain_graph = True)

    loss.backward()

    optimizer.step()

print(pred)

108985.1953125
104965.4609375
100576.9375
95755.4609375
90493.078125
84802.515625
78709.9765625
72253.6171875
65483.82421875
58464.1796875
51272.62890625
44002.859375
36765.76953125
29690.908203125
22927.86328125
16647.52734375
11043.125
6330.82275390625
2749.729736328125
560.896728515625
44.884437561035156
1497.08154296875
5219.64794921875
11508.3369140625
20631.5078125
32797.37890625
48103.99609375
66464.734375
87503.1953125
110420.15625
133863.046875
155882.59375
174108.5625
186224.4375
190602.4375
186776.28125
175489.609375
158354.84375
137368.5
114507.71875
91492.5859375
69690.75
50108.265625
33423.22265625
20033.669921875
10107.7841796875
3630.128662109375
442.8656311035156
281.5325012207031
2805.91162109375
7626.43603515625
14326.6064453125
22481.70703125
31674.201171875
41505.94140625
51607.4765625
61644.75390625
71323.375
80390.8515625
88636.90625
95892.625
102028.09375
106949.640625
110596.2734375
112936.140625
113962.8203125
113692.078125
112158.6796875
109414.09375
105524.5

In [46]:
try: del neigh_feats
except: pass


for node in adj_list:
    cur_neigh_feats = torch.mean(torch.index_select(x, 0, node[node > 0]), dim = 0).unsqueeze(0)
    try:
        neigh_feats = torch.cat((neigh_feats, cur_neigh_feats), dim = 0)
    except:
        neigh_feats = cur_neigh_feats

print(neigh_feats)


    

    # print("\n")

tensor([[0.5918, 0.3897],
        [0.6808, 0.2922],
        [0.5245, 0.2379],
        [0.5572, 0.3854],
        [0.6808, 0.2922]])


In [49]:
neigh_feats

tensor([[0.5480, 0.6868],
        [0.3765, 0.5615],
        [0.1179, 0.6393],
        [0.5240, 0.5442],
        [0.3765, 0.5615]])

In [50]:
x

tensor([[0.7804, 0.7554],
        [0.8911, 0.9375],
        [0.5990, 0.2697],
        [0.1539, 0.8532],
        [0.0819, 0.4254]])

In [53]:
cat = torch.cat((x, neigh_feats), dim = 1)
cat

tensor([[0.7804, 0.7554, 0.5480, 0.6868],
        [0.8911, 0.9375, 0.3765, 0.5615],
        [0.5990, 0.2697, 0.1179, 0.6393],
        [0.1539, 0.8532, 0.5240, 0.5442],
        [0.0819, 0.4254, 0.3765, 0.5615]])

In [66]:
cat.t()

tensor([[0.7804, 0.8911, 0.5990, 0.1539, 0.0819],
        [0.7554, 0.9375, 0.2697, 0.8532, 0.4254],
        [0.5480, 0.3765, 0.1179, 0.5240, 0.3765],
        [0.6868, 0.5615, 0.6393, 0.5442, 0.5615]])

In [86]:
l = torch.nn.Linear(4, 1)
l(cat)

tensor([[-0.3884],
        [-0.3401],
        [-0.3273],
        [-0.1176],
        [-0.1519]], grad_fn=<AddmmBackward>)

In [87]:
param = torch.rand(1, 4)
param

tensor([[0.7348, 0.4718, 0.1690, 0.1551]])

In [74]:
param.mm(cat.t())

tensor([[0.8810, 0.8597, 0.5320, 0.7657, 0.5579]])

In [93]:
t1 = torch.rand(128, 2866)
t2 = torch.rand(724, 2866)
# t3 = torch.rand(1)

# torch.cat((t1, t2, t3))

In [94]:
t1

tensor([[0.1119, 0.2839, 0.9802,  ..., 0.1941, 0.5874, 0.9189],
        [0.7242, 0.6031, 0.9000,  ..., 0.2070, 0.0143, 0.3608],
        [0.2858, 0.7533, 0.4856,  ..., 0.2480, 0.1502, 0.6680],
        ...,
        [0.5362, 0.1371, 0.0212,  ..., 0.9871, 0.5192, 0.6612],
        [0.3483, 0.7225, 0.5843,  ..., 0.5180, 0.6524, 0.9938],
        [0.5438, 0.2904, 0.3011,  ..., 0.1303, 0.5327, 0.6136]])

In [95]:
param = torch.rand(1, 2)
param

tensor([[0.0219, 0.1333]])

In [97]:
t1.mm(t2.t()).shape

torch.Size([128, 724])

In [14]:
from collections import defaultdict
# converts from adjacency matrix to adjacency list
def convert(a):
    adjList = defaultdict(list)
    for i in range(len(a)):
        for j in range(len(a[i])):
                       if a[i][j]== 1:
                           adjList[i].append(j)
    return adjList
  
# driver code
a =[[0, 0, 1], [0, 0, 1], [1, 1, 0]] # adjacency matrix
AdjList = convert(a)
print("Adjacency List:")
# print the adjacency list
for i in AdjList:
    print(i, end ="")
    for j in AdjList[i]:
        print(" -> {}".format(j), end ="")
    print()
   

Adjacency List:
0 -> 2
1 -> 2
2 -> 0 -> 1
