In [None]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from deepwalk.models import negative_sampling
from deepwalk.walks import construct_co_occurrence_matrix

In [None]:
# Sample a graph from an SBM with K communities
q = 0.1
#p = q*np.exp(1.5
p = 4*q

G = nx.stochastic_block_model(sizes=[200,200,200],
                             p=[[p,q,q],[q,p,q],[q,q,p]],
                             seed=1234)
A = nx.to_numpy_array(G)


In [None]:
# Construct co-occurrence matrix
C = construct_co_occurrence_matrix(A, 
                                   n_walks=100, 
                                   walk_length=100,
                                   window_size=3)

In [None]:
# Run SGNS for various numbers of negative samples
ns = [1/10, 1/6, 1/3, 1, 2, 3, 5, 10]
for b in ns:
    X, Y = negative_sampling(C=C, d=2, b=b, n_iter=1000, eta=.00001)
    #X_norms = np.linalg.norm(X, axis=1)
    #Xn = X/X_norms[:, np.newaxis]
    plt.scatter(X[:200, 0], X[:200, 1], color='r')
    plt.scatter(X[200:400, 0], X[200:400, 1], color='b')
    plt.scatter(X[400:, 0], X[400:, 1], color='g')
    plt.savefig(f'sgns_sensitivity_b{b}.png')
    #plt.title(f'b={b}')
    #plt.ylim([-1.5, 1.5])
    #plt.xlim([-1.5, 1.5])
    plt.show()

In [None]:
# Run b=10 again using a smaller step-size and more iterations so it converges.
X, Y = negative_sampling(C=C, d=2, b=10, n_iter=10000, eta=.000001)
#X_norms = np.linalg.norm(X, axis=1)
#Xn = X/X_norms[:, np.newaxis]
plt.scatter(X[:200, 0], X[:200, 1], color='r')
plt.scatter(X[200:400, 0], X[200:400, 1], color='b')
plt.scatter(X[400:, 0], X[400:, 1], color='g')
plt.savefig(f'sgns_sensitivity_b{10}.png')
#plt.title(f'b={b}')
#plt.ylim([-1.5, 1.5])
#plt.xlim([-1.5, 1.5])
plt.show()

In [None]:
# Run b=10 again using a smaller step-size and more iterations so it converges.
X, Y = negative_sampling(C=C, d=2, b=50, n_iter=10000, eta=.000001)
#X_norms = np.linalg.norm(X, axis=1)
#Xn = X/X_norms[:, np.newaxis]
plt.scatter(X[:200, 0], X[:200, 1], color='r')
plt.scatter(X[200:400, 0], X[200:400, 1], color='b')
plt.scatter(X[400:, 0], X[400:, 1], color='g')
#plt.savefig(f'sgns_sensitivity_b{b}.png')
#plt.title(f'b={b}')
#plt.ylim([-1.5, 1.5])
#plt.xlim([-1.5, 1.5])
plt.show()