In [None]:
# Setup library
import os
import csv
import pickle
import re

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from dataset_util import get_tfrecord, get_index, get_weight, get_datatick, get_sample_size

In [None]:
dataset = get_tfrecord('./electronics/dataset')
print(f'Dataset: {dataset}')

if os.path.exists('./electronics/dataset/get_index.pickle'):
    with open('./electronics/dataset/get_index.pickle', 'rb') as f:
        idx2lab, lab2cnt = pickle.load(f)
else:
    idx2lab, lab2cnt = get_index(dataset)
    with open('./electronics/dataset/get_index.pickle', 'wb') as f:
        pickle.dump((idx2lab, lab2cnt), f)
print(f'lab2cnt: {lab2cnt}')

if os.path.exists('./electronics/dataset/get_weight.pickle'):
    with open('./electronics/dataset/get_weight.pickle', 'rb') as f:
        weights = pickle.load(f)
else:
    weights = get_weight('./electronics/dataset/weights.csv', idx2lab)
    with open('./electronics/dataset/get_weight.pickle', 'wb') as f:
        pickle.dump(weights, f)
print(f'weights: {weights}')

if os.path.exists('./electronics/dataset/get_datatick.pickle'):
    with open('./electronics/dataset/get_datatick.pickle', 'rb') as f:
        datatick = pickle.load(f)
else:
    datatick = get_datatick(lab2cnt, weights)
    with open('./electronics/dataset/get_datatick.pickle', 'wb') as f:
        pickle.dump(datatick, f)
print(f'datatick: {datatick}')

if os.path.exists('./electronics/dataset/get_sample_size.pickle'):
    with open('./electronics/dataset/get_sample_size.pickle', 'rb') as f:
        sample_size = pickle.load(f)
else:
    sample_size = get_sample_size(weights, datatick)
    with open('./electronics/dataset/get_sample_size.pickle', 'wb') as f:
        pickle.dump(sample_size, f) 
print(f'sample_size: {sample_size}')

In [None]:
df_cnt = pd.DataFrame(lab2cnt.items(), columns=('label', 'count'))
df_cnt = df_cnt.sort_values(by=['count'], ascending=True)
df_sam = pd.DataFrame(sample_size.items(), columns=('label', 'count'))
df_sam = df_sam.sort_values(by=['count'], ascending=True)

In [None]:
df = pd.merge(df_cnt, df_sam, on='label')
x = df['label'][df.index].tolist()
height1 = df['count_x'][df.index].tolist()
height2 = df['count_y'][df.index].tolist()

fig = plt.figure(figsize=(8+4, 6))
ax = fig.add_subplot()
bars1 = ax.bar(x, height1, color='turquoise', label='ISCXVPN2016')
bars2 = ax.bar(x, height2, color='teal', label='Weighted Sampled')
ax.set_xticks(ax.get_xticks()) # sometimes removed
ax.set_xticklabels([l+' ('+str(weights[l])+')' for l in x], rotation=35, fontsize='large')
ax.set_ylim((1, 1e7))
ax.set_yscale('log')
ax.set_xlabel('Applications (Sampling weight)', fontsize='x-large')
ax.set_ylabel('The number of packet with payload', fontsize='x-large')
for index, rect in enumerate(bars1):
    height = rect.get_height()
#     ax.annotate(f'{format(height1[index])}',
    if index == 10 or index == 12 or index == 13:
        ax.annotate(f'{height1[index]:0,d}',
                     xy=(rect.get_x()+0.4, height*0.8),
                     ha='center',
                     va='bottom')
    elif index == 14:
        ax.annotate(f'{height1[index]:0,d}',
                     xy=(rect.get_x()+0.4, height*1.2),
                     ha='center',
                     va='bottom')
    elif index == 15:
        ax.annotate(f'{height1[index]:0,d}',
                     xy=(rect.get_x()+0.4, height*1),
                     ha='center',
                     va='bottom')
    else:
        ax.annotate(f'{height1[index]:0,d}',
                     xy=(rect.get_x()+0.4, height/1),
                     ha='center',
                     va='bottom')
for index, rect in enumerate(bars2):
    height = rect.get_height()
    ax.annotate(f'{height2[index]:,d}',
                 xy=(rect.get_x()+0.4, height/2),
                 ha='center',
                 va='bottom')
ax.legend()
fig.savefig('dataset.png', bbox_inches='tight')

### CL vs. FL

In [None]:
with open('electronics/cl_epoch100/centralized_history.pickle', 'rb') as f:
    cl_pkl = pickle.load(f)
with open('electronics/fl_client32_round100_epoch1/ckpt/fl/fl_metrics.pickle', 'rb') as f:
    fl_pkl = pickle.load(f)

In [None]:
cl_accuracy, cl_loss = cl_pkl[1], cl_pkl[3]
fl_accuracy, fl_loss = fl_pkl[4], fl_pkl[3]

In [None]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()
ax.plot(cl_accuracy, color='C1', label='Centralized')
ax.plot(fl_accuracy, color='C2', label='Federated (32 Clients)')
ax.set_xlabel('Epochs, Rounds', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.set_ylim([-0.05, 1.05])
ax.tick_params(axis='both', labelsize='large')
ax.legend(loc='center right', fontsize='large')

ax.axhline(y=cl_accuracy[99], linestyle='--', color='gray')
ax.axhline(y=fl_accuracy[98], linestyle='--', color='gray')
ax.text(-7.5, cl_accuracy[99]-0.000, round(cl_accuracy[99], 2),
        fontsize='medium', 
        horizontalalignment='center', verticalalignment='center')
ax.text(-7.5, fl_accuracy[98]-0.015, round(fl_accuracy[98], 2),
        fontsize='medium', 
        horizontalalignment='center', verticalalignment='center')
fig.savefig('clvsfl-acc.png', bbox_inches='tight')

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()
ax.plot(cl_loss, color='C1', label='Centralized')
ax.plot(fl_loss, color='C2', label='Federated (32 Clients)')
ax.set_xlabel('Epochs, Rounds', fontsize='x-large')
ax.set_ylabel('Loss', fontsize='x-large')
ax.tick_params(axis='both', labelsize='large')
ax.legend(loc='center right', fontsize='large')

ax.axvline(x=45, ymax=cl_loss[45]/3, linestyle='--', color='gray')
ax.text(45, 0+0.025, 45,
        fontsize='medium', 
        horizontalalignment='center', verticalalignment='center')

fig.savefig('clvsfl-loss.png', bbox_inches='tight')

# Clients

In [None]:
xdata = [5, 10, 15, 20, 25, 30]
ydata = list()
for v in xdata:
    path = f'electronics/fl_client{v}_round10_epoch1/ckpt/fl/fl_metrics.pickle'
    with open(path, 'rb') as f:
        data = pickle.load(f)
    ydata.append(max(data[4]))

In [None]:
ydata2 = list()
with open('./electronics/get_sample_size.pickle', 'rb') as f:
    sample_size = pickle.load(f)
    total = sum(sample_size.values())
for v in xdata:
    nrow = total
    tick = nrow*(1-0.1)//v
    ydata2.append(int(tick))

print(ydata2)

In [None]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()
bars = ax.bar(x=xdata, height=ydata, width=3, color='C5')
ax.set_xlabel('The number of clients', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.set_ylim([0, 1])
for index, rect in enumerate(bars):
    height = rect.get_height()
    ax.annotate(f'{ydata[index]:.2f}',
                 xy=(rect.get_x()+1.5, height),
                 ha='center',
                 va='bottom')

ax2 = ax.twinx()
ax2.plot(xdata, ydata2, linestyle='', color='C3', marker='D')
ax2.set_ylim([0, 35000])
ax2.set_ylabel('The number of records per clients', fontsize='x-large')
for index, v in enumerate(ydata2):
    ax2.text(xdata[index], v+1000, v, fontsize='medium', horizontalalignment='center', verticalalignment='center')

fig.savefig('clientsfixeddata.png', bbox_inches='tight')

In [None]:
xdata = [5, 10, 15, 20, 25, 30]
ydata = list()
for v in xdata:
    path = f'electronics/clients_max30/ckpt_clients{v}/fl/fl_metrics.pickle'
    with open(path, 'rb') as f:
        data = pickle.load(f)
    ydata.append(max(data[4]))

In [None]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()
bars = ax.bar(x=xdata, height=ydata, width=3, color='C8')
ax.set_xlabel('The number of clients', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.set_ylim([0, 1])
for index, rect in enumerate(bars):
    height = rect.get_height()
    ax.annotate(f'{ydata[index]:.2f}',
                 xy=(rect.get_x()+1.5, height),
                 ha='center',
                 va='bottom')
fig.savefig('clients.png', bbox_inches='tight')

In [None]:
xdata = list(range(5, 100+5, 5))
ydata = list()
ydata2 = list()
for v in xdata:
    path = f'electronics/clients_max100/ckpt_clients{v}/fl/fl_metrics.pickle'
    with open(path, 'rb') as f:
        data = pickle.load(f)
    v_max = max(data[4])
    ydata.append(v_max)
    i_max = None
    for i, v in enumerate(data[4]):
        if 0.8 <= v:
            i_max = i+1
            break
    ydata2.append(i_max)

In [None]:
fig = plt.figure(figsize=(8*2, 6))
ax = fig.add_subplot()
bars = ax.bar(x=xdata, height=ydata, width=3, color='C8')
ax.set_xlabel('The number of clients', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.set_ylim([0, 1])
for index, rect in enumerate(bars):
    height = rect.get_height()
    ax.annotate(f'{ydata[index]:.2f}',
                 xy=(rect.get_x()+1.5, height),
                 ha='center',
                 va='bottom')

ax2 = ax.twinx()
ax2.plot(xdata, ydata2, linestyle='-', marker='o', color='C3')
ax2.set_ylim([0, 100])
ax2.set_ylabel('Rounds to exceed 0.8 accuracy', fontsize='x-large')
# ax2.set_ylabel('The number of records per clients', fontsize='x-large')
# for index, v in enumerate(ydata2):
#     ax2.text(xdata[index], v+1000, v, fontsize='medium', horizontalalignment='center', verticalalignment='center')

    
fig.savefig('clients.png', bbox_inches='tight')

In [None]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()

for i, (x, y) in enumerate(zip(xdata, ydata)):
    if i not in (0, 1, 11, 19):
        continue
    ax.axvline(x=xdata[i], ymax=ydata[i], linestyle='--', color='gray')
#     ax.axhline(y=ydata[i], xmax=xdata[i]/max(xdata), linestyle='--', color='gray')
    ax.annotate(f'{y:.2f}',
                xy=(x, y+0.01),
                ha='center',
                va='bottom')
    if i in (0, 1):
        ax.annotate(f'{x}', 
                    xy=(x, 0),
                    ha='center',
                    va='bottom')

ax.plot(xdata, ydata, marker='o', color='C8')
ax.set_xlabel('The number of clients', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.set_xlim([0, 105])
ax.set_ylim([0, 1])

fig.savefig('clients-acc.png', bbox_inches='tight')


fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()

for i, (x, y) in enumerate(zip(xdata, ydata2)):
    if i not in (1, 11, 19):
        continue
    ax.axvline(x=xdata[i], ymax=ydata2[i]/100, linestyle='--', color='gray')
#     ax.axhline(y=ydata2[i], xmax=xdata[i]/max(xdata), linestyle='--', color='gray')
    ax.annotate(f'{y:d}',
                xy=(x, y+1),
                ha='center',
                va='bottom')
    if i in (1,):
        ax.annotate(f'{x}', 
                    xy=(x, 0),
                    ha='center',
                    va='bottom')

ax.plot(xdata, ydata2, marker='o', color='C3')
ax.set_xlabel('The number of clients', fontsize='x-large')
ax.set_ylabel('Rounds', fontsize='x-large')
ax.set_xlim([0, 105])
ax.set_ylim([0, 100])
    
fig.savefig('clients-round.png', bbox_inches='tight')

# Data class

In [None]:
xdata = [(1, 1000), (2, 200), (3, 100), (4, 200), (5, 200), (6, 100), (7, 100), (8, 100)]
ydata = list()
for c, r in xdata:
    path = f'electronics/fl_class{c}_round{r}_epoch1/ckpt/fl/fl_metrics.pickle'
    with open(path, 'rb') as f:
        data = pickle.load(f)
    ydata.append(max(data[4]))
xdata = [x[0] for x in xdata]

In [None]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()
bars = ax.bar(x=xdata, height=ydata, color='C9')
ax.set_xlabel('The number of classes distributed to one client', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.tick_params(axis='both', labelsize='large')
ax.set_ylim([0, 1])
for index, rect in enumerate(bars):
    height = rect.get_height()
    ax.annotate(f'{ydata[index]:.2f}',
                 xy=(rect.get_x()+0.4, height),
                 ha='center',
                 va='bottom')

fig.savefig('class.png', bbox_inches='tight')

In [None]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot()

for i, (x, y) in enumerate(zip(xdata, ydata)):
    if i not in (0, 1, 3, 7):
        continue
    ax.axvline(x=xdata[i], ymax=ydata[i], linestyle='--', color='gray')
#     ax.axhline(y=ydata[i], xmax=xdata[i]/max(xdata), linestyle='--', color='gray')
    ax.annotate(f'{y:.2f}',
                xy=(x, y+0.01),
                ha='center',
                va='bottom')

ax.plot(xdata, ydata, marker='v', markersize=8, color='C9')
ax.set_xlabel('The number of classes distributed to each client', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.tick_params(axis='both', labelsize='large')
ax.set_ylim([0, 1])

fig.savefig('class.png', bbox_inches='tight')

# Dynamic class label

In [None]:
with open('electronics/ckpt_dynamic/fl/fl_metrics.pickle', 'rb') as f:
    data = pickle.load(f)
    ydata = data[4]

In [None]:
reorder = [7, 0, 14, 8, 11, 2, 15, 10, 4, 13, 5, 12, 1, 3, 9, 6]
datalab = list()
for i in range(0, len(reorder), 2):
    datalab.append((idx2lab[reorder[i]], idx2lab[reorder[i+1]]))

In [None]:
fig = plt.figure(figsize=(8*2, 6))
ax = fig.add_subplot()

for i in range(1, 8):
    ax.fill_between([10*(i-1), 10*i], [1, 1], color='khaki', alpha=0.1*i+0.1)
ax.fill_between([10*7, len(ydata)], [1, 1], color='khaki', alpha=0.9)
ax.plot(range(1, len(ydata)+1), ydata, color='C5')
ax.plot(range(0, 80, 10), [0]*8, linestyle='', marker='v', color='gray')
for i in range(0, 8):
    ax.text(i*10, 0.025, f'+{datalab[i][0]}, {datalab[i][1]}', rotation=15, fontsize='xx-large')
ax.set_xlabel('Rounds', fontsize='xx-large')
ax.set_ylabel('Accuracy', fontsize='xx-large')
ax.tick_params(axis='both', labelsize='xx-large')
fig.savefig('dynamic.png', bbox_inches='tight')

In [None]:
ydata2 = list()
for i in range(1, 9):
    ydata2.extend([2*i]*10)

In [None]:
fig = plt.figure(figsize=(8*2, 6))
ax = fig.add_subplot()

for i in range(1, 8):
    ax.fill_between([10*(i-1), 10*i], [1, 1], color='khaki', alpha=0.1*i+0.1)
ax.fill_between([10*7, len(ydata)], [1, 1], color='khaki', alpha=0.9)
ax.plot(range(1, len(ydata)+1), ydata, color='C5')
ax.plot(range(0, 80, 10), [0]*8, linestyle='', marker='v', color='gray')
ax.set_xlabel('Rounds', fontsize='xx-large')
ax.set_ylabel('Accuracy', fontsize='xx-large')
ax.tick_params(axis='both', labelsize='xx-large')

ax2 = ax.twinx()
ax2.plot(range(1, len(ydata)+1), ydata2, color='C1')
ax2.set_ylim([0, 20])
ax2.set_yticks(range(0, 24, 4))
ax2.set_ylabel('The number of clients', fontsize='xx-large')
ax2.yaxis.set_major_formatter(plt.FormatStrFormatter('%d'))
ax2.tick_params(axis='both', labelsize='xx-large')

for i in range(0, 8):
    ax.text(i*10, 0.025, f'+{datalab[i][0]}, {datalab[i][1]}', rotation=15, fontsize='xx-large')

fig.savefig('dynamic.png', bbox_inches='tight')

# Epoch 1 / 5

In [None]:
path_root = os.path.abspath(os.path.expanduser('./ieeeaccess/epoch_exp'))
ndir = [path_root]
items = list()
while ndir:
    cdir = ndir.pop()
    for entry in os.scandir(cdir):
        if entry.name.startswith('.') or not entry.is_file():
            ndir.append(entry.path)
        elif entry.name.endswith('.pickle'):
            items.append(entry.path)

In [None]:
regex = 'c(\d+)_e(\d+)_r(\d+)'
result = dict()
for item in items:
    c, e, r = re.findall(regex, item)[0]
    with open(item, 'rb') as f:
        data = pickle.load(f)
    result[e] = data['val_accuracy']

In [None]:
df = pd.DataFrame()
epochs = list()
rounds = list()
for k, v in result.items():
    for i, a in enumerate(v):
        if a >= 0.95:
            r95 = i+1
            break
    if k == '1':
        base = r95
    print(k, r95)
    epochs.append(int(k))
    rounds.append(int(r95))
mag = list()
for r in rounds:
    mag.append(base/r)
df['epochs'] = epochs
df['rounds'] = rounds
df['mag'] = mag
df = df.sort_values(by=['epochs'])
df = df.set_index('epochs')

In [None]:
df

In [None]:
print(df.to_csv())