NOTE: You need to run `FLOPs Acc trade off.ipynb` with all `data` in `['Cifar10_cniid', 'Digits', 'DomainNet']` first.

In [None]:
import os, sys
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())
plt.rc('axes', labelsize=12)

In [None]:
cache_path = './fig/flops_acc_curve'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

In [None]:
datasets = ['Cifar10_cniid', 'Digits', 'DomainNet']

In [None]:
data_dfs = {}
for data in datasets:
    csv_file = os.path.join(cache_path, f'{data}_res_df.csv')
    print(f'read csv from {csv_file}')
    data_dfs[data] = pd.read_csv(csv_file)

In [None]:
agg = pd.concat([data_dfs[k] for k in data_dfs])

In [None]:
agg

In [None]:
agg['mode'] = agg['mode'].apply(lambda n: n if n!='RT' else 'FedAvg')
agg['mode'] = agg['mode'].apply(lambda n: n if n!='SplitMix' else 'Split-Mix')
agg['data'] = agg['data'].apply(lambda n: n if n!='Cifar10' else 'CIFAR10')

In [None]:
agg['GFLOPs/batch'] = agg['MFLOPs/batch'] / 1e3
agg['FLOPs/batch'] = agg['MFLOPs/batch'] * 1e6
agg['normal params/rnd'] = agg['params/rnd'] * 1e6

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.boxplot(data=agg, x='data', y='normal params/rnd', hue='mode', )
ax.set(yscale='log', ylabel=r'#parameters/round', xlabel='')
ax.grid(True)
ax.get_legend().remove()

plt.tight_layout()
out_file = os.path.join(cache_path, f'all_data_param_rnd.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.boxplot(data=agg, x='data', y='FLOPs/batch', hue='mode', )
ax.set(yscale='log', xlabel='', ylabel=r'MACs/batch')
ax.grid(True)
ax.get_legend().remove()

plt.tight_layout()
out_file = os.path.join(cache_path, f'all_data_flops_batch.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()

## Per-user test acc

In [None]:
def fetch_config_summary(runs, config_keys, summary_keys):
    df_dict = defaultdict(list)
    for run in runs:
        if run.state != 'finished':
            print("WARN: run not finished yet")
        history_len = 0
        missing_sum_key = []
        # print(run.summary['all_domains'])
        # print(run.config.keys())
        for k in summary_keys:
            if k in run.summary:
                h = run.summary[k]
                df_dict[k].append(h)
            else:
                missing_sum_key.append(k)
                break
        if len(missing_sum_key) > 0:
            print(f"missing key: {missing_sum_key}")
            continue
        for k in run.summary.keys():
            if k.endswith('clean test acc'):
                df_dict[k].append(run.summary[k])
        for k in config_keys:
            df_dict[k].append(run.config[k])
    return df_dict

In [None]:
all_df = {}
data = 'DomainNet'

for data in datasets:
    if data == 'Digits':
        sweep_dict = {
            'FedAvg': "jyhong/SplitMix_release/sweeps/8g8s7kp4",
            'SHeteroFL': "jyhong/SplitMix_release/sweeps/0lh7d73x",
            'SplitMix': "jyhong/SplitMix_release/sweeps/3wr7bsxb",
        }
    elif data == 'DomainNet':
        sweep_dict = {
            'FedAvg': "jyhong/SplitMix_release/sweeps/y489wn02",
            'SHeteroFL': "jyhong/SplitMix_release/sweeps/shs7yw8p",
            'SplitMix': "jyhong/SplitMix_release/sweeps/2kxrau5h",
        }
    elif data == 'Cifar10_cniid':
        sweep_dict = {
            'FedAvg': "jyhong/SplitMix_release/sweeps/6ua8jh9x",
            'SHeteroFL': "jyhong/SplitMix_release/sweeps/fvg0045z",
            'SplitMix': "jyhong/SplitMix_release/sweeps/g71nb2yv",
        }
    else:
        raise ValueError()

    # get FedAvg
    mode = 'FedAvg'
    api = wandb.Api()
    sweep = api.sweep(sweep_dict[mode])

    df_dict = fetch_config_summary(
        sweep.runs,
        config_keys = ['width_scale'], 
        summary_keys = ['avg test acc']
    )
    df = pd.DataFrame(df_dict)
    df['mode'] = mode
    df['data'] = data
    df['width'] = df['width_scale']

    all_df[mode] = df

    df = df[df['width'] == 0.125]

    _df = df.drop(['avg test acc', 'width_scale', 'width'], axis=1).set_index(['mode', 'data'])
    _df = _df.stack().reset_index().rename(columns={'level_2': 'user', 0: 'Acc'})
    all_df[mode+'@'+data] = _df
    print(mode+'@'+data, f": {len(_df)}")

    for mode in ['SHeteroFL', 'SplitMix']:
        api = wandb.Api()
        sweep = api.sweep(sweep_dict[mode])

        df_dict = fetch_config_summary(
            sweep.runs,
            config_keys = ['test_slim_ratio'], 
            summary_keys = ['avg test acc']
        )
        df = pd.DataFrame(df_dict)
        df['mode'] = mode
        df['data'] = data
        df['width'] = df['test_slim_ratio']

        df = df[df['width'] == 1.]

        df = df.drop(['avg test acc', 'test_slim_ratio', 'width'], axis=1).set_index(['mode', 'data'])
        df = df.stack().reset_index().rename(columns={'level_2': 'user', 0: 'Acc'})
        all_df[mode+'@'+data] = df
        print(mode+'@'+data, f": {len(df)}")
        # df


## Aggregate

In [None]:
keys = []
for data in datasets:
    keys += [k for k in all_df.keys() if data in k]
keys

In [None]:
agg = pd.concat([all_df[k] for k in keys])
agg['mode'] = agg['mode'].apply(lambda n: n if n != 'RT' else 'FedAvg')
agg['mode'] = agg['mode'].apply(lambda n: n if n != 'SplitMix' else 'Split-Mix')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.boxplot(data=agg, x='data', y='Acc', hue='mode', )
ax.set(ylabel=r'accuracy', xlabel='')
ax.grid(True)

plt.tight_layout()
out_file = os.path.join(cache_path, f'all_data_acc_user.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()