In [1]:
import dgl
%matplotlib inline
import networkx as nx 
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import os

import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

adj = np.load('../../../data/single_graph/BA/graph_adj.npy')
rows, cols = np.where(adj == 1)
edges = zip(rows.tolist(), cols.tolist())
G = nx.Graph()
G.add_edges_from(edges)
labels = pd.read_csv('../../../data/single_graph/BA/data.csv').label.values


In [2]:
G.number_of_nodes()

200

In [3]:
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        if self.activation is not None:
            h = self.activation(h)
        return {'h' : h}
    
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(gcn_msg, gcn_reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

In [4]:
class Net(nn.Module):
    def __init__(self, in_feat, hidden_dim, out_cls):
        super(Net, self).__init__()
        self.gcn1 = GCN(in_feat, hidden_dim, F.relu)
        self.gcn2 = GCN(hidden_dim, out_cls, None)

    def forward(self, g, features):
        features = g.in_degrees().view(-1, 1).float()
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        return x

In [5]:
import torch
labels = torch.tensor(labels)
inputs = torch.eye(G.number_of_nodes())

In [6]:
S = dgl.DGLGraph()
S.from_networkx(G)
S.ndata['h'] = inputs

In [7]:
inputs.shape

torch.Size([200, 200])

In [8]:
np.random.seed(1)
labeled_nodes = np.random.choice(list(range(G.number_of_nodes())), int(G.number_of_nodes() * 0.15), replace = False)
labels_train = labels[labeled_nodes]

unlabelled_nodes = [i for i in list(range(G.number_of_nodes())) if i not in labeled_nodes]
val_nodes = np.random.choice(unlabelled_nodes, int(len(unlabelled_nodes)*0.2), replace = False)
test_nodes = [i for i in unlabelled_nodes if i not in val_nodes]

val_label = labels[val_nodes]
test_label = labels[test_nodes]

In [18]:
net = Net(1, 64, 10)

optimizer = torch.optim.Adam(net.parameters(), lr=0.03)
all_logits = []
for epoch in range(2000):
    logits = net(S, inputs)
    # we save the logits for visualization later
    logp = F.log_softmax(logits, 1)
    # we only compute loss for labeled nodes
    loss = F.nll_loss(logp[labeled_nodes], labels_train)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch%100 == 0:
        print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))


Epoch 0 | Loss: 316.9016
Epoch 100 | Loss: 2.0407
Epoch 200 | Loss: 1.9164
Epoch 300 | Loss: 1.7446
Epoch 400 | Loss: 1.5990
Epoch 500 | Loss: 1.4944
Epoch 600 | Loss: 1.4168
Epoch 700 | Loss: 1.3567
Epoch 800 | Loss: 1.3086
Epoch 900 | Loss: 1.2687
Epoch 1000 | Loss: 1.2350
Epoch 1100 | Loss: 1.2058
Epoch 1200 | Loss: 1.1803
Epoch 1300 | Loss: 1.1577
Epoch 1400 | Loss: 1.1373
Epoch 1500 | Loss: 1.1190
Epoch 1600 | Loss: 1.1022
Epoch 1700 | Loss: 1.0868
Epoch 1800 | Loss: 1.0727
Epoch 1900 | Loss: 1.0595


In [19]:
net.eval()
logits = net(S, inputs)
# we save the logits for visualization later
logp = F.log_softmax(logits, 1)
# we only compute loss for labeled nodes
#loss = F.nll_loss(logp[labeled_nodes], labels_train)

In [20]:
argmax_Y = torch.max(logp[test_nodes], 1)[1]

In [21]:
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_label == argmax_Y.float()).sum().item() / len(test_label) * 100))

Accuracy of argmax predictions on the test set: 56.617647%
