# Hierarchical Network Embedding for Community Detection

## Imports

In [None]:
import time
import collections

import numpy as np

import torch, torch.nn as nn, torch.autograd as autograd
import torch.utils.data
from torch.distributions import multivariate_normal

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh

import networkx as nx

import matplotlib.pyplot as plt

from sklearn import mixture
from sklearn.metrics.cluster import normalized_mutual_info_score

## Data

In [None]:
data = citegrh.load_cora()
G    = DGLGraph(data.graph)
kn   = G.to_networkx()
pos  = nx.spring_layout(kn)

## Constants

In [None]:
N    = 2708
D    = 400
K    = 7
eps  = 1e-6
beta = 1e-4

batch_size = 300

# FIXME: restrict propbability density functions
max_density = 2.

## Utility functions

In [None]:
def vectorize(row):
    return np.array(row.split(" ")).astype(float)

def normalize(v):
    min_v = torch.min(v)
    range_v = torch.max(v) - min_v
    
    if range_v > 0:
        return (v - min_v) / range_v
    else:
        return torch.zeros(vector.size())

In [None]:
def skipgram2dict(skipgram):
    d = dict(skipgram[i].split(" ", 1) for i in range(len(skipgram)))
    d = {int(num): vectorize(v) for num, v in d.items()}
    
    return collections.OrderedDict(sorted(d.items()))

## K-Means

In [None]:
def KMeans(x, K, Niter=10, verbose=False):
    N, D = x.shape # Number of samples, dimension of the ambient space

    # K-means loop:
    # - x  is the point cloud,
    # - cl is the vector of class labels
    # - c  is the cloud of cluster centroids
    start = time.time()
    c     = x[:K, :].clone().detach() # Simplistic random initialization
    x_i   = torch.clone(x[:, None, :]).detach() # (Npoints, 1, D)

    for i in range(Niter):

        c_j  = torch.clone(c[None, :, :]).detach() # (1, Nclusters, D)
        D_ij = ((x_i - c_j) ** 2).sum(-1) # (Npoints, Nclusters) symbolic matrix of squared distances
        cl   = D_ij.argmin(dim=1).long().view(-1) # Points -> Nearest cluster
        pi   = 1 - normalize(D_ij)

        Ncl = torch.bincount(cl).float() # Class weights
        for d in range(D): # Compute the cluster centroids with torch.bincount:
            c[:, d] = torch.bincount(cl, weights=x[:, d]) / Ncl

    end = time.time()

    if verbose:
        print("K-means example with {:,} points in dimension {:,}, K = {:,}:".format(N, D, K))
        print('Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n'.format(
                Niter, end - start, Niter, (end - start) / Niter))

    return cl, pi

## Community Embedding ([ComE](https://sentic.net/community-embedding.pdf))

In [None]:
def compute_loss(model, gmm, X_batch, pi, psi, Sigma):
    X = model(X_batch)
    
    # Probabilities (X_batch, K)
    probs = torch.FloatTensor(gmm.predict_proba(X)).clamp(min=eps)
    # Gamma (X_batch, K)
    gamma = compute_gamma(pi, probs)
    # N (1, K)
    N     = gamma.sum(dim=0, keepdim=True)
    # Pi (1, K)
    pi    = N / X_batch.size(0)
    
    gmm.fit(X.numpy())
    
    # Psi (K)
    psi   = torch.FloatTensor(gmm.means_)
    # Sigma (K)
    Sigma = torch.FloatTensor(gmm.covariances_)
    
    loss  = -(beta / K) * torch.sum(torch.sum(torch.log(pi * probs), dim=1, keepdim=True))
    
    return psi, Sigma, loss

def compute_gamma(pi, probs):
    gamma_numerator   = pi * probs
    gamma_denominator = torch.sum(gamma_numerator, dim=1, keepdim=True)
    
    return torch.div(gamma_numerator, gamma_denominator)

def reset_embeddings(X_batch, K):
    data = X_batch.clone().detach().numpy()
    gmm  = mixture.GaussianMixture(n_components=K, covariance_type='diag').fit(data)
    
    return torch.FloatTensor(gmm.means_), torch.FloatTensor(gmm.covariances_), gmm

In [None]:
model = nn.Sequential(
    nn.Linear(400, 1024),
    nn.Dropout(0.05),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.Dropout(0.05),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.Dropout(0.05),
    nn.ReLU(),
    nn.Linear(256, 400)
)

### Training

In [None]:
#!deepwalk --help

#!deepwalk --input ../graphsage/cora/cora.adjlist --representation-size 400 --walk-length 40 --output ../graphsage/cora/cora.embeddings

In [None]:
with open("../graphsage/cora/cora.embeddings", "r") as f:
    skipgram = f.readlines()
    skipgram.pop(0)
    dataset = skipgram2dict(skipgram)
    x       = torch.FloatTensor(list(dataset.values()))

In [None]:
train_batch_gen = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
psi, Sigma, gmm = reset_embeddings(x, K)

In [None]:
torch.manual_seed(1234)
opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.1)

num_epochs = 10
train_loss = []

for epoch in range(num_epochs):
    model.train(True)
    
    print("Epoch: {}".format(epoch))
    
    for X_batch in train_batch_gen:
        with autograd.detect_anomaly():
            # Obtain mixed-community membership
            l, pi            = KMeans(X_batch, K)
            # Update pi, mean, cov
            psi, Sigma, loss = compute_loss(model, gmm, X_batch.float(), pi.float(), psi, Sigma)
            
            loss.backward(retain_graph=True)
            
            opt.step()
            opt.zero_grad()
            
        train_loss.append(loss.data.cpu().numpy())
        
    print("Training loss (in-iteration): \t{:.6f}".format(
        np.mean(train_loss[-len(dataset) // batch_size :]))
    )

model.train(False)

### Validation

In [None]:
nc = nx.draw_networkx_nodes(kn, pos, node_color=data.labels,
                            with_labels=False, node_size=0.5, cmap=plt.cm.jet)

In [None]:
labels2, pi2 = KMeans(x, K, Niter=10)

In [None]:
nc = nx.draw_networkx_nodes(kn, pos, node_color=labels2,
                            with_labels=False, node_size=0.5, cmap=plt.cm.jet)

In [None]:
labels3, pi3 = KMeans(model(x), K, Niter=100)

In [None]:
nc = nx.draw_networkx_nodes(kn, pos, node_color=labels3,
                            with_labels=False, node_size=0.5, cmap=plt.cm.jet)

### Mutual info score

In [None]:
normalized_mutual_info_score(labels.flatten(), labels2)

In [None]:
normalized_mutual_info_score(labels.flatten(), labels3)