## Imports

In [165]:
import jax
import networkx
import pandas as pd
import numpy as onp
import jax.numpy as np
from dgl import data
from functools import partial
from tqdm import notebook as tqdm
from sklearn.preprocessing import OneHotEncoder
from jax import random as r
from jax.experimental import optimizers
from jax.experimental.stax import Dropout
from jax.flatten_util import ravel_pytree

## Dataset

In [166]:
cora = data.citation_graph.CoraDataset()
X = cora.features
y = cora.labels

In [167]:
G = cora.graph
A = networkx.to_numpy_matrix(G).A

In [168]:
Aself = A + np.eye(A.shape[0])
d = Aself.sum(axis=1)
W = np.dot(np.diag(1.0 / d), Aself)

In [169]:
train_idx = cora.train_mask.astype(onp.bool)
val_idx = cora.val_mask.astype(onp.bool)
test_idx = cora.test_mask.astype(onp.bool)

In [170]:
y_onehot = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))

In [171]:
input_size = X.shape[1]
hidden_size = 64
output_size = onp.max(y) + 1

## Network

In [172]:
N_agents = 6
P = 0.75
key = r.PRNGKey(42)

In [173]:
def check_connected(M_agents):
    tot = M_agents
    tmp = M_agents
    for i in range(M_agents.shape[0]):
        tmp = np.matmul(tmp,M_agents)
        tot += tmp
    
    return (tot>0.).all()

In [174]:
def communication_mask(N_agents, P, key):
    M_agents = np.zeros((N_agents,N_agents))
    indices = np.triu_indices(N_agents,k=1)
    agents_mask = r.bernoulli(p=P,
                              shape=(int(N_agents*(N_agents-1)/2),),
                              key=key)
    M_agents = jax.ops.index_update(M_agents,indices,agents_mask)
    M_agents += M_agents.T + np.eye(N_agents)
    if check_connected(M_agents):
        return M_agents
    else:
        print("Disconnected graph, retrying...")
        _, subkey = r.split(key)
        return  communication_mask(N_agents,P, subkey)

In [175]:
M_agents = communication_mask(N_agents,P,key)

In [176]:
M_agents

DeviceArray([[1., 1., 1., 1., 0., 1.],
             [1., 1., 0., 1., 1., 1.],
             [1., 0., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1., 1.],
             [0., 1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1., 1.]], dtype=float32)

In [177]:
node_to_agent = r.randint(minval=0, maxval=N_agents, shape=(X.shape[0], ), key=key)

In [178]:
def prune_connections(A, M_agents, node_to_agent):
    new_A = A
    if not M_agents.all():
        #remove connections between agents that cannot communicate
        for i in range(N_agents):
            I = (node_to_agent==i).astype('int32')
            for j in range(i+1,N_agents):
                J = (node_to_agent==j).astype('int32')
                if M_agents[i,j]==0:
                    IJ = np.outer(I,J)
                    new_A -= (A*IJ)
    return new_A

In [179]:
new_A = prune_connections(A, M_agents, node_to_agent)
new_Aself = new_A + onp.eye(new_A.shape[0])
new_d = new_Aself.sum(axis=1)
new_W = np.dot(onp.diag(1.0 / new_d), new_A + new_Aself)

In [180]:
def compute_consensus_matrix(new_A, N_agents, node_to_agent):
    W_agents = np.zeros((N_agents,N_agents))
    if N_agents > 1:
    # count numbere of edges from one agent to the other
        for i in range(N_agents):
            I = (node_to_agent==i).astype('int16')
            for j in range(i,N_agents):
                J = (node_to_agent==j).astype('int16')
                IJ = np.outer(I, J)
                edges = (new_A * IJ).sum()
                W_agents = jax.ops.index_update(W_agents,(i,j),edges)
                W_agents = jax.ops.index_update(W_agents,(j,i),edges)

        W_agents = W_agents / W_agents.sum(axis=0) 
    else:
        W_agents = np.ones((1,1))

    return W_agents

In [181]:
agents_W = compute_consensus_matrix(new_A, N_agents, node_to_agent)

In [182]:
agents_W

DeviceArray([[0.18981159, 0.19436814, 0.21210305, 0.15715244, 0.        ,
              0.15239096],
             [0.19748779, 0.19917582, 0.        , 0.17595702, 0.19868995,
              0.18286915],
             [0.24703419, 0.        , 0.22768125, 0.19677636, 0.2125182 ,
              0.18392013],
             [0.1632938 , 0.17994505, 0.17555423, 0.12760241, 0.17248908,
              0.1434577 ],
             [0.        , 0.1875    , 0.17495507, 0.15916723, 0.20232897,
              0.1544929 ],
             [0.20237264, 0.23901099, 0.20970641, 0.18334453, 0.2139738 ,
              0.18286915]], dtype=float32)

## Training

In [183]:
def cross_entropy(y, ypred):
    return - np.mean(np.sum(y * np.log(ypred), axis=1))

In [184]:
def accuracy(y, ypred):
    return np.mean(np.argmax(y, axis=1) == np.argmax(ypred, axis=1))

In [185]:
def init_params(key):
    _, subkey1 = r.split(key)
    theta_l1 = jax.nn.initializers.lecun_normal()(key=key,
                                                  shape=(input_size, hidden_size))
    theta_l2 = jax.nn.initializers.lecun_normal()(key=subkey1,
                                                  shape=(hidden_size, output_size))
    b_l1 = np.zeros((hidden_size))
    b_l2 = np.zeros((output_size))
    return [theta_l1, b_l1, theta_l2, b_l2]

In [186]:
def init_params_dist(key):
    _, *subkeys = r.split(key, num=N_agents+1)
    params = []
    for i in range(N_agents):
        params.append(init_params(subkeys[i]))
    return params

In [187]:
@partial(jax.jit, static_argnums=(3))
def dist_gcn(params, L, X, mode, key=r.PRNGKey(0)):
    _, dropout_fcn = Dropout(0.5, mode=mode)
    _, subkey = r.split(key)
    X = dropout_fcn(_, X, rng=key)

    H = np.zeros((X.shape[0], hidden_size))
    for i in range(N_agents):
        H = jax.ops.index_update(H, node_to_agent == i,
                                 np.dot(X[node_to_agent == i],
                                 params[i][0]) + params[i][1])

    H = np.dot(L, H)
    H = jax.nn.relu(H)
    H = dropout_fcn(_, H, rng=subkey)
 
    Y = np.zeros((X.shape[0], output_size))
    for i in range(N_agents):
        Y = jax.ops.index_update(Y, node_to_agent == i,
                                 np.dot(H[node_to_agent == i],
                                 params[i][2]) + params[i][3])
  
    return jax.nn.softmax(np.dot(L, Y), axis=1)

In [188]:
def loss(params, L, key):
    ypred = dist_gcn(params, L, X, 'train', key=key)
    losses = cross_entropy(y_onehot[train_idx, :], ypred[train_idx, :])
    return np.mean(losses), (ypred, losses)
loss_and_grad = jax.jit(jax.value_and_grad(loss, has_aux=True))

In [189]:
@jax.jit
def consensus_step(params, W):
    flatten = [ravel_pytree(p) for p in params]
    p_flat = np.vstack(f[0] for f in flatten)
    p_flat = np.dot(W, p_flat)
    return [flatten[i][1](p_flat[i]) for i in range(p_flat.shape[0])]

## Initialization

In [190]:
params = init_params_dist(key)
agents_opts = []
agents_opts_states = []
for i in range(N_agents):
    agents_opts.append(optimizers.adam(1e-3))
    agents_opts_states.append(agents_opts[i][0](params[i]))

## Training

In [191]:
iters = 200
bar = tqdm.tqdm(range(iters))
_, subkey = r.split(key)
best_acc = best_val_acc = best_epoch = 0
results =  {'tr_loss':[],'tr_acc':[],'te_acc':[],'best_acc':[],'best_epoch':[]}
for i in bar:
    subkey, subkey2 = r.split(subkey)
    (loss_i, ypred), grads = loss_and_grad(params, new_W, subkey2)

    if i>0:
        agents_opts_states = consensus_step(agents_opts_states, agents_W)

    for k in range(N_agents):
        agents_opts_states[k] = agents_opts[k][1](i, grads[k], agents_opts_states[k])
        params[k] = agents_opts[k][2](agents_opts_states[k])

    ypred = dist_gcn(params, new_W, X, 'test')

    train_acc_i = accuracy(y_onehot[train_idx, :], ypred[train_idx, :])
    val_acc_i = accuracy(y_onehot[val_idx, :], ypred[val_idx, :])
    test_acc_i = accuracy(y_onehot[test_idx, :], ypred[test_idx, :])

    if val_acc_i > best_val_acc:
        best_val_acc = val_acc_i
        best_acc = test_acc_i
        best_epoch = i

    results['tr_loss'].append(loss_i)
    results['tr_acc'].append(train_acc_i)
    results['te_acc'].append(test_acc_i)

results['best_acc'] = best_acc
results['best_epoch'] = best_epoch

  0%|          | 0/200 [00:00<?, ?it/s]

## Results

In [192]:
print(f"Best epoch:{results['best_epoch']}, Best Accuracy:{results['best_acc']}")

Best epoch:195, Best Accuracy:0.6679999828338623
