We aim to evaluate the methods in different pr_nuser (#user per round).

In [None]:
import os, sys
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())

In [None]:
cache_path = './fig/flops_acc_curve'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

In [None]:
data = 'Digits'

sweep_dict = {
    'SplitMix': 'jyhong/SplitMix_release/sweeps/80ewd3yq',
    'SHeteroFL': 'jyhong/SplitMix_release/sweeps/0qjd6qdr',
}

In [None]:
agg_df_dict = {}

In [None]:
def get_slimmabe_ratios(mode: str):
    ps = mode.split('-')
    slimmable_ratios = []
    for p in ps:
        if 'd' in p:
            p, q = p.split('d')  # p: 1/p-net; q: weight of the net in samples
            p, q = int(p), int(q)
            p = p * 1. / q
        else:
            p = int(p)
        slimmable_ratios.append(1. / p)
#     print(f"Set slim ratios: {self.slimmable_ratios} by mode: {mode}")
    return slimmable_ratios

In [None]:
def fetch_config_summary(runs, config_keys, summary_keys):
    df_dict = defaultdict(list)
    for run in runs:
        if run.state != 'finished':
            print("WARN: run not finished yet")
        history_len = 0
        missing_sum_key = []
        for k in summary_keys:
            if k in run.summary:
                h = run.summary[k]
                df_dict[k].append(h)
            else:
                missing_sum_key.append(k)
                break
        if len(missing_sum_key) > 0:
            print(f"missing key: {missing_sum_key}")
            continue
        for k in config_keys:
            df_dict[k].append(run.config[k])
    return df_dict

## (S)HeteroFL

In [None]:
# for mode in ['SHeteroFL pr nuser', 'SHeteroFL pr nuser=-1']:
for mode in ['SHeteroFL']:
    api = wandb.Api()
    sweep = api.sweep(sweep_dict[mode])
    df_dict = fetch_config_summary(
        sweep.runs,
        config_keys = ['test_slim_ratio', 'pr_nuser'], 
        summary_keys = ['avg test acc', 'GFLOPs', 'model size (MB)']
    )
    # del_idxs = []
    # for idx in range(len(df_dict['slim_ratios'])):
    #     slim_ratios = get_slimmabe_ratios(df_dict['slim_ratios'][idx])
        # print(df_dict['slim_ratios'][idx], slim_ratios)
    #     if df_dict['test_slim_ratio'][idx] not in slim_ratios:
    #         # print("del", idx, df_dict['test_slim_ratio'][idx])
    #         del_idxs.append(idx)
    # for k in df_dict:
    #     df_dict[k] = [v for i, v in enumerate(df_dict[k]) if i not in del_idxs]
    df = pd.DataFrame(df_dict)
    df['test_slim_ratio'] = df['test_slim_ratio'] * 100
    df['width'] = df['test_slim_ratio']
    df['pr_nuser'] = df['pr_nuser'].apply(lambda pn: pn if pn > 0 else 50)

    df['mode'] = mode
    agg_df_dict[mode] = df  # [df['slim_sch'] == 'group_slimmable']

In [None]:
sns.lineplot(data=df, x='width', y='avg test acc', hue='pr_nuser', marker='o')
plt.grid(True)

## Split-Mix 0.125atom

In [None]:
for mode in ['SplitMix']:
    # 'SplitMix step=0.25 non-exp'
    api = wandb.Api()
    sweep = api.sweep(sweep_dict[mode])

    print(f"mode: {mode}")
    api = wandb.Api()
    sweep = api.sweep(sweep_dict[mode])

    df_dict = fetch_config_summary(
        sweep.runs,
        config_keys = ['test_slim_ratio', 'pr_nuser'], 
        summary_keys = ['avg test acc', 'GFLOPs', 'model size (MB)']
    )
    df = pd.DataFrame(df_dict)
    df['mode'] = 'SplitMix'
    df['test_slim_ratio'] = df['test_slim_ratio'] * 100
    df['width'] = df['test_slim_ratio']
    df['pr_nuser'] = df['pr_nuser'].apply(lambda pn: pn if pn > 0 else 50)
    if ' ex' in mode:
        df = df[df['pr_nuser'] > 10]
    agg_df_dict[mode] = df

In [None]:
# df_ = df[df['pr_nuser'] >= 10]
df_ = df

fig, ax = plt.subplots(1, 1)
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
sns.lineplot(data=df_, x='width', y='avg test acc', marker='o', hue='pr_nuser')
ax.set(xticks=df['test_slim_ratio'].unique())
# ax.set(xlim=(0, 150), ylim=(0.3, 0.9))
ax.grid(True)

## Aggregation

In [None]:
cmap = sns.color_palette(as_cmap=True)
len(cmap)

more budget-sufficient clients

In [None]:
agg = pd.concat([v for k, v in agg_df_dict.items()])
agg = agg.reset_index()
agg = agg[np.isin(agg['pr_nuser'], [2,5, 20, 50])]
agg['avg test acc'] = agg['avg test acc'] * 100
agg['MFLOPs'] = agg['GFLOPs'] * 1e3
agg['method'] = agg['mode'].apply(lambda n: n if n != 'RT' else 'Ind. FedAvg')
agg['#user/round'] = agg['pr_nuser'] # .apply(lambda n: (n.replace('d', '/')) if '-' in n else n)
# agg = agg[agg['slim_ratios'].apply(lambda n: 'd' not in n)]

fig, ax = plt.subplots(1, 1, figsize=(5,3))
sns.lineplot(data=agg, x='width', y='avg test acc', marker='o', style='method', hue='#user/round',
            style_order=['SplitMix', 'SHeteroFL'], palette=cmap[:len(agg['#user/round'].unique())])  # 'Ind. FedAvg', 
ax.set(xticks=agg['width'].unique(), ylabel='average test accuracy (%)',
      xlabel='width (%)')
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax.grid(True)

plt.tight_layout()
out_file = os.path.join(cache_path, f'Digits_pr_nuser.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()