In [4]:
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tikzplotlib

In [5]:
def unpack_data(directory_path, datatype='losses.log', epochs=200, num_workers=10):
    directory = os.path.join(directory_path)
    if not os.path.isdir(directory):
        raise Exception(f"custom no directory {directory}")
    data = np.zeros((epochs, num_workers))
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(datatype):
                j = int(file.split('-')[0][1:])
                with open(directory_path + '/' + file, 'r') as f:
                    i = 0
                    for line in f:
                        data[i, j] = line
                        i += 1
    return data

def time_order(directory_path, epochs=200, num_workers=10):

    time_data = unpack_data(directory_path, datatype='total-time.log', epochs=epochs, num_workers=num_workers)
    time_data = np.cumsum(time_data/60, axis=0)
    time_data = time_data.flatten()
    sorted_time = np.argsort(time_data)
    selected_workers = sorted_time[(num_workers-1)::num_workers]
    time_stamps = time_data[selected_workers]
    # print(time_stamps)
    
    return time_stamps, sorted_time

In [7]:
ntest = 5
slowdowns = {"noslowdown", "slowdown2", "slowdown4"}


# exp_type defines:
# 0: the statistic you want (usually epochtime or commtime)
# 1: what graph topology (ring, clique-ring, 2c-clique-ring, 4c-clique-ring)
# 2: degree of noniidness or slowdown (noniid-0.25, noniid-0.9, noniid-0.7, noniid-0.5, iid, noslowdown, slowdown2, slowdown4)
# 3: over which communication algorithms (all, dsgd)
# 4: this is for varying topology for 16 worker ring
exp_type = ["xxx", "ring", "iid", "all", "vary"]
# specify number of epochs
epochs = 200
# specify number of epochs to plot
plot_epochs = 200
# specify number of workers
num_work = 16


if exp_type[3] == "all":
    communicators = ['dsgd', 'ldsgd', 'pdsgd', 'swift', '2swift']
    labels = ['D-SGD', 'LD-SGD', 'PA-SGD', 'SWIFT', 'SWIFT (2-SGD)']
elif exp_type[3] == "dsgd":
    communicators = ['dsgd', 'swift']
    labels = ['D-SGD', 'SWIFT']
else:
    print("BAD1")


if exp_type[1] == "ring" and exp_type[2] == "iid":
    if exp_type[4] == "vary":
        base = f"Random-{num_work}-Ring-VT"
    else:
        base = f"Random-{num_work}-Ring"
elif exp_type[1] == "clique-ring" and exp_type[2][:6] == "noniid":
    base = f"Random-{num_work}-{exp_type[2][7:]}-Noniid"
elif exp_type[1] == "2c-clique-ring":
    base = f"Random-{num_work}-ROC/2Cluster"
elif exp_type[1] == "4c-clique-ring":
    base = f"Random-{num_work}-ROC/4Cluster"
elif exp_type[1] == "ring" and exp_type[2] in slowdowns:
    base = f"Slowdown-{num_work}-Ring"
else:
    print("BAD2")


log_type = ["epoch-time", "commtime"]

for l in log_type:
    over_all = {}
    exp_type[0] = l
    for comm in communicators:
        per_worker = []
    
        for t in range(1, ntest + 1):
            filename = f"Output/{base}/{comm}-{exp_type[2]}-test{t}-{num_work}W-{exp_type[1]}-1sgd-{epochs}epochs"
            if exp_type[2] in slowdowns:
                if comm == "swift":
                    filename = f"Output/{base}/swift-{exp_type[2]}-test{t}-{num_work}W-{exp_type[1]}-1sgd-{epochs}epochs"
                elif comm == "2swift":
                    filename = f"Output/{base}/swift-{exp_type[2]}-test{t}-{num_work}W-{exp_type[1]}-2sgd-{epochs}epochs"
            else:
                if comm == "swift":
                    filename = f"Output/{base}/swift-{exp_type[2]}-test{t}-{num_work}W-no_mem-{exp_type[1]}-1sgd-{epochs}epochs"
                elif comm == "2swift":
                    filename = f"Output/{base}/swift-{exp_type[2]}-test{t}-{num_work}W-no_mem-{exp_type[1]}-2sgd-{epochs}epochs"
            worker_stat = unpack_data(filename, datatype=f"{exp_type[0]}.log", epochs=epochs, num_workers=num_work)
            per_worker.append(worker_stat.mean())
        over_all[comm] = np.array(per_worker).mean()
    
    print(f"avg {exp_type[0]}, {exp_type[1]}, {exp_type[2]}, {exp_type[3]}, {epochs}, {num_work} workers")
    for k in over_all:
        print(k, round(over_all[k], 3))
    print()

avg epoch-time, ring, iid, all, 200, 16 workers
dsgd 2.241
ldsgd 2.172
pdsgd 1.982
swift 1.367
2swift 1.348

avg commtime, ring, iid, all, 200, 16 workers
dsgd 0.962
ldsgd 0.517
pdsgd 0.5
swift 0.121
2swift 0.085

