In [None]:
import networkx as nx
import numpy as np
from scipy.spatial import distance
import matplotlib.pyplot as plt

np.random.seed(67)

# Tutorial 10: Belief Propagation for graph coloring

In [None]:
# Generate graph
N, beta, q, c = 100, 2, 3, 5
G = nx.erdos_renyi_graph(n=N, p=c/(N-1))
G = G.to_directed()

theta = 1 - np.exp(-beta)

In [None]:
# Checking the convergence
def convergence(el1, el2, abs_tol):
    err = 0
    for e in el1:
        err +=  # FILL
    err /=  # FILL
    return(err, err < abs_tol)

In [None]:
def BP(G, beta, q, init='perturb', update='random', max_it=1000, abs_tol=1e-4, alpha=0.1, report=False):
    
    # Initialization BP messages
    if init == 'perturb':
        for e in G.edges():
            G.edges()[e]['message_t'] = # FILL
            G.edges()[e]['message_t'] = # FILL
    elif init == 'random':
        for e in G.edges():
            G.edges()[e]['message_t'] = # FILL
            G.edges()[e]['message_t'] = # FILL
    elif init == 'first-color':
        for e in G.edges():
            G.edges()[e]['message_t'] = # FILL
        
    # Iterating
    conv, it = False, 0
    differences = []
    
    if update=='parallel':
        while not conv and it<max_it:
            # FILL

    elif update=='random':
        while not conv and it<max_it
            # FILL

    if report:
        print('Number of iterations: {0}'.format(it))
    
    return(it, differences)

## Point b)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(12,5))

# Parallel update (all messages at once)
x, y = BP(G, beta, q, update='parallel', report=True)
ax[0].plot(np.arange(x), y)
ax[0].set_title('Parallel update', size=16)
ax[0].set_xlabel('Number of iterations',size=12)
ax[0].set_ylabel('$err$',size=12)

# Random update (one by one)
x, y = BP(G, beta, q, update='random', report=True)
ax[1].plot(np.arange(x), y)
ax[1].set_title('Random update', size=16)
ax[1].set_xlabel('Number of iterations', size=12)
ax[1].set_ylabel('$err$',size=12)

plt.savefig('tutorial10_point_b.png')
plt.show()

## Point c)

In [None]:
c_choices = np.linspace(2, 7, 50)
N_choices = [50, 100, 200]
result = np.zeros((len(c_choices),len(N_choices)+1))
result[:, 0] = c_choices
for j, N in enumerate(N_choices):
    print(N)
    for i, c in enumerate(c_choices):
        iterations = []
        for _ in range(5):
            # FILL
        result[i, j+1] = np.median(iterations)

In [None]:
plt.figure(figsize=(10,5))
for col in range(len(N_choices)):
    plt.plot(result[:,0], result[:,col+1], label='N = {0}'.format(N_choices[col]))
plt.legend(fontsize=12)
plt.xlabel('Average degree $c$', size=12)
plt.ylabel('Converge iteration', size=12)
plt.savefig('tutorial10_point_c.png')

## Point d)

In [None]:
# Attaching marginal distributions to each node
def marginals_one_point(G):
    for i in G.nodes():
        prod = np.ones(q)
        # FILL
        G.nodes()[i]['marginal'] = prod
        
# Assessing the coloring mapping
def accuracy(G, colors):
    errors = 0
    for e in G.edges():
        # FILL
    return(errors/2)

#### d.1)

In [None]:
c_choices = np.linspace(0.1, 7, 30)
result1 = np.zeros((len(c_choices),3))
result1[:, 0] = c_choices
N = 500
for i, c in enumerate(c_choices):
    errors = []
    for _ in range(5):
        # FILL
        errors.append(accuracy(G)/(G.number_of_edges()/2))
    result1[i, 1] = np.mean(errors)
    result1[i, 2] = np.std(errors)

In [None]:
plt.figure(figsize=(10,5))
plt.errorbar(result1[:,0], result1[:,1], result1[:,2])
plt.xlabel('Average degree $c$', size=12)
plt.ylabel('Fraction of violations', size=12)
plt.savefig('tutorial10_point_d.png')
plt.show()

#### d.2)

In [None]:
q_choices = np.arange(2, 12)
result1b = np.zeros((len(q_choices),3))
result1b[:, 0] = q_choices
N = 500
c = 5
for i, q in enumerate(q_choices):
    errors = []
    for _ in range(5):
        # FILL
        errors.append(accuracy(G)/(G.number_of_edges()/2))
    result1b[i, 1] = np.mean(errors)
    result1b[i, 2] = np.std(errors)

In [None]:
plt.figure(figsize=(10,5))
plt.errorbar(result1b[:,0], result1b[:,1], result1b[:,2])
plt.xlabel('Number of colors $q$', size=12)
plt.ylabel('Fraction of violations', size=12)
plt.savefig('tutorial10_point_d_color.png')
plt.show()

## Point e)

In [None]:
N = 1000
q = 3
c = 5

G = nx.erdos_renyi_graph(n=N, p=c/(N-1))
G = G.to_directed()

plt.figure(figsize=(17,5))
for # FILL
    # FILL
plt.legend(fontsize=12)
plt.xlabel('Number of iterations',size=12)
plt.ylabel('$err$', size=12)
plt.savefig('tutorial10_point_e.png')
plt.show()