In [None]:
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import cProfile
from functools import reduce
import numpy as np
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
numberOfSamples = 1000
numberOfFeatures = 1

In [None]:
x = np.linspace(0, 10, numberOfSamples).reshape(-1, 1) # we want samples between 1 and -1
x = torch.from_numpy(x).float()
y = x * 3 + torch.rand(numberOfSamples).reshape(-1, 1) * 7 + 10* torch.cos(1.2*x)
plt.scatter(x, y)
#Axes3D.scatter(x.numpy(), x.numpy(), y.numpy())

In [None]:
# Define the model
def model(x, w, b):
    return x @ w.t() + b

# MSE loss
def mse(t1, t2):
    diff = t1 - t2
    return torch.sum(diff * diff) / diff.numel()

def batch_iter(y, tx, batch_size, num_batches=1):
    data_size = len(y)
    shuffle_indices = np.random.permutation(np.arange(data_size))
    shuffled_y = y[shuffle_indices]
    shuffled_tx = tx[shuffle_indices]
    for batch_num in range(num_batches):
        start_index = batch_num * batch_size
        end_index = min((batch_num + 1) * batch_size, data_size)
        if start_index != end_index:
            yield shuffled_y[start_index:end_index], shuffled_tx[start_index:end_index]

In [None]:
def plot_solution(w, b):
    plt.scatter(x, y)
    plt.scatter(x, x @ w.t().detach() + b.detach())

In [None]:
def sgd(targets, inputs, batch_size, max_iter, λ=5e-3):
    losses = []
    w = torch.randn(1, numberOfFeatures, requires_grad=True)
    b = torch.randn(numberOfFeatures, requires_grad=True)
    acc_loss = 0
    i = 0
    for ybatch, xbatch in batch_iter(targets, inputs, batch_size, max_iter):
        preds = model(xbatch, w, b)
        loss = mse(preds, ybatch)
        print('epoch', i, " loss=", loss)
        loss.backward()
        with torch.no_grad():
            w -= w.grad * λ
            b -= b.grad * λ
            w.grad.zero_()
            b.grad.zero_()
        i += 1
    return w, b
plot_solution(*sgd(y, x, 5, 100))

In [None]:
def quantize(tensor):
    N = list(tensor.size())[0]
    Q = torch.zeros(N, dtype=bool)
    Q = tensor > 0
    return Q
def unquantize(tensor):
    tensor = tensor.type(torch.FloatTensor)
    tensor[tensor == 0] = -1
    return tensor # * data_scale

"""
GPU i is responsible for chunk i
"""
def ms_allreduce(tensor, chunksize=1):
    r = dist.get_rank()
    arraySize=tensor.size()
    acc = torch.zeros(arraySize)
    acc[r*chunksize:(r+1)*chunksize] = tensor[r*chunksize:(r+1)*chunksize]
    reqs = []
    #"Naive all-reduce"
    for i in range(dist.get_world_size()): # K steps
        if i != r:
            reqs += [dist.isend(tensor=quantize(tensor[i*chunksize:(i+1)*chunksize]), dst=i)] # K concurrent transfers
    for i in range(dist.get_world_size()): # K steps
        if i != r:
            recv = torch.zeros(arraySize, dtype=bool)
            dist.recv(tensor=recv[r*chunksize:(r+1)*chunksize],src=i) # K / ??? values...
            acc += unquantize(recv)
    for req in reqs:
        req.wait()
    reqs = []
    #"Naive all-gather"
    for i in range(dist.get_world_size()):
        if i != r:
            reqs += [dist.isend(tensor=quantize(acc[r*chunksize:(r+1)*chunksize]),dst=i)]
    #"Naive all-gather"
    for i in range(dist.get_world_size()):
        if i != r:
            recv = torch.zeros(arraySize, dtype=bool)
            dist.recv(tensor=recv[i*chunksize:(i+1)*chunksize], src=i)
            acc[i*chunksize:(i+1)*chunksize] += unquantize(recv[i*chunksize:(i+1)*chunksize])
    for req in reqs:
        req.wait()
    tensor[:] = acc[:]

In [None]:
def dist_sgdq(rank, size, group, targets, inputs, batch_size, max_iter, λ=1e-2):
    losses = []
    w = torch.randn(1, numberOfFeatures, requires_grad=True)
    b = torch.randn(numberOfFeatures, requires_grad=True)
    acc_loss = 0
    i = 0
    for ybatch, xbatch in batch_iter(targets, inputs, batch_size, max_iter):
        preds = model(xbatch, w, b)
        loss = mse(preds, ybatch)
        print('epoch(rank[', rank,'])', i, " loss=", loss)
        loss.backward()
        error_G = torch.zeros(w.size())
        error_b = torch.zeros(b.size())
        with torch.no_grad():
            G = w.grad.clone() + error_G
            ms_allreduce(w.grad)
            error_G = G - w.grad / size
            B = b.grad.clone() + error_b
            ms_allreduce(b.grad)
            error_b = B - b.grad / size 
            #print(rank, ': ', error, ' = ', G)
            #print('rank[', rank, '] has ', w.grad)
            w -= G * λ
            b -= B * λ
            w.grad.zero_()
            b.grad.zero_()
        i += 1
    return w, b

In [None]:
def init_processes(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size+1)
    fn(rank, size+1)
    return 3
def run(rank, size):
    print(rank, 'running...')
    group = dist.new_group(list(range(size)))
    assert numberOfSamples % size == 0
    C = int(numberOfSamples / size)
    f, t = rank*C, (rank+1)*C
    w, b = dist_sgdq(rank, size, group, y[f:t], x[f:t], 5, 100) #should avg instead?
    print('Solution rank', rank,{'w': w, 'b': b})
    print(w, b)
    #dist.send(tensor=w, dst=)

size = 2
processes = []
for rank in range(size):
    p = Process(target=init_processes, args=(rank, size, run))
    p.start()
    processes.append(p)
    
#os.environ['MASTER_ADDR'] = '127.0.0.1'
#os.environ['MASTER_PORT'] = '29500'
#dist.init_process_group('gloo', rank=size, world_size=size+1)
for p in processes:
    p.join()
#plot_solution(ww['0'], bb[0])

In [None]:
plot_solution(torch.tensor([1.0072]), torch.tensor([3.7367]))

In [None]:
#