In [1]:
!pip install dgl

Collecting dgl
  Downloading dgl-1.0.1-cp39-cp39-macosx_10_10_x86_64.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 6.5 MB/s eta 0:00:01
Collecting networkx>=2.1
  Downloading networkx-3.0-py3-none-any.whl (2.0 MB)
[K     |████████████████████████████████| 2.0 MB 68.5 MB/s eta 0:00:01
Installing collected packages: networkx, dgl
  Attempting uninstall: networkx
    Found existing installation: networkx 1.11
    Uninstalling networkx-1.11:
      Successfully uninstalled networkx-1.11
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fa2l 0.2 requires networkx<2.0.0, but you have networkx 3.0 which is incompatible.[0m
Successfully installed dgl-1.0.1 networkx-3.0


In [5]:
%matplotlib inline

In [12]:

import dgl
import torch
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')


def build_karate_club_graph():
    g = dgl.DGLGraph()
    # add 34 nodes into the graph; nodes are labeled from 0~33
    g.add_nodes(34)
    # all 78 edges as a list of tuples
    edge_list = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2),
                 (4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
                 (7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4),
                 (10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
                 (13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1),
                 (21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23),
                 (27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8),
                 (31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8),
                 (32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23),
                 (32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13),
                 (33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22),
                 (33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30),
                 (33, 31), (33, 32)]
    # add edges two lists of nodes: src and dst
    src, dst = tuple(zip(*edge_list))
    g.add_edges(src, dst)
    # edges are directional in DGL; make them bi-directional
    g.add_edges(dst, src)

    return g


In [13]:

import torch
import torch.nn as nn


def gcn_message(edges):
    """
    compute a batch of message called 'msg' using the source nodes' feature 'h'
    """
    return {'msg': edges.src['h']}


def gcn_reduce(nodes):
    """
    compute the new 'h' features by summing received 'msg' in each node's mailbox.
    """
    return {'h': torch.sum(nodes.mailbox['msg'], dim=1)}


class GCNLayer(nn.Module):
    """
    Define the GCNLayer module.
    """

    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first set the node features
        g.ndata['h'] = inputs
        # trigger message passing on all edges
        g.send_and_recv(g.edges(), gcn_message, gcn_reduce)
        # trigger aggregation at all nodes
        g.send_and_recv(g.nodes(), gcn_message,gcn_reduce)
        # get the result node features
        h = g.ndata.pop('h')
        # perform linear transformation
        return self.linear(h)


class GCN(nn.Module):
    """
    Define a 2-layer GCN model.
    """
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(in_feats, hidden_size)
        self.gcn2 = GCNLayer(hidden_size, num_classes)

    def forward(self, g, inputs):
        h = self.gcn1(g, inputs)
        h = torch.relu(h)
        h = self.gcn2(g, h)
        return h

In [36]:
# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F

import networkx as nx
import matplotlib.animation as animation
import matplotlib.pyplot as plt

# from model import GCN
# from graph_builder import build_karate_club_graph

import warnings
warnings.filterwarnings('ignore')


net = GCN(34, 5, 2)
print(net)
G = build_karate_club_graph()

inputs = torch.eye(34)
labeled_nodes = torch.tensor([0, 33])  # only the instructor and the president nodes are labeled
labels = torch.tensor([0, 1])

optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
all_logits = []

for epoch in range(20):
    logits = net(G, inputs)
    all_logits.append(logits.detach())
    logp = F.log_softmax(logits, 1)

    # compute loss for labeled nodes
    loss = F.nll_loss(logp[labeled_nodes], labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))


def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pos[v] = all_logits[i][v].numpy()
        cls = pos[v].argmax()
        colors.append(cls1color if cls else cls2color)
    # ax.cla()
    # ax.axis('off')
    # ax.set_title('Epoch: %d' % i)
    nx.draw_networkx(nx_G.to_undirected(), pos, node_color=colors, with_labels=True, node_size=300)


nx_G = G.to_networkx().to_undirected()
print(G.to_networkx())
# fig, ax = plt.subplots()
ax = plt.figure(dpi=150)
# fig.clf()
# ax = fig.subplots()
# draw(19)  # draw the prediction of the first epoch

ani = animation.FuncAnimation(ax, draw, frames=len(all_logits), interval=200)
plt.show()


GCN(
  (gcn1): GCNLayer(
    (linear): Linear(in_features=34, out_features=5, bias=True)
  )
  (gcn2): GCNLayer(
    (linear): Linear(in_features=5, out_features=2, bias=True)
  )
)
Epoch 0 | Loss: 1.0009
Epoch 1 | Loss: 0.7500
Epoch 2 | Loss: 0.5536
Epoch 3 | Loss: 0.4020
Epoch 4 | Loss: 0.3247
Epoch 5 | Loss: 0.2439
Epoch 6 | Loss: 0.1992
Epoch 7 | Loss: 0.1615
Epoch 8 | Loss: 0.1252
Epoch 9 | Loss: 0.0932
Epoch 10 | Loss: 0.0679
Epoch 11 | Loss: 0.0484
Epoch 12 | Loss: 0.0337
Epoch 13 | Loss: 0.0233
Epoch 14 | Loss: 0.0160
Epoch 15 | Loss: 0.0109
Epoch 16 | Loss: 0.0075
Epoch 17 | Loss: 0.0052
Epoch 18 | Loss: 0.0036
Epoch 19 | Loss: 0.0026
MultiDiGraph with 34 nodes and 156 edges


<Figure size 900x600 with 0 Axes>

In [42]:
!pip install karateclub


Collecting karateclub
  Downloading karateclub-1.3.3.tar.gz (64 kB)
[K     |████████████████████████████████| 64 kB 4.8 MB/s eta 0:00:011
Collecting networkx<2.7
  Downloading networkx-2.6.3-py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 9.2 MB/s eta 0:00:01
[?25hCollecting decorator==4.4.2
  Downloading decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)
Collecting python-louvain
  Downloading python-louvain-0.16.tar.gz (204 kB)
[K     |████████████████████████████████| 204 kB 49.5 MB/s eta 0:00:01
Collecting pygsp
  Downloading PyGSP-0.5.1-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 22.0 MB/s eta 0:00:01
Collecting pandas<=1.3.5
  Downloading pandas-1.3.5-cp39-cp39-macosx_10_9_x86_64.whl (11.3 MB)
[K     |████████████████████████████████| 11.3 MB 94.3 MB/s eta 0:00:01
Collecting python-Levenshtein
  Downloading python_Levenshtein-0.20.9-py3-none-any.whl (9.4 kB)
Collecting Levenshtein==0.20.9
  Downloading Levenshtein-0.2

In [44]:
import networkx as nx

In [43]:

from karateclub import EgoNetSplitter

g = nx. newman_watts_strogatz_graph(1000, 20, 0.05)
splitter = EgoNetSplitter (1.0)
splitter.fit(g)
print(splitter.get_memberships())

{0: [0], 1: [0], 2: [0], 3: [0], 4: [0], 5: [0], 6: [0], 7: [0], 8: [0], 9: [0], 10: [0], 11: [0], 12: [0], 13: [0], 14: [0], 15: [0], 16: [0], 17: [0], 18: [0], 19: [0], 20: [0], 21: [0], 22: [0], 23: [0], 24: [0], 25: [0], 26: [0], 27: [0], 28: [0], 29: [0], 30: [0], 31: [0], 32: [0], 33: [0], 34: [0], 35: [0], 36: [0], 37: [0], 38: [0], 39: [0], 40: [0], 41: [0], 42: [0], 43: [0], 44: [0], 45: [0], 46: [0], 47: [0], 48: [0], 49: [0], 50: [0], 51: [0], 52: [0], 53: [0], 54: [0], 55: [0], 56: [0], 57: [0], 58: [0], 59: [0], 60: [0], 61: [0], 62: [0], 63: [0], 64: [0], 65: [0], 66: [0], 67: [0], 68: [0], 69: [0], 70: [0], 71: [0], 72: [0], 73: [0], 74: [0], 75: [0], 76: [0], 77: [0], 78: [5], 79: [5], 80: [5], 81: [5], 82: [5], 83: [5], 84: [5], 85: [5], 86: [5], 87: [5], 88: [5], 89: [5], 90: [5], 91: [5], 92: [5], 93: [5], 94: [5], 95: [5], 96: [5], 97: [5], 98: [5], 99: [5], 100: [5], 101: [5], 102: [5], 103: [5], 104: [5], 105: [5], 106: [5], 107: [5], 108: [5], 109: [5], 110: [5],