In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
%load_ext autoreload
%autoreload 2

In [104]:
from algorithms import ChainTopology, RingTopology, QuadraticsTask, relaysum_model, relaysum_grad, relaysum_mix, all_reduce, gossip, d2, gradient_tracking, BinaryTreeTopology, relaysum_grad_overlap, Iterate, StarTopology
from random_quadratics import RandomQuadraticsTask
from collections import defaultdict
import torch
import math
import tuning
import pandas as pd
from warehouse import Warehouse

In [53]:
torch.set_default_dtype(torch.float64)

In [145]:
target = 1

In [77]:
def relaysum_dme(world, num_steps):
    state = torch.zeros([world.num_workers])
    local_numbers = torch.zeros([world.num_workers])

    # Initialize messages between connected workers to 0
    messages: Mapping[Edge, float] = defaultdict(float)

    num_messages_sent = 0

    for step in range(num_steps):
        yield Iterate(step, state, num_messages_sent)
        samples = torch.randn_like(state) + target

        local_numbers_update = (samples - local_numbers) / (step + 1)
        local_numbers += local_numbers_update
        
        new_messages = {}
        for worker in world.workers:
            neighbors = world.neighbors(worker)
            for neighbor in neighbors:
                new_messages[worker, neighbor] = local_numbers_update[worker] + sum(
                    messages[n, worker] for n in neighbors if n != neighbor
                )

        messages = new_messages

        for worker in world.workers:
            neighbors = world.neighbors(worker)
            sum_updates = local_numbers_update[worker] + sum(new_messages[n, worker] for n in neighbors)
            state[worker] += sum_updates / world.num_workers

        num_messages_sent += world.max_degree

    yield Iterate(step, state, num_messages_sent)

In [138]:
def relaysum_dme_corr(world, num_steps):
    state = torch.zeros([world.num_workers])
    counts = torch.zeros([world.num_workers])

    # Initialize messages between connected workers to 0
    messages: Mapping[Edge, float] = defaultdict(float)

    num_messages_sent = 0

    for step in range(num_steps):
        yield Iterate(step, state / counts, num_messages_sent)
        samples = torch.randn_like(state) + target
        
        new_messages = {}
        for worker in world.workers:
            neighbors = world.neighbors(worker)
            for neighbor in neighbors:
                new_messages[worker, neighbor, "sample"] = samples[worker] + sum(
                    messages[n, worker, "sample"] for n in neighbors if n != neighbor
                )
                new_messages[worker, neighbor, "count"] = 1 + sum(
                    messages[n, worker, "count"] for n in neighbors if n != neighbor
                )

        messages = new_messages

        for worker in world.workers:
            neighbors = world.neighbors(worker)
            sum_updates = samples[worker] + sum(new_messages[n, worker, "sample"] for n in neighbors)
            count_updates = 1 + sum(new_messages[n, worker, "count"] for n in neighbors)
            state[worker] += sum_updates
            counts[worker] += count_updates

        num_messages_sent += world.max_degree

    yield Iterate(step, state / counts, num_messages_sent)

In [122]:
def gossip_dme(world, num_steps):
    state = torch.zeros([world.num_workers])

    W = world.gossip_matrix()

    num_messages_sent = 0
    for step in range(num_steps):
        yield Iterate(step, state, num_messages_sent)
        samples = torch.randn_like(state) + target

        state += (samples - state) / (step + 1)

        # gossip
        state = W @ state

        num_messages_sent += world.max_degree

    yield Iterate(step, state, num_messages_sent)

In [None]:
def error(state):
    return torch.mean((state - target)**2).item()

w = Warehouse()

w.clear("error", {"method": "RelaySum"})
for n in [8, 16, 32, 64, 128]:
    for seed in range(40):
        for iterate in relaysum_dme_corr(ChainTopology(n), 400):
            if iterate.step % 5 == 0:
                w.log_metric("error", {"value": error(iterate.state), "step": iterate.step}, {"n": n, "method": "RelaySum"})

In [None]:
for n in [8, 16, 32, 64, 128]:
    for seed in range(40):
        for iterate in gossip_dme(RingTopology(n), 400):
            if iterate.step % 5 == 0:
                w.log_metric("error", {"value": error(iterate.state), "step": iterate.step}, {"n": n, "method": "Gossip"})

In [None]:
for n in [8, 16, 32, 64, 128]:
    for iterate in gossip_dme(RingTopology(n), 400):
        if iterate.step % 5 == 0:
            w.log_metric("error", {"value": 1/(n * (iterate.step + 1)), "step": iterate.step}, {"n": n, "method": "1 / nT"})

In [None]:
import seaborn as sns
sns.set_theme("paper")
sns.set_style("whitegrid")
from matplotlib import pyplot as plt
import matplotlib
matplotlib.rcParams['text.usetex'] = True
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    'text.latex.preamble' : r'\usepackage{amsmath}\usepackage{amssymb}\usepackage{newtxmath}'
})
import matplotlib.ticker as mtick
colors = [sns.color_palette("tab10")[i] for i in [7, 0, 1, 9, 3]]
sns.set_palette(colors)

In [None]:
df = w.query("error")
df["method"] = pd.Categorical(df.method, categories=["1 / nT", "RelaySum", "Gossip"])
g = sns.FacetGrid(col="method", data = df, hue="n")
g.map(sns.lineplot, "step", "value")
g.set(yscale="log")
g.set_ylabels("Mean squared error")
g.set_xlabels("Steps")
g.add_legend();

In [None]:
g.savefig("distributed-mean-estimation-results.pdf", bbox_inches="tight")