In [None]:
import csv
import numpy as np
import matplotlib.pyplot as plt
import os
import seaborn as sns
import torch
from torchvision import datasets, transforms

select = range(20)
colors1 = plt.get_cmap('Set1')(select)
colors2 = plt.get_cmap('Set2')(select)
colors3 = plt.get_cmap('Set3')(select)
colorsd = plt.get_cmap('Dark2')(select)
colorsb = plt.get_cmap('tab20b')(select)
colorsc = plt.get_cmap('tab20c')(select)
colorsa = plt.get_cmap('Accent')(select)

font0 = {'family' : 'Times New Roman','weight' : 'bold','size' : 25,}
font1 = {'family' : 'Times New Roman','weight' : 'normal','size' : 25,}

In [None]:
paths = os.walk(r'../ret')
ret_dir = {}
for path, dir_lst, file_lst in paths:
    for dir_name in dir_lst:
        ret = (os.path.join(path, dir_name)).replace("\\","/")
        if ret.find("alpha") != -1:
            ret = ret.split("/")
            if ret[3].find("alpha") != -1:
                if ret[2] in ret_dir.keys():
                    ret_dir[ret[2]].add(ret[3])
                else:
                    ret_dir[ret[2]] = {ret[3]}

print(ret_dir)

In [None]:
def reading_acc(name):
    csv_reader = csv.reader(open(name))
    epoch = []
    acc = []
    loss = []
    time = [] 
    for row in csv_reader:
        epoch.append(int(row[0]))
        acc.append(float(row[1])) 
        if row[2] == "nan":
            loss.append(-1)
        else:
            loss.append(float(row[2]))
        time.append(float(row[3]))
    return epoch, acc, loss, time

def reading_staleness(name):
    csv_reader = csv.reader(open(name))
    epoch = []
    stale_list = []
    cum_stale = []
    for row in csv_reader:
        epoch.append(int(row[0]))
        stale_list.append(eval(row[1])) 
        cum_stale.append(int(row[2]))
    return epoch, stale_list, cum_stale

def mkdir(path):
 
	folder = os.path.exists(path)
 
	if not folder:                   
		os.makedirs(path)   


def oscillation(acc, threshold = -1):
    n = 0
    if threshold > 0:
        return None
    for i in range(len(acc)-1):
        if acc[i+1] - acc[i] < threshold:
            n += 1

    return n

def get_target(acc, target):
    tf = -1
    ts = -1
    for i in range(len(acc)):
        if acc[i] >= target:
            tf = i
            break
    
    for i in range(len(acc)-1,-1,-1):
        if acc[i] < target:
            ts = i + 1
            break

    return tf, ts

In [None]:
target_model = "resnet18"  #"CNN", "resnet18", "vgg", "LSTM"
target_dataset = "cifar10"  #"Shake" "cifar10" "cifar100" "femnist"
target_distribution = "Hetero_Dirchlet" # "Unbalance_Dirchlet" "Hetero_Dirchlet" "Shards" "100"
draw_path = []
for dataset_model in ret_dir.keys():
    dataset = dataset_model.split("_")[0]
    model = dataset_model.split("_")[1]
    if dataset == target_dataset or target_dataset == "":
        if model == target_model or target_model == "":
            for target_ret in ret_dir[dataset_model]:
                if target_ret.find(target_distribution) != -1:
                    target_ret_path = (os.path.join("../ret", dataset_model, target_ret)).replace("\\","/")
                    draw_path.append(target_ret_path)
    
print(draw_path)

In [None]:
ths = [-5, -10, -15, -20, -30]
tar = 50
csv_ret ={}
for p in draw_path:
    files = os.listdir(p)

    for f in files:
        if f.find("Prox") == -1 and f.find("acc") != -1:
            info = f.replace(".csv","").split("_")
            label = target_model +"|" +target_dataset + "|" +target_distribution + "|" +info[6]
            filename = (os.path.join(p, f)).replace("\\","/")
            epoch, acc, loss, time = reading_acc(filename)
            O_ths = []
            for t in ths:
                O_ths.append(oscillation(acc, threshold=t))
            tf, ts = get_target(acc, tar)
            if label in csv_ret.keys():
                csv_ret[label].append([info[0], epoch, acc, loss, time, O_ths, tf, ts])
            else:
                csv_ret[label]=[[info[0], epoch, acc, loss, time, O_ths, tf, ts]]


In [None]:
target_x = "Epoch" #"Epoch" "Time"
target_y = "Acc" #"Acc" "Loss"

for l, a_s_r in csv_ret.items():
    info = l.split("|")
    for r in a_s_r:
        if target_x == "Epoch":
            x_data = r[1]
        elif target_x == "Time":
            x_data = r[4]
        else:
            pass

        if target_y == "Acc":
            y_data = r[2]
        elif target_y == "Loss":
            y_data = r[3]
            plt.axhline(y=0, xmin=x_data[0], xmax=x_data[-1], colors = "black")
        else:
            pass
        if r[0] == "sflAvg":
            plt.plot(x_data, y_data, label = "SFL-Avg", color = colors2[1], linestyle = "solid", alpha = 1)
        elif r[0] == "sflSGD":
            plt.plot(x_data, y_data, label = "SFL-SGD", color = colorsa[4], linestyle = "solid", alpha = 0.8)
        elif r[0] == "aflSGD":
            plt.plot(x_data, y_data, label = "SAFL-SGD", color = colorsd[2], linestyle = "dashed", alpha = 0.8)
        elif r[0] == "aflAvg":
            plt.plot(x_data, y_data, label = "SAFL-Avg", color = colors1[0], linestyle = "dashed", alpha = 0.6)
        else:
            pass
        
    plt.xlabel(target_x, font=font1)
    # plt.ylabel("Acc/%", font=font1)
    plt.ylabel("Accuracy", font=font1)
    plt.xticks(size = 13)
    plt.yticks(size = 13)
    # plt.title(l, font = font0)
    plt.grid(False)
    plt.legend(loc=4, prop = {'size':17})
    # plt.show()
    rp = "./pic/" + info[0] + "/" + info[1] + "/" + info[2]+ "/" +target_x + "_" + target_y
    print(rp)
    mkdir(rp)
    plt.savefig(rp  + "/"+info[3] + ".png", bbox_inches='tight')
    plt.close()
    
target_x = "Epoch" #"Epoch" "Time"
target_y = "Loss" #"Acc" "Loss"

for l, a_s_r in csv_ret.items():
    info = l.split("|")
    
    for r in a_s_r:
        if target_x == "Epoch":
            x_data = r[1]
        elif target_x == "Time":
            x_data = r[4]
        else:
            pass

        if target_y == "Acc":
            y_data = r[2]
        elif target_y == "Loss":
            y_data = r[3]
            plt.axhline(y=0, xmin=x_data[0], xmax=x_data[-1], color = "black")
        else:
            pass

        if r[0] == "sflAvg":
            plt.plot(x_data, y_data, label = "SFL-Avg", color = colors2[1], linestyle = "solid", alpha = 0.9)
        elif r[0] == "sflSGD":
            plt.plot(x_data, y_data, label = "SFL-SGD", color = colorsa[4], linestyle = "solid", alpha = 0.9)
        elif r[0] == "aflSGD":
            plt.plot(x_data, y_data, label = "SAFL-SGD", color = colorsd[2], linestyle = "dashed", alpha = 0.8)
        elif r[0] == "aflAvg":
            plt.plot(x_data, y_data, label = "SAFL-Avg", color = colors1[0], linestyle = "dashed", alpha = 0.6)
        else:
            pass
        
    plt.xlabel(target_x, font=font1)
    # plt.ylabel("Acc/%", font=font1)
    plt.ylabel("Loss", font=font1)
    plt.xticks(size = 13)
    plt.yticks(size = 13)
    # plt.title(l, font = font0)
    plt.grid(False)
    plt.legend(loc=1,prop = {'size':17})
    # plt.show()
    rp = "./pic/" + info[0] + "/" + info[1] + "/" + info[2]+ "/" +target_x + "_" + target_y
    print(rp)
    mkdir(rp)
    plt.savefig(rp  + "/"+info[3] + ".png", bbox_inches='tight')
    plt.close()

In [None]:
target_x = "Threshold" 
target_y = "# Oscillations" 
bar_width = 0.2
for l, a_s_r in csv_ret.items():
    info = l.split("|")
    bias = 0
    for r in a_s_r:
        x_data = np.array([0,1,2,3,4])
        y_data = r[5]
        if r[0] == "sflAvg":
            plt.bar(x_data+bias, y_data, bar_width, label = "SFL-Avg", color = colorsd[3], hatch = "o", alpha = 0.8)
        elif r[0] == "sflSGD":
            plt.bar(x_data+bias, y_data, bar_width, label = "SFL-SGD", color = colorsd[2], hatch = "\\", alpha = 0.8)
        elif r[0] == "aflSGD":
            plt.bar(x_data+bias, y_data, bar_width, label = "SAFL-SGD", color = colorsd[1], hatch = "+", alpha = 0.8)
        elif r[0] == "aflAvg":
            plt.bar(x_data+bias, y_data, bar_width, label = "SAFL-Avg", color = colorsd[0], hatch = "/", alpha = 0.8)
        else:
            pass
        bias += bar_width
        
    plt.xlabel(target_x, font=font1)
    # plt.ylabel("Acc/%", font=font1)
    plt.ylabel("# Oscillations" , font=font1)
    x_labels = ['5', '10', '15', '20', '30']
    plt.xticks(x_data+bar_width/2, x_labels, size = 13)
    plt.yticks(size = 13)
    # plt.title(l, font=font0)
    plt.grid(False)
    plt.legend(loc=1, prop = {'size':17})
    # plt.show()
    rp = "./pic/" + info[0] + "/" + info[1] + "/" + info[2]+ "/" +target_x + "_" + target_y
    print(rp)
    mkdir(rp)
    plt.gcf().set_size_inches(10, 5)
    plt.savefig(rp  + "/"+info[3]+ ".png", bbox_inches='tight')
    plt.close()
    # plt.show()

In [None]:
resources_path = "../resources_No_100_max_50.csv"
csv_reader = csv.reader(open(resources_path))
for row in csv_reader:
    r = row

truth_row = []
for i in range(100):
    truth_row.append(int(r[i]))
n,b,patches = plt.hist(truth_row,10)

for c, p in zip(b, patches):
    plt.setp(p, 'facecolor', colorsb[int(c/5)])
plt.xlabel("Resources", font=font0)
plt.ylabel("Amounts of client", font=font0)
plt.xticks(size = 13)
plt.yticks(size = 13)
plt.title("Resources Distribution", font=font1)
plt.grid(False)
#plt.legend()
plt.savefig("Resources_Distribution.png")

In [None]:
#b = torch.load("../data_partition/data_partition_with_HeteroDiri_dataset_cifar10_1_and_100_models.pt")
paths = os.walk(r'../data_partition')
ret_dir = {}
for path, dir_lst, file_lst in paths:
    for dir_name in dir_lst:
        ret = (os.path.join(path, dir_name)).replace("\\","/")
        data_class = ret.split("/")[2]
        if data_class in ret_dir.keys():
            ret_dir[data_class].add(ret)
        else:
            ret_dir[data_class] = {ret}

target_dataset = "cifar10"  #"cifar10" "cifar100" "femnist"
if target_dataset == "cifar10":
    transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    train_dataset = datasets.CIFAR10("../../data/", train=True, download=True, transform=transform_train)
elif target_dataset == "cifar100":
    transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    train_dataset = datasets.CIFAR100("../../data/", train=True, download=True, transform=transform_train)
elif target_dataset == "femnist":
    transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    train_dataset = datasets.EMNIST("../../data/", train=True, download=True, transform=transform_train,split = 'byclass' )
else:
    pass

target_distribution = "HeteroDiri" # "Unbalance_Diri" "HeteroDiri" "Shards" "100"
draw_path = []
for dataset in ret_dir.keys():
    if dataset == target_dataset or target_dataset == "":
        for target_ret in ret_dir[dataset]:
            distribution_path = os.walk(target_ret)
            for p,d,f in distribution_path:
                for file in f:
                    if file.find(target_distribution) != -1:
                        target_ret_path = (os.path.join(target_ret, file)).replace("\\","/")
                        draw_path.append(target_ret_path)
hatch_par = ['/', '-', 'x', 'o', 'O', '.', '*']

if  target_distribution == "HeteroDiri":    
    for p in draw_path:
        b = torch.load(p)
        info = p.split("/")[3].split("_")
        user_class = []
        num_clients = 10
        for i in range(0, 100):
            l = []
            for j in range(0, 10):
                l.append(len(b[j][i]))
                #print("Client:", i, "Class:", j, "data:", l)
            user_class.append(l)
        user_class = np.array(user_class)
        x_data = list(range(num_clients))
        up_bottom = [0]*num_clients
        for i in range(0,10):
            plt.bar(x_data, user_class[0:num_clients,i], bottom=up_bottom, label = "Class"+str(i))
            up_bottom +=  user_class[0:num_clients,i]   
        plt.xlabel("Client", font=font1)
        # plt.ylabel("Acc/%", font=font1)
        plt.ylabel("Ammount", font=font1)
        plt.xticks(x_data, np.array(range(num_clients))+1,size = 13)
        plt.yticks(size = 13)
        plt.title("Distribution of "+info[5]+'|'+info[3]+'|'+info[6], font=font0)
        plt.grid(False)
        plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0, prop = {'family' : 'Times New Roman','size':17})
        rp = "./pic/Distribution/" + info[5] + "/" + info[3] 
        print(rp)
        mkdir(rp)
        plt.savefig(rp  + "/"+info[6]+ ".png",bbox_inches = 'tight')
        plt.close()
        # plt.show()

        user_num = []
        for i in range(len(user_class)):
            user_num.append(sum(user_class[i]))
        n, bins, patches = plt.hist(user_num, 10)
        for c, p in zip(bins, patches):
            plt.setp(p, 'facecolor', colorsc[int(c/100)])
        plt.xlabel("Data Scales", font=font1)
        plt.ylabel("Numbers of client", font=font1)
        plt.xticks(size = 13)
        plt.yticks(size = 13)
        plt.title("Data Amount Distribution of Clients", font=font0)
        plt.grid(False)
        #plt.legend()
        plt.savefig(rp  + "/"+info[6]+ "Client_distribution.png")
        plt.close()
        # plt.show()
elif target_distribution == "Shards":    
    for p in draw_path:
        b = torch.load(p)
        info = p.split("/")[3].split("_")
        user_class = []
        num_clients = 10
        for i in range(0, 100):
            l = [0]*10
            indices = b[i]
            for data in indices:
                y = int(train_dataset.targets[data])
                l[y] += 1
            l.append(len(b[i]))
            user_class.append(l)
        user_class = np.array(user_class)
        x_data = list(range(num_clients))
        up_bottom = [0]*num_clients
        for i in range(0,10):
            plt.bar(x_data, user_class[0:num_clients,i], bottom=up_bottom, label = "Class"+str(i))
            up_bottom +=  user_class[0:num_clients,i]   
        plt.xlabel("Client", font=font1)
        # plt.ylabel("Acc/%", font=font1)
        plt.ylabel("Ammount of data", font=font1)
        plt.xticks(x_data, np.array(range(num_clients))+1,size = 13)
        plt.yticks(size = 13)
        plt.title("Distribution of "+info[5]+'|'+info[3]+'|'+info[6], font=font0)
        plt.grid(False)
        plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0, prop = {'family' : 'Times New Roman','size':17})
        rp = "./pic/Distribution/" + info[5] + "/" + info[3] 
        print(rp)
        mkdir(rp)
        plt.savefig(rp  + "/"+info[6]+ ".png",bbox_inches = 'tight')
        plt.close()
        # plt.show()
elif target_distribution == "Unbalance_Diri":    
    for p in draw_path:
        b = torch.load(p)
        info = p.split("/")[3].split("_")
        user_class = []
        num_clients = 10
        for i in range(0, 100):
            l = [0]*10
            indices = b[i]
            for data in indices:
                y = int(train_dataset.targets[data])
                l[y] += 1
            l.append(len(b[i]))
            user_class.append(l)
        user_class = np.array(user_class)
        x_data = list(range(num_clients))
        up_bottom = [0]*num_clients
        for i in range(0,10):
            plt.bar(x_data, user_class[0:num_clients,i], bottom=up_bottom, label = "Class"+str(i))
            up_bottom +=  user_class[0:num_clients,i]   
        plt.xlabel("Client", font=font1)
        # plt.ylabel("Acc/%", font=font1)
        plt.ylabel("Ammount of data", font=font1)
        plt.xticks(x_data, np.array(range(num_clients))+1,size = 13)
        plt.yticks(size = 13)
        plt.title("Distribution of "+info[6]+'|'+info[3]+'|'+info[7], font=font0)
        plt.grid(False)
        plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0, prop = {'family' : 'Times New Roman','size':17})
        rp = "./pic/Distribution/" + info[6] + "/" + info[3] 
        print(rp)
        mkdir(rp)
        plt.savefig(rp  + "/"+info[7]+ ".png",bbox_inches = 'tight')
        plt.close()
        # plt.show()

        user_num = []
        for i in range(len(user_class)):
            user_num.append(sum(user_class[i]))
        n, bins, patches = plt.hist(user_num, 10)
        for c, p in zip(bins, patches):
            plt.setp(p, 'facecolor', colorsc[int(c/(max(user_num)/10))])
        plt.xlabel("Data Scale", font=font1)
        plt.ylabel("Numbers of client", font=font1)
        plt.xticks(size = 13)
        plt.yticks(size = 13)
        plt.title("Data Amount Distribution of Clients", font=font0)
        plt.grid(False)
        #plt.legend()
        plt.savefig(rp  + "/"+info[7]+ "Client_distribution.png")
        plt.close()
        # plt.show()
else:
    pass