The notebook provide analysis of teh acc/FLOPs trade-off.
Set `data` to choose different dataset.

In [None]:
import os, sys
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())

In [None]:
cache_path = './fig/flops_acc_curve'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

**NOTE**: Set `data` to choose different dataset.

In [None]:
# TODO Set data here
data = 'Digits'

if data == 'Digits':
    sweep_dict = {
        'FedAvg': "jyhong/SplitMix_release/sweeps/8g8s7kp4",
        'SHeteroFL': "jyhong/SplitMix_release/sweeps/0lh7d73x",
        'SplitMix': "jyhong/SplitMix_release/sweeps/3wr7bsxb",
    }
elif data == 'DomainNet':
    sweep_dict = {
        'FedAvg': "jyhong/SplitMix_release/sweeps/y489wn02",
        'SHeteroFL': "jyhong/SplitMix_release/sweeps/shs7yw8p",
        'SplitMix': "jyhong/SplitMix_release/sweeps/2kxrau5h",
    }
elif data == 'Cifar10_cniid':
    sweep_dict = {
        'FedAvg': "jyhong/SplitMix_release/sweeps/6ua8jh9x",
        'SHeteroFL': "jyhong/SplitMix_release/sweeps/fvg0045z",
        'SplitMix': "jyhong/SplitMix_release/sweeps/g71nb2yv",
    }
else:
    raise ValueError()

In [None]:
agg_df_dict = {}

In [None]:
def fetch_config_summary(runs, config_keys, summary_keys):
    df_dict = defaultdict(list)
    for run in runs:
        if run.state != 'finished':
            print("WARN: run not finished yet")
        history_len = 0
        missing_sum_key = []
        for k in summary_keys:
            if k in run.summary:
                h = run.summary[k]
                df_dict[k].append(h)
            else:
                missing_sum_key.append(k)
                break
        if len(missing_sum_key) > 0:
            print(f"missing key: {missing_sum_key}")
            continue
        for k in config_keys:
            df_dict[k].append(run.config[k])
    return df_dict

## FedAvg

In [None]:
mode = 'FedAvg'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['width_scale'], 
    summary_keys = ['avg test acc', 'GFLOPs', 'model size (MB)']
)
df = pd.DataFrame(df_dict)
df['mode'] = mode
df['width_scale'] = df['width_scale'] * 100
df['width'] = df['width_scale']

agg_df_dict[mode] = df  # [df['slim_ratio'] == 1.0]

In [None]:
fig, ax = plt.subplots(1, 1)
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
sns.lineplot(data=df, x='width', y='avg test acc', marker='o')
ax.set(xticks=df['width'].unique())
# ax.set(xlim=(0, 150), ylim=(0.3, 0.9))
ax.grid(True)

## SHeteroFL

In [None]:
mode = 'SHeteroFL'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['test_slim_ratio'], 
    summary_keys = ['avg test acc', 'GFLOPs', 'model size (MB)']
)
df = pd.DataFrame(df_dict)
df['test_slim_ratio'] = df['test_slim_ratio'] * 100
df['width'] = df['test_slim_ratio']

df['mode'] = mode
agg_df_dict[mode] = df

In [None]:
fig, ax = plt.subplots(1, 1)
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
sns.lineplot(data=df, x='test_slim_ratio', y='avg test acc', 
             marker='o')
ax.set(xticks=df['test_slim_ratio'].unique())
# ax.set(xlim=(0, 150), ylim=(0.3, 0.9))
ax.grid(True)

## Split-Mix 0.125atom

In [None]:
dfs = []
# for atom_slim_ratio in [0.125, 0.25]:
for mode in ['SplitMix']:
    print(f"mode: {mode}")
    api = wandb.Api()
    sweep = api.sweep(sweep_dict[mode])

    df_dict = fetch_config_summary(
        sweep.runs,
        config_keys = ['test_slim_ratio', 'atom_slim_ratio'], 
        summary_keys = ['avg test acc', 'GFLOPs', 'model size (MB)']
    )
    df = pd.DataFrame(df_dict)
    df['mode'] = mode
    df['test_slim_ratio'] = df['test_slim_ratio'] * 100
    df['width'] = df['test_slim_ratio']
    dfs.append(df)
    agg_df_dict[mode] = df
    
df = pd.concat(dfs)

In [None]:
fig, ax = plt.subplots(1, 1)
# for slim_ratio, val_accs in zip(df_dict['slim_ratio'], df_dict['val_acc']):
#     plt.plot(val_accs)
sns.lineplot(data=df, x='test_slim_ratio', y='avg test acc', marker='o')
ax.set(xticks=df['test_slim_ratio'].unique())
# ax.set(xlim=(0, 150), ylim=(0.3, 0.9))
ax.grid(True)

## Aggregation

In [None]:
agg = pd.concat([v for k, v in agg_df_dict.items()])

In [None]:
agg = agg.reset_index()
fig, ax = plt.subplots(1, 1, figsize=(4,4))
sns.lineplot(data=agg, x='width', y='avg test acc', marker='o', hue='mode')
ax.set(xticks=df['test_slim_ratio'].unique(), ylabel='average test accuracy')
# ax.set(xlim=(0, 150), ylim=(0.3, 0.9))
ax.grid(True)

# plt.tight_layout()
# out_file = os.path.join(cache_path, f'digits_width_acc_curve.pdf')
# print(f"save fig => {out_file}")
# plt.savefig(out_file)

plt.show()

In [None]:
# agg = pd.concat([v for k, v in agg_df_dict.items()])
agg = pd.concat([agg_df_dict[k] for k in ['FedAvg', "SHeteroFL", "SplitMix"]])

agg['avg test acc'] = agg['avg test acc'] * 100
agg['MFLOPs'] = agg['GFLOPs'] * 1e3

agg = agg.drop(['test_slim_ratio', 'atom_slim_ratio', 'GFLOPs'],
               axis=1).set_index(['mode', 'width']).unstack('mode')
agg.columns = agg.columns.swaplevel(0,1)
agg.sort_index(axis=1, level=0, inplace=True)
# agg.reindex(columns = agg.columns.reindex(['avg test acc', 'MFLOPs', 'model size (MB)'], level = 1))
agg

In [None]:
print(agg.to_latex(float_format="{:0.1f}".format))

## Analysis of training

In [None]:
# agg = pd.concat([v for k, v in agg_df_dict.items()])
agg = pd.concat([agg_df_dict[k] for k in ['FedAvg', "SHeteroFL", "SplitMix"]])

agg['avg test acc'] = agg['avg test acc'] * 100
agg['MFLOPs'] = agg['GFLOPs'] * 1e3

agg = agg.drop(['test_slim_ratio', 'atom_slim_ratio', 'GFLOPs', 'width_scale'], 
               axis=1).set_index(['mode', 'width']).unstack('mode')
agg.columns = agg.columns.swaplevel(0,1)
agg.sort_index(axis=1, level=0, inplace=True)
# agg.reindex(columns = agg.columns.reindex(['avg test acc', 'MFLOPs', 'model size (MB)'], level = 1))
agg = agg.stack('mode')

In [None]:
agg.reset_index().groupby(['mode', 'width']).mean()

In [None]:
param_per_domain_dict = defaultdict(list)  # {'domains': [], 'params': [], 'mode': []}
res_per_user_dict = defaultdict(list)
if data == 'Digits':
    domains = ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST_M']
    pd_nuser = 10
    batch_size = 32
elif data == 'DomainNet':
    domains = ['real',      'clipart',   'infograph', 'painting',  'quickdraw', 'sketch']
    pd_nuser = 5
    batch_size = 32
elif 'Cifar10' in data:
    domains = ['cifar10']
    pd_nuser = 100
    batch_size = 128
n_domain = len(domains)
n_user = pd_nuser * n_domain

In [None]:
df_ = agg.reset_index()
FedAvg_params = df_[df_['mode'] == 'FedAvg']['model size (MB)'].sum()
param_per_domain_dict['domains'] += domains
param_per_domain_dict['params'] += [1.]*len(domains)
param_per_domain_dict['mode'] += ['FedAvg'] * len(domains)

res_per_user_dict['mode'] += ['FedAvg'] * n_user
res_per_user_dict['params/rnd'] += [df_[df_['mode'] == 'FedAvg']['model size (MB)'].min()] * n_user
res_per_user_dict['MFLOPs/batch'] += [df_[df_['mode'] == 'FedAvg']['MFLOPs'].min()*3*batch_size] * n_user

In [None]:
mode = 'SHeteroFL'

df_ = agg.reset_index()
widths = df_['width'].unique()
max_widths = [widths[int((i*1./n_user)*len(widths))] for i in range(n_user)]
user_domain = [int((i*1./n_user)*len(domains)) for i in range(n_user)]
param_per_domain = [0 for _ in domains]
param_per_user = [0 for _ in range(n_user)]
flops_per_user = [0 for _ in range(n_user)]

for u in range(n_user):
    max_width = widths[int((u*1./n_user)*len(widths))]
    domain = int((u*1./n_user)*len(domains))
    param_per_domain[domain] += df_[(df_['mode'] == mode) & (max_width == df_['width'])]['model size (MB)'].values[0] / pd_nuser
    # upload the max-width model
    param_per_user[u] += df_[(df_['mode'] == mode) & (max_width == df_['width'])]['model size (MB)'].values[0]
    # train FLOPs
    for width in df_['width'].unique():
        if width > max_width:
            break
        flops_per_user[u] += df_[(df_['mode'] == mode) & (width == df_['width'])]['MFLOPs'].values[0] * 3 * batch_size # gradient descent = backward (=2*forwad) + forward
    
param_per_domain = [ppd*1./df_[(df_['mode'] == mode)]['model size (MB)'].max() for ppd in param_per_domain]

param_per_domain_dict['domains'] += domains
param_per_domain_dict['params'] += param_per_domain
param_per_domain_dict['mode'] += [mode] * len(domains)

res_per_user_dict['mode'] += [mode] * n_user
res_per_user_dict['params/rnd'] += param_per_user
res_per_user_dict['MFLOPs/batch'] += flops_per_user

In [None]:
mode = 'SplitMix'

df_ = agg.reset_index()
widths = df_['width'].unique()
atom_width = np.min(widths)
max_widths = [widths[int((i*1./n_user)*len(widths))] for i in range(n_user)]
user_domain = [int((i*1./n_user)*len(domains)) for i in range(n_user)]
param_per_domain = [0 for _ in domains]
param_per_user = [0 for _ in range(n_user)]
flops_per_user = [0 for _ in range(n_user)]

for u in range(n_user):
    max_width = widths[int((u*1./n_user)*len(widths))]
    domain = int((u*1./n_user)*len(domains))
    param_per_domain[domain] += df_[(df_['mode'] == mode)]['model size (MB)'].values[0] * max(max_widths)/df_['width'].min() / pd_nuser
    # upload the max-width model
    param_per_user[u] += df_[(df_['mode'] == mode)]['model size (MB)'].min() * int(max_width/atom_width)
    # train FLOPs
    for width in [df_['width'].min()]:
        flops_per_user[u] += df_[(df_['mode'] == mode) & (width == df_['width'])]['MFLOPs'].values[0] * 3 * batch_size * int(max_width/atom_width) # gradient descent = backward (=2*forwad) + forward

param_per_domain = [ppd*1./df_[(df_['mode'] == mode)]['model size (MB)'].max() for ppd in param_per_domain]
    
param_per_domain_dict['domains'] += domains
param_per_domain_dict['params'] += param_per_domain
param_per_domain_dict['mode'] += [mode] * len(domains)

res_per_user_dict['mode'] += [mode] * n_user
res_per_user_dict['params/rnd'] += param_per_user
res_per_user_dict['MFLOPs/batch'] += flops_per_user

In [None]:
df = pd.DataFrame(param_per_domain_dict)
df['params'] = df['params'] * 100
# df

In [None]:
fig, ax= plt.subplots(1, 1, figsize=(4,3))
sns.barplot(data=df, x='domains', y='params', hue='mode')
ax.set(title="percentage of trained parameters", ylabel="")
ax.grid(True)
fig.autofmt_xdate()

plt.tight_layout()
out_file = os.path.join(cache_path, f'{data}_domain_pct_param.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()

In [None]:
res_df = pd.DataFrame(res_per_user_dict)
res_df['data'] = data
display(res_df)

out_file = os.path.join(cache_path, f'{data}_res_df.csv')
print(f"save df => {out_file}")
res_df.to_csv(out_file)

group_df = res_df.groupby('mode')
for stat in ['mean', 'std', 'max', 'min']:
    print(stat)
    eval(f'display(group_df.{stat}())')