In [1]:
import torch

In [2]:
"""
    1) Each node v ∈ V aggregates the representations of the nodes in its immediate neighborhood, into a single vector (Note that this aggregation
step depends on the representations generated at the previous iteration of the outer loop and the k = 0 (“base case”) representations are defined as the input node features)
    2) After aggregating the neighboring feature vectors, GraphSAGE then concatenates the node’s current representation with the aggregated neighborhood vector ... torch.cat((node_features, neighbor_features), dim = 1)
    3) This concatenated vector is fed through a fully connected layer with nonlinear activation function σ, which transforms the representations to be used at the next step of the algorithm
"""

'\n    1) Each node v ∈ V aggregates the representations of the nodes in its immediate neighborhood, into a single vector (Note that this aggregation\nstep depends on the representations generated at the previous iteration of the outer loop and the k = 0 (“base case”) representations are defined as the input node features)\n    2) After aggregating the neighboring feature vectors, GraphSAGE then concatenates the node’s current representation with the aggregated neighborhood vector ... torch.cat((node_features, neighbor_features), dim = 1)\n    3) This concatenated vector is fed through a fully connected layer with nonlinear activation function σ, which transforms the representations to be used at the next step of the algorithm\n'

In [3]:
x = torch.rand(5,2, requires_grad = True, dtype = torch.float32)
adj_list = torch.tensor(
                        [[1,2,3,-99],
                         [0,3,-99,-99],
                         [0,3,4,-99],
                         [0,1,2,4],
                         [2,3,-99,-99]]
                         
                         )

In [4]:
def one_degree_hotter(adj_list, already_visited):

    # new_adj_list = torch.ones(adj_list.shape[0], torch.max(adj_list) * torch.max(adj_list))

    for i in range(0, adj_list.shape[0]): 
        cur_nodes = adj_list[i]
        cur_stack = torch.tensor([-99]) 
        for j in cur_nodes:
            if j != -99:
                nodes_to_add_in = adj_list[j]
                for n in nodes_to_add_in:
                    if n not in cur_stack and n not in cur_nodes and n != i and n not in already_visited[i]:
                        cur_stack = torch.cat((cur_stack, n.unsqueeze(0)))
        ready_to_stack = torch.nn.functional.pad(cur_stack, (0, adj_list.shape[0] - cur_stack.shape[0]), value = -99).unsqueeze(0)
        try:
            final_stack = torch.cat((final_stack, ready_to_stack), dim = 0)
        except:
            final_stack = ready_to_stack

    return final_stack


In [5]:
class Sage(torch.nn.Module):

    def __init__(self, embed_dim, feature_dim, num_classes, K): 
        super(Sage, self).__init__()
        '''weights is of shape [embed_dim (arbitrary), feature_dim * 2]'''
        self.w2 = torch.nn.Parameter(torch.rand(num_classes, embed_dim), requires_grad = True)
        self.relu = torch.nn.ReLU()
        self.K = K

        self.params  = torch.nn.ParameterDict({})

        for i in range(K):
            if i == 0:
                feature_dim = feature_dim * 2
            self.params[str(i)] = torch.nn.Parameter(torch.rand(embed_dim, feature_dim), requires_grad = True)
            feature_dim = embed_dim * 2



    def forward(self, x, adj_list):
        
        already_visited = adj_list

        for degree in range(self.K):

            if degree == 0:

                for node in adj_list:
                    cur_neigh_feats = torch.mean(torch.index_select(x, 0, node[node >= 0]), dim = 0).unsqueeze(0)
                    try:
                        neigh_feats = torch.cat((neigh_feats, cur_neigh_feats), dim = 0)
                    except:
                        neigh_feats = cur_neigh_feats

                x = torch.cat((x, neigh_feats), dim = 1)
                x = torch.mm(self.params[str(degree)], x.t())
                x = self.relu(x).t()


            else:

                adj_list = one_degree_hotter(adj_list, already_visited)
                already_visited = torch.cat((already_visited, adj_list), dim = 1)

                for node in adj_list:
                    cur_neigh_feats = torch.mean(torch.index_select(x, 0, node[node >= 0]), dim = 0).unsqueeze(0)
                    try:
                        neigh_feats = torch.cat((neigh_feats, cur_neigh_feats), dim = 0)
                    except:
                        neigh_feats = cur_neigh_feats

                neigh_feats[neigh_feats != neigh_feats] = 0

                x = torch.cat((x, neigh_feats), dim = 1)
                x = torch.mm(self.params[str(degree)], x.t())
                x = self.relu(x).t()

            del neigh_feats

            # print(x.shape)

        # out = torch.mm(self.w2, x).t()

        return x

In [6]:
class sage_net(torch.nn.Module):

    def __init__(self, embed_dim, feature_dim, num_classes, K): 
        super(sage_net, self).__init__()
        self.s1 = Sage(embed_dim, feature_dim, num_classes, K)
        self.s2 = Sage(32, embed_dim, num_classes, K)
        # self.fc = torch.nn.Linear()

    def forward(self, x):

        x, adj_list = x[0], x[1]

        out = self.s1(x, adj_list)
        out = self.s2(out, adj_list)
        # print(out.shape)
        out = torch.mean(out)

        # print(out)

        return out

In [7]:
x = torch.tensor([[1.0000, 1.0000, 0.0396, 0.9787],
        [0.5366, 0.5700, 0.1209, 0.5313],
        [0.2313, 0.2818, 0.1446, 0.4948],
        [0.2336, 0.0585, 0.2096, 0.8928]], dtype = torch.float32, requires_grad = True)
y = torch.tensor([345.], dtype = torch.float32, requires_grad = True)

adj_list = torch.tensor([[1,3,-99],
            [0,2,3],
            [1,3,-99],
            [0,1,2]])



In [8]:
model = sage_net(embed_dim = 16, feature_dim = 4, num_classes = 1, K = 4)
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1)
lr = .1

input = (x, adj_list)

for i in range(0, 100):



    pred = model(input)

    loss = criterion(pred, y)

    print(loss.item())

    if loss < 300:
        print('ehre')
        lr = lr * .1
        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
        # break


    if loss < 100:
        print('ehre')
        lr = lr * .5
        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
        break
    

    loss.backward()

    optimizer.step()

print(pred)

124355240.0
20678108.0
2055595.625
81643.2421875
155.550537109375
ehre
89.9403076171875
ehre
ehre
tensor(255.0597, grad_fn=<MeanBackward0>)
  return F.l1_loss(input, target, reduction=self.reduction)
