In [1]:
import os
import networkx as nx
import pandas as pd
import numpy as np

import torch

## Create dataset

In [2]:
data_dir = os.path.expanduser('/media/oanaucs/Stuff/datasets/cora')

In [3]:
edgelist = pd.read_csv(os.path.join(data_dir, "cora.cites"), sep='\t', header=None, names=["target", "source"])
edgelist["label"] = "cites"

In [24]:
edge_index = [list(edgelist['source']), list(edgelist['target'])]

In [25]:
print(edge_index)

[[1033, 103482, 103515, 1050679, 1103960, 1103985, 1109199, 1112911, 1113438, 1113831, 1114331, 1117476, 1119505, 1119708, 1120431, 1123756, 1125386, 1127430, 1127913, 1128204, 1128227, 1128314, 1128453, 1128945, 1128959, 1128985, 1129018, 1129027, 1129573, 1129683, 1129778, 1130847, 1130856, 1131116, 1131360, 1131557, 1131752, 1133196, 1133338, 1136814, 1137466, 1152421, 1152508, 1153065, 1153280, 1153577, 1153853, 1153943, 1154176, 1154459, 116552, 12576, 128540, 132806, 135130, 141342, 141347, 148170, 15670, 1688, 175291, 178727, 18582, 190697, 190706, 1956, 197054, 198443, 198653, 206371, 210871, 229635, 231249, 248425, 249421, 254923, 259701, 259702, 263279, 263498, 265203, 273152, 27510, 28290, 286500, 287787, 28851, 289779, 289780, 289781, 307015, 335733, 33904, 33907, 35061, 38205, 387795, 415693, 41714, 427606, 44368, 45599, 46079, 46431, 486840, 48766, 503883, 503893, 513189, 54129, 54131, 56119, 561238, 568857, 573964, 573978, 574009, 574264, 574462, 575077, 575292, 575331, 

In [5]:
gnx = nx.from_pandas_edgelist(edgelist, edge_attr="label")
nx.set_node_attributes(gnx, "paper", "label")

In [6]:
degrees = gnx.degree
unique_degrees = np.unique(np.asarray([deg[1] for deg in degrees]))

In [7]:
adj_mat = torch.tensor(nx.to_numpy_matrix(gnx), dtype=torch.uint8)

In [8]:
feature_names = ["w_{}".format(ii) for ii in range(1433)]
column_names =  ['id'] + feature_names + ["subject"] 
node_data = pd.read_csv(os.path.join(data_dir, "cora.content"), sep='\t', header=None, names=column_names)
node_data['label'] = node_data['subject'].rank(method='dense', ascending=False).astype(int)

In [9]:
print(node_data)

           id  w_0  w_1  w_2  w_3  w_4  w_5  w_6  w_7  w_8  ...  w_1425  \
0       31336    0    0    0    0    0    0    0    0    0  ...       0   
1     1061127    0    0    0    0    0    0    0    0    0  ...       1   
2     1106406    0    0    0    0    0    0    0    0    0  ...       0   
3       13195    0    0    0    0    0    0    0    0    0  ...       0   
4       37879    0    0    0    0    0    0    0    0    0  ...       0   
...       ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...     ...   
2703  1128975    0    0    0    0    0    0    0    0    0  ...       0   
2704  1128977    0    0    0    0    0    0    0    0    0  ...       0   
2705  1128978    0    0    0    0    0    0    0    0    0  ...       0   
2706   117328    0    0    0    0    1    0    0    0    0  ...       0   
2707    24043    0    0    0    0    0    0    0    0    0  ...       0   

      w_1426  w_1427  w_1428  w_1429  w_1430  w_1431  w_1432  \
0          1       0       0       

In [10]:
# save some important info
num_classes = len(np.unique(node_data['label']))
node_num_features = len(feature_names)
batch_size = 10

In [11]:
target = pd.DataFrame(node_data['label'])
ids = pd.DataFrame(node_data['id'])
del node_data['id']
del node_data['subject']
del node_data['label']

In [12]:
print(num_classes, node_num_features)

7 1433


In [13]:
id_data = torch.Tensor(np.array(ids))
tensor_node_data = torch.Tensor(np.array(node_data))
target_data = torch.Tensor(np.array(target))
train_dataset = torch.utils.data.TensorDataset(id_data, tensor_node_data, target_data)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

## Create network

In [14]:
in_dim = node_num_features
out_dim = 1
w_hidden_dim = 10
T = 500
lr = 0.003
layers_dim = (64, 64, 64)


In [15]:
class Net(torch.nn.Module):
    def __init__(self, in_dim, out_dim, layers_dim=layers_dim):
        super(Net, self).__init__()
        
        self.num_layers = len(layers_dim)
        module_list = [torch.nn.Linear(in_dim, layers_dim[0])]
        for i in range(1, self.num_layers - 1):
            module_list.append(torch.nn.Linear(layers_dim[i], layers_dim[i+1]))
        module_list.append(torch.nn.Linear(layers_dim[-1], out_dim))
        
        self.layers = torch.nn.ModuleList(module_list)
    
    def forward(self, x):
#         x = x.contiguous().view(-1, self.flat_size(x))
        x = torch.transpose(x, 2, 1)
        
        for i in range(self.num_layers):
            x = torch.nn.functional.relu(self.layers[i](x))
        return x

In [16]:
f_net = Net(w_hidden_dim, num_classes)

In [17]:
# message passing fct
def message_function(h_w, e_vw):
    return torch.cat([torch.unsqueeze(h_w, 0), torch.unsqueeze(e_vw, 0)], axis=1)

In [18]:
h_mat = torch.nn.Parameter(torch.randn(len(unique_degrees), node_num_features, in_dim))

In [19]:
# update fct
def update_function(H_mat, h_v, m_v):
    mat_update = torch.bmm(H_mat, torch.transpose(m_v, 1, 2))
    return mat_update

In [20]:
w_mat = torch.nn.Parameter(torch.randn(batch_size, w_hidden_dim, node_num_features))
w_mat.shape

torch.Size([10, 10, 1433])

In [21]:
# readout fct
def readout_function(h, W):
    mat = torch.bmm(W, h)
    return f_net(mat)

In [22]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(f_net.parameters()) + [h_mat, w_mat], lr=lr)
softmax = torch.nn.Sigmoid()

In [27]:
for step, (id_num, feat, target) in enumerate(train_loader):
#     h = []
#     h.append([feat, target])

    print(id_num, feat)

    for time_step in range(T):

        optimizer.zero_grad()

        messages = torch.zeros((batch_size, 1) + (node_num_features, ))
        
        # for all nodes in batch, we need to compute corresponding messages
        for b in range(batch_size):
            v_node_id = id_num[b]
            
            for idx_edge in range(len(edge_index[0])):
                if edge_index[0][idx_edge] == v_node_id:
                    w = edge_index[1][idx_edge]
                    
                    print(v_node_id, w)
                    
                    node_feature = h[time_step][0][w]
                
        
#         for i in range(len(edge_index[0])):
#             w = edge_index[1][i]
#             node_feature = h[time_step][0][w]
#             edge_feature = h[time_step][1][i]
#             temp_mv = message_function(node_feature, edge_feature)
#             messages[w, :, :] += temp_mv

#         h_t = torch.zeros((len(nodes_features), ) + nodes_features[0].shape + (1, ))
#         for v in range(len(nodes_features)):
#             update = update_function(h_mat, h[time_step][0][v], torch.unsqueeze(messages[v], 0))
#             h_t[v] += update[0, :, :]

#         h.append([torch.squeeze(h_t, 2), edges_features])

#         readout = torch.squeeze(readout_function(h_t, w_mat), 1)

#         output = softmax(readout)
#         _, preds = torch.max(output, 1)
#     #     print(output, preds)

#         loss = loss_function(readout, labels)
#         print(loss)

#         loss.backward(retain_graph=True)

#         optimizer.step()

tensor([[ 14545.],
        [815096.],
        [ 57764.],
        [ 12198.],
        [307336.],
        [ 67633.],
        [128202.],
        [219218.],
        [377303.],
        [159085.]]) tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([14545.]) 51934
tensor([815096.]) 205196
tensor([815096.]) 574264
tensor([57764.]) 1246
tensor([57764.]) 12337
tensor([57764.]) 36140
tensor([57764.]) 46500
tensor([57764.]) 46501
tensor([12198.]) 10177
tensor([12198.]) 12350
tensor([12198.]) 27612
tensor([307336.]) 3243
tensor([307336.]) 174418
tensor([67633.]) 379288
tensor([128202.]) 128203
tensor([377303.]) 195792
tensor([159085.]) 159084
tensor([159085.]) 241821
tensor([14545.]) 51934
tensor([815096.]) 205196
tensor([815096.]) 574264
tensor([57764.]) 1246
tensor([57764.]) 12337
tensor([5

KeyboardInterrupt: 