In [7]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp
import pandas as pd

In [8]:
def load_zachery():
    nodes_data = pd.read_csv('https://raw.githubusercontent.com/myeonghak/DGL-tutorial/master/data/nodes.csv')
    edges_data = pd.read_csv('https://raw.githubusercontent.com/myeonghak/DGL-tutorial/master/data/edges.csv')
    src = edges_data['Src'].to_numpy()
    dst = edges_data['Dst'].to_numpy()
    g = dgl.graph((src, dst))
    club = nodes_data['Club'].to_list()
    # Convert to categorical integer values with 0 for 'Mr. Hi', 1 for 'Officer'.
    club = torch.tensor([c == 'Officer' for c in club]).long()
    # We can also convert it to one-hot encoding.
    club_onehot = F.one_hot(club)
    g.ndata.update({'club' : club, 'club_onehot' : club_onehot})
    return g

In [9]:
g = load_zachery()
print(g)

node_embed = nn.Embedding(g.number_of_nodes(), 5)
inputs = node_embed.weight
nn.init.xavier_uniform_(inputs)

Graph(num_nodes=34, num_edges=156,
      ndata_schemes={'club': Scheme(shape=(), dtype=torch.int64), 'club_onehot': Scheme(shape=(2,), dtype=torch.int64)}
      edata_schemes={})


Parameter containing:
tensor([[-0.1116, -0.3177,  0.1696, -0.3869,  0.0662],
        [ 0.1972, -0.1982, -0.1538, -0.2523,  0.0666],
        [ 0.2542,  0.3176, -0.1036, -0.2525,  0.1133],
        [ 0.0819, -0.3058,  0.0683,  0.1108,  0.0149],
        [ 0.3564,  0.3323, -0.1487, -0.3778,  0.3357],
        [-0.3220,  0.0835, -0.1055, -0.1706,  0.3823],
        [-0.3167, -0.0880,  0.1005, -0.0164,  0.2311],
        [-0.3896, -0.1688, -0.0241, -0.0062, -0.2727],
        [-0.2042,  0.2344,  0.0678, -0.2126, -0.2960],
        [-0.1155,  0.1459,  0.2736, -0.0022, -0.1045],
        [ 0.3178, -0.2798,  0.1331, -0.3657, -0.3147],
        [ 0.2133, -0.0470,  0.2917,  0.3371, -0.3458],
        [ 0.2020,  0.3376,  0.1120, -0.3678,  0.2820],
        [ 0.0112, -0.1472, -0.3233, -0.2581, -0.2420],
        [-0.2245, -0.0516,  0.3246, -0.1678, -0.3326],
        [-0.1589,  0.3127,  0.2259, -0.3833, -0.1675],
        [ 0.1974,  0.2712,  0.2815, -0.1448,  0.0526],
        [-0.2842, -0.0454, -0.0130, -0.1463

In [10]:
u, v = g.edges()
eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_pos_u, test_pos_v = u[eids[:50]], v[eids[:50]]
train_pos_u, train_pos_v = u[eids[50:]], v[eids[50:]]

In [11]:
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(34)
neg_u, neg_v = np.where(adj_neg != 0)
neg_eids = np.random.choice(len(neg_u), 200)
test_neg_u, test_neg_v = neg_u[neg_eids[:50]], neg_v[neg_eids[:50]]
train_neg_u, train_neg_v = neg_u[neg_eids[50:]], neg_v[neg_eids[50:]]

In [12]:
train_u = torch.cat([torch.as_tensor(train_pos_u), torch.as_tensor(train_neg_u)])
train_v = torch.cat([torch.as_tensor(train_pos_v), torch.as_tensor(train_neg_v)])
train_label = torch.cat([torch.zeros(len(train_pos_u)), torch.ones(len(train_neg_u))])

test_u = torch.cat([torch.as_tensor(test_pos_u), torch.as_tensor(test_neg_u)])
test_v = torch.cat([torch.as_tensor(test_pos_v), torch.as_tensor(test_neg_v)])
test_label = torch.cat([torch.zeros(len(test_pos_u)), torch.ones(len(test_neg_u))])


In [16]:
from dgl.nn import SAGEConv

class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')
        self.conv2 = SAGEConv(h_feats, h_feats, 'mean')

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

# input layer : 5
# hidden layer : 16
net = GraphSAGE(5, 16)

In [17]:
optimizer = torch.optim.Adam(itertools.chain(net.parameters(), node_embed.parameters()), lr=0.01)

all_logits = []
for e in range(100):
    logits = net(g, inputs)
    pred = torch.sigmoid((logits[train_u] * logits[train_v]).sum(dim=1))

    loss = F.binary_cross_entropy(pred, train_label)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    all_logits.append(logits.detach())

    if e % 5 == 0: 
        print('In peoch {}, loss: {}'.format(e, loss))



In peoch 0, loss: 0.6724940538406372
In peoch 5, loss: 0.5050308704376221
In peoch 10, loss: 0.38803035020828247
In peoch 15, loss: 0.34129464626312256
In peoch 20, loss: 0.3013409376144409
In peoch 25, loss: 0.26346445083618164
In peoch 30, loss: 0.23212091624736786
In peoch 35, loss: 0.19596540927886963
In peoch 40, loss: 0.15028861165046692
In peoch 45, loss: 0.10778073966503143
In peoch 50, loss: 0.07224243879318237
In peoch 55, loss: 0.04359419643878937
In peoch 60, loss: 0.02438025176525116
In peoch 65, loss: 0.013142098672688007
In peoch 70, loss: 0.006081700790673494
In peoch 75, loss: 0.0024116893764585257
In peoch 80, loss: 0.0011544295120984316
In peoch 85, loss: 0.0006790601182729006
In peoch 90, loss: 0.0004566108400467783
In peoch 95, loss: 0.00033353964681737125


In [18]:
pred = torch.sigmoid((logits[test_u] * logits[test_v]).sum(dim=1))
print('Accuracy', ((pred >= 0.5) == test_label).sum().item() / len(pred))

Accuracy 0.86
