In [None]:
import os, sys
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())

In [None]:
cache_path = './fig/flops_acc_curve'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

In [None]:
data = 'DomainNet'
domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']

sweep_dict = {
    'FedAvg': "jyhong/SplitMix_release/sweeps/y489wn02",
    'SHeteroFL': "jyhong/SplitMix_release/sweeps/shs7yw8p",
    'SplitMix': "jyhong/SplitMix_release/sweeps/2kxrau5h",
}

In [None]:
def fetch_config_summary(runs, config_keys, summary_keys):
    df_dict = defaultdict(list)
    for run in runs:
        if run.state != 'finished':
            print("WARN: run not finished yet")
        history_len = 0
        missing_sum_key = []
        for k in summary_keys:
            if k in run.summary:
                h = run.summary[k]
                df_dict[k].append(h)
            else:
                missing_sum_key.append(k)
                break
        if len(missing_sum_key) > 0:
            print(f"missing key: {missing_sum_key}")
            continue
        for k in run.summary.keys():
            if k.endswith('clean test acc'):
                df_dict[k].append(run.summary[k])
        for k in config_keys:
            df_dict[k].append(run.config[k])
    return df_dict

In [None]:
def rearrange_by_domain(df, reduce='mean', drop_keys=['slim_ratio', 'avg test acc']):
    df_ = df.drop(drop_keys, axis=1).set_index(['mode', 'width']).stack()
    df_ = df_.reset_index().rename(columns={'level_2': 'domain', 0: 'acc mean'})
    display(df_)
    # df_['client'] = df_['domain'].apply(lambda n: int(n[2]))
    df_['domain'] = df_['domain'].apply(lambda n: domains[int(n[0])])
    df_ = df_.groupby(['mode', 'width', 'domain'])
    if reduce == 'mean':
        df_ = df_.mean().unstack('domain')
    return df_

In [None]:
all_df = {}
all_df_diff = {}

In [None]:
mode = 'FedAvg'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['width_scale'], 
    summary_keys = [] # ['avg test acc']
)
df = pd.DataFrame(df_dict)
df['mode'] = mode
# df['slim_ratio'] = df['slim_ratio'] * 100
# df['width'] = df['width_scale']
df = df.rename(columns={'width_scale': 'width'})

df = rearrange_by_domain(df, drop_keys=[])
all_df[mode] = df
# agg_df_dict[mode] = df  # [df['slim_ratio'] == 1.0]
df

In [None]:
mode = 'SHeteroFL'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['test_slim_ratio'],  #, 'slim_sch'], 
    summary_keys = []
)
df = pd.DataFrame(df_dict)
# df['test_slim_ratio'] = df['test_slim_ratio'] * 100
df['width'] = df['test_slim_ratio']

# df['mode'] = mode
# agg_df_dict[mode] df = df[df['slim_sch'] == 'group_size']

# mode = 'S' + mode
df['mode'] = mode
# df = df[df['slim_sch'] == 'group_slimmable']
df = rearrange_by_domain(df, drop_keys=['test_slim_ratio',])
# agg_df_dict['S'+mode]  = df
all_df[mode] = df
all_df_diff[mode] = (all_df['FedAvg'].droplevel(0, axis=0).droplevel(0, axis=1) - df.droplevel(0, axis=0).droplevel(0, axis=1)) * 100
df

In [None]:
mode = 'SplitMix'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['test_slim_ratio',], 
    summary_keys = []
)
df = pd.DataFrame(df_dict)
df['width'] = df['test_slim_ratio']

df['mode'] = mode
df = rearrange_by_domain(df, drop_keys=['test_slim_ratio',])
all_df[mode] = df
all_df_diff[mode] = (all_df['FedAvg'].droplevel(0, axis=0).droplevel(0, axis=1) - df.droplevel(0, axis=0).droplevel(0, axis=1)) * 100
df

In [None]:
mode = 'SplitMix'
fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.heatmap(all_df_diff[mode], vmin=0., vmax=30., annot=True, fmt='.1f', cbar=False, square=True,
           cmap='OrRd')
ax.set(title=mode, xlabel='')
fig.autofmt_xdate()

plt.tight_layout()
out_file = os.path.join(cache_path, f'{data}_domain_width_mat_{mode}.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()

In [None]:
mode = 'SHeteroFL'
fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.heatmap(all_df_diff[mode], vmin=0., vmax=30., annot=True, fmt='.1f', cbar=False, square=True,
           cmap='OrRd')
ax.set(title=mode, xlabel='')
fig.autofmt_xdate()

plt.tight_layout()
out_file = os.path.join(cache_path, f'{data}_domain_width_mat_{mode}.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()