Compare the convergence curves

In [None]:
import os, sys
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display

sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())

In [None]:
cache_path = './fig/slimmable'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)
def save_fig(fname):
    plt.tight_layout()
    out_file = os.path.join(cache_path, fname)
    print(f"save fig => {out_file}")
    plt.savefig(out_file)

In [None]:
data = 'DomainNet'
if data == 'Cifar10_cniid':
    sweep_dict = {
        'FedAvg': "jyhong/SplitMix_release/sweeps/dlxe994l",
        'SHeteroFL': "jyhong/SplitMix_release/sweeps/54zqzzno",
        'Split-Mix': "jyhong/SplitMix_release/sweeps/np49m5hp",
    }
elif data == 'Digits':
    sweep_dict = {
        'FedAvg': "jyhong/SplitMix_release/sweeps/057l05ow",
        'SHeteroFL': "jyhong/SplitMix_release/sweeps/ufwuoldc",
        'Split-Mix': "jyhong/SplitMix_release/sweeps/ybief82d",
    }
elif data == 'DomainNet':
    sweep_dict = {
        'FedAvg': "jyhong/SplitMix_release/sweeps/wf20oh8r",
        'SHeteroFL': "jyhong/SplitMix_release/sweeps/dqfo7crn",
        'Split-Mix': "jyhong/SplitMix_release/sweeps/naglzvcl",
    }

In [None]:
agg_df_dict = {}

In [None]:
def fetch_config_history(config_keys, history_keys):
    df_dict = {k: [] for k in config_keys+history_keys+['step']}
    for run in sweep.runs:
        if run.state != 'finished':
            print("WARN: run not finished yet")
        history_len = 0
        run_history = run.history(samples=1000, keys=history_keys)
        for k in history_keys:
            h = run_history[k]
            df_dict[k].extend(h)
            history_len = len(h)
        if history_len == 0:
            continue
        df_dict['step'].extend(list(range(history_len)))
        for k in config_keys:
            df_dict[k].extend([run.config[k]] * history_len)
    return df_dict

## FedAvg

In [None]:
mode = 'FedAvg'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_history(
    config_keys = ['width_scale'], 
    history_keys = ['val_acc', 'train_loss', '_runtime']
)
df = pd.DataFrame(df_dict)
df['mode'] = mode + ' 1-Net'

RT_df = df
agg_df_dict[mode] = df[df['width_scale'] == 0.125]

In [None]:
df

In [None]:
fig, ax = plt.subplots(1, 1)
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
sns.lineplot(data=df, x='step', y='val_acc', hue='width_scale')
ax.set(xlim=(0, 400)) # , ylim=(0.1, 0.72))
ax.grid(True)

## SHeteroFL

In [None]:
mode = 'SHeteroFL'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
def smooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='same')
    return y_smooth

def fetch_config_history_HeteroFL(runs, config_keys, history_keys, smooth_window=0):
    df_dict = defaultdict(list)
    for k in history_keys:
        for run in runs:
            if run.state != 'finished':
                print("WARN: run not finished yet")
            # history_len = 0
            history = run.history(samples=1000)
            h = history[k]
            # assert 'val_sacc' in k, f"Not val_sacc key: {k}"
            history_len = len(h)
            if history_len == 0:
                continue
            if smooth_window > 0:
                h = smooth(h, 10)
            df_dict['val_acc'].extend(h)
            slim_ratio = k[len('slim'):len('slim')+4]
            if slim_ratio == '0.12':
                slim_ratio = '0.125'
            df_dict['slim_ratio'].extend([float(slim_ratio)] * history_len)
            df_dict['step'].extend(list(range(history_len)))
            assert len(history['_runtime']) == history_len, f"{len(history['_runtime'])} != {history_len}"
            df_dict['_runtime'].extend(history['_runtime'])
            for c_k in config_keys:
                df_dict[c_k].extend([run.config[c_k]] * history_len)
    return df_dict

In [None]:
sel_slim_ratio = 1.
df_dict = fetch_config_history_HeteroFL(
    sweep.runs,
    config_keys = [],
    history_keys = [f'slim{r:.2f} val_sacc' for r in [1.0,0.5,0.25,0.125]],
#     history_keys = [f'val_acc'],
#     history_keys = [f'slim1.00 val_sacc'],
    smooth_window = 0,
)
df = pd.DataFrame(df_dict)
df = df.rename({'slim1.00 val_sacc': 'val_acc'}, axis=1)
df = df.rename({'slim_ratio': 'width_scale'}, axis=1)

df['mode'] = mode + f' {sel_slim_ratio}-Net'
agg_df_dict[mode] = df[(df['width_scale'] == sel_slim_ratio)]  # & (df['slim_sch'] == 'group_size')]

# df['mode'] = mode + f' {sel_slim_ratio}-Net'
# agg_df_dict[mode] = df  #[(df['slim_ratio'] == sel_slim_ratio) & (df['slim_sch'] == 'group_slimmable')]

In [None]:
fig, ax = plt.subplots(1, 1)
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
sns.lineplot(data=df, x='step', y='val_acc', hue='width_scale')
ax.set(xlim=(0, 400))  #, ylim=(0.1, 0.72))
ax.grid(True)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4,4))
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
_df = pd.concat((RT_df, df))

_df['mode'] = _df['mode'].apply(lambda n: {'FedAvg 1-Net': 'FedAvg', 'SHeteroFL 1.0-Net': 'SHeteroFL'}[n])
_df['mode'] = _df['mode'].apply(lambda n: n if n!='RT' else 'FedAvg')
if data == 'DomainNet':
    _df['val_acc'] = smooth(_df['val_acc'], 10)
else:
    _df['val_acc'] = smooth(_df['val_acc'], 5)
_df = _df.rename(columns={'width_scale': 'width'})

sns.lineplot(data=_df, x='step', y='val_acc', hue='width', style='mode')
if data == 'Digits':
    ax.set(xlim=(0, 200), ylim=(0.6, 0.9), ylabel='validation accuracy', xlabel='communication round', 
           title=f'{data}')
elif data == 'Cifar10_pct1':
    ax.set(xlim=(0, 390), ylim=(0.4, 0.93), ylabel='validation accuracy', xlabel='communication round', 
           title=f'CIFAR10 100%')
elif data == 'Cifar10_cniid':
    ax.set(xlim=(0, 390), ylim=(0.25, 0.57), ylabel='validation accuracy', xlabel='communication round', 
           title=f'CIFAR10 100%')
elif data == 'DomainNet':
    ax.set(xlim=(0, 390), ylim=(0.35, 0.73), ylabel='validation accuracy', xlabel='communication round')
ax.grid(True)

save_fig(f'{data.lower()}_val_acc_converg_SHeteroFL_FedAvg.pdf')

plt.show()

## Split-Mix

In [None]:
mode = 'Split-Mix'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
sel_atom_ratio = 0.125
if data == 'Digits':
    df_dict = fetch_config_history(
        config_keys = ['loss_temp', 'rescale_init', 'rescale_layer'],
        history_keys = ['val_acc']
    )
    df = pd.DataFrame(df_dict)
    df = df[(df['loss_temp']=='none') & (df['rescale_init']==True) & (df['rescale_layer']==True)]
    df.drop(['loss_temp', 'rescale_init', 'rescale_layer'], axis=1)
else:
    df_dict = fetch_config_history(
        config_keys = [],
        history_keys = ['val_acc']
    )
    df = pd.DataFrame(df_dict)
# df = df.rename({'slim_ratio': 'atom_ratio'}, axis=1)
df['mode'] = mode + f' {int(1/sel_atom_ratio)}x{sel_atom_ratio}-Net'

# _df = df[df['atom_slim_ratio']==sel_atom_ratio]
_df = df.set_index('step')
_RT_df = RT_df.set_index('step')
_df['_runtime'] = (_RT_df[_RT_df['width_scale'] == 1.]['_runtime'] \
                   + _RT_df[_RT_df['width_scale'] == 0.5]['_runtime'] \
                   + _RT_df[_RT_df['width_scale'] == 0.25]['_runtime'] \
                   + _RT_df[_RT_df['width_scale'] == 0.125]['_runtime']) / 4. *1.05
_df = _df.reset_index()

agg_df_dict[mode] = _df

In [None]:
fig, ax = plt.subplots(1, 1)
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
sns.lineplot(data=_df, x='step', y='val_acc')
ax.set(xlim=(0, None))  #, ylim=(0.1, 0.72))
ax.grid(True)
print(max(df['val_acc']))

## Aggregation

In [None]:
agg = pd.concat([v for k, v in agg_df_dict.items()])

In [None]:
agg_df_dict.keys()

In [None]:
agg = pd.concat([agg_df_dict[k] for k in ['FedAvg', "SHeteroFL", "Split-Mix"]])
agg['mode'] = agg['mode'].apply(lambda n: {'FedAvg 1-Net': 'FedAvg', 'Split-Mix 8x0.125-Net': 'Split-Mix', 
                                           'SHeteroFL 1.0-Net': 'SHeteroFL'}[n])
agg = agg[agg['step'] <= 400]
agg['wall time (min)'] = agg['_runtime'] / 60
# agg

In [None]:
agg['mode'] = agg['mode'].apply(lambda n: n if n!='RT' else r'FedAvg')

fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.lineplot(data=agg, x='step', y='val_acc', hue='mode')
ax.set(ylabel='validation accuracy', xlabel='communication round',
      title=f'{data}')
if data == 'Cifar10_pct1':
    ax.set(xlim=(0, 400), ylim=(0.5, 0.9), title='CIFAR10 100%')
elif data == 'Cifar10_cniid':
    ax.set(xlim=(0, 400), ylim=(0.2, 0.6), title='CIFAR10 class non-i.i.d')
elif data == 'Digits':
    ax.set(xlim=(0, 200), ylim=(0.5, 0.9))
elif data == 'DomainNet':
    ax.set(xlim=(0, 400), ylim=(0.2, 0.72))
ax.grid(True)

save_fig(f'{data.lower()}_val_acc_converg.pdf')

plt.show()