In [70]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from agents import GlobalConsensus, EventGlobalConsensus, EventGlobalConsensusTorch
from models import NN, Dummy
from utils import add_params, scale_params, subtract_params, average_params, sum_params
import torch

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### ADMM Global Consensus

In [71]:
rho = 1

# Initial lambdas must sum to 0!
lam = np.random.randn(2,2)
lambdas = np.vstack([lam, -lam])

agents = [
    GlobalConsensus(
        rho=rho, 
        x_init=np.random.randn(2), 
        lam_init=lam
    ) 
    for lam in lambdas
]

# comopute initial average
avg = 0
for agent in agents:
    avg += agent.x/len(agents) 

# broadcast average to all agents
for agent in agents:
    agent.primal_average = avg
    print(agent.x)
print(f'average = {agents[0].primal_average}')

[-0.6650954   0.52366207]
[-0.49326449  0.12240181]
[1.12682555 0.36038851]
[0.43149256 1.36076102]
average = [0.09998955 0.59180335]


In [72]:
for i in range(50):
    # compute primal variables
    for agent in agents:
        agent.primal_update()

    # compute new global average
    avg = 0
    for agent in agents:
        avg += agent.x/len(agents) 
    
    # broadcast average to all agents
    for agent in agents:
        agent.primal_avg = avg

    # update dual variables    
    for agent in agents:
        agent.dual_update()

for i, agent in enumerate(agents):
    print(f'agent {i}: x = {agent.x}, lam = {agent.lam}')

agent 0: x = [-7.39906180e-11 -1.13503031e-09], lam = [1.47981236e-10 2.27006062e-09]
agent 1: x = [6.72227309e-10 1.45999945e-10], lam = [-1.34445462e-09 -2.91999890e-10]
agent 2: x = [7.39906180e-11 1.13503031e-09], lam = [-1.47981236e-10 -2.27006062e-09]
agent 3: x = [-6.72227309e-10 -1.45999945e-10], lam = [1.34445462e-09 2.91999890e-10]


### Event-Based ADMM Global Consensus

In [73]:
rho = 1
deltas = [1e-3]
t_max = 30

# Initial lambdas must sum to 0!
lam = np.random.randn(2,2)*5
lambdas = np.vstack([lam, -lam])
x_init = np.random.randn(lambdas.shape[0],2)*5
initial_avg = np.mean(x_init, axis=0)


In [75]:
for delta in deltas:

    # Initialise Agents
    agents = [
        EventGlobalConsensus(
            N=len(lambdas), 
            rho=rho, 
            delta=delta,
            x_init=x, 
            lam_init=lam
        ) 
        for lam, x in zip(lambdas, x_init)
    ]

    # broadcast average to all agents
    for agent in agents:
        agent.primal_avg = initial_avg

    # Run event based ADMM
    comm = 0

    for t in range(t_max):
        
        for agent in agents:
            agent.primal_update()
        
        sum_of_res = 0
        for agent in agents:
            if agent.broadcast: 
                comm += 1
                sum_of_res += agent.residual
            
        # This is somehow updating all the agents
        agent.primal_avg += sum_of_res
        
        for agent in agents:
            agent.dual_update()
    
    accuracy = np.sum([np.linalg.norm(agent.x - agent.C, ord=1) for agent in agents])
    load = comm/(t_max*len(agents))
    print(f'Accuracy = {1-accuracy:.6f}, load = {load}')
    

Accuracy = 0.999858, load = 1.0
