In [None]:
import os, sys
import torch
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display
from scipy.stats import pearsonr

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())

In [None]:
cache_path = '../fig/'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

In [None]:
all_dfs = {}

## FedAvg

In [None]:
mode = 'FedAvg'
sweep_ids = {
    'cifar10 100%': "jyhong/SplitMix_release/sweeps/d6ua8kbt",
    'cifar10 50%': "jyhong/SplitMix_release/sweeps/jbn34q4n",
    'cifar10_cniid': "jyhong/SplitMix_release/sweeps/6ua8jh9x",
    'Digits': "jyhong/SplitMix_release/sweeps/8g8s7kp4",
    'DomainNet': "jyhong/SplitMix_release/sweeps/y489wn02",
}

In [None]:
df_dict = defaultdict(list)
cfg_keys = ['width_scale']
for name, sweep_id in sweep_ids.items():
    api = wandb.Api()
    sweep = api.sweep(sweep_id)
    for run in sweep.runs:
        df_dict['acc'].append(run.summary['avg test acc'])
        for k in ['GFLOPs', 'model size (MB)']:
            df_dict[k].append(run.summary[k])
        for k in cfg_keys:
            df_dict[k].append(run.config[k])
        df_dict['task'].append(name)
df = pd.DataFrame(df_dict)
df['mode'] = mode
all_dfs[mode] = df
# df

In [None]:
df.set_index(['mode', 'task', 'width_scale']) #.unstack('task') # .reset_index()

## SHeteroFL

In [None]:
mode = 'SHeteroFL'
sweep_ids = {
    'cifar10 100%': "jyhong/SplitMix_release/sweeps/13li9grh",
    'cifar10 50%': "jyhong/SplitMix_release/sweeps/6bbo3mwi",
    'cifar10_cniid': "jyhong/SplitMix_release/sweeps/fvg0045z",
    'Digits': "jyhong/SplitMix_release/sweeps/0lh7d73x",
    'DomainNet': "jyhong/SplitMix_release/sweeps/shs7yw8p",
}

In [None]:
df_dict = defaultdict(list)
cfg_keys = ['test_slim_ratio']
for name, sweep_id in sweep_ids.items():
    api = wandb.Api()
    sweep = api.sweep(sweep_id)
    for run in sweep.runs:
        df_dict['acc'].append(run.summary['avg test acc'])
        for k in ['GFLOPs', 'model size (MB)']:
            df_dict[k].append(run.summary[k])
        for k in cfg_keys:
            df_dict[k].append(run.config[k])
        df_dict['task'].append(name)
df = pd.DataFrame(df_dict)
df['mode'] = mode
df = df.rename(columns={'test_slim_ratio': 'width_scale'})
all_dfs[mode] = df
# df

In [None]:
df.set_index(['task', 'width_scale']).unstack('task') # .reset_index()

## SplitMix

In [None]:
mode = 'SplitMix'
sweep_ids = {
    'cifar10 100%': "jyhong/SplitMix_release/sweeps/fjt4nczs",  # "jyhong/SplitMix_release/sweeps/rio0lk4l",
    'cifar10 50%': "jyhong/SplitMix_release/sweeps/y6e7r33c",
    'cifar10_cniid': "jyhong/SplitMix_release/sweeps/g71nb2yv",
    'Digits': "jyhong/SplitMix_release/sweeps/3wr7bsxb",
    'DomainNet': "jyhong/SplitMix_release/sweeps/2kxrau5h",
}

In [None]:
df_dict = defaultdict(list)
cfg_keys = ['test_slim_ratio']
for name, sweep_id in sweep_ids.items():
    api = wandb.Api()
    sweep = api.sweep(sweep_id)
    for run in sweep.runs:
        df_dict['acc'].append(run.summary['avg test acc'])
        for k in ['GFLOPs', 'model size (MB)']:
            df_dict[k].append(run.summary[k])
        for k in cfg_keys:
            df_dict[k].append(run.config[k])
        df_dict['task'].append(name)
df = pd.DataFrame(df_dict)
df['mode'] = mode
df = df.rename(columns={'test_slim_ratio': 'width_scale'})
all_dfs[mode] = df
# df

In [None]:
df.set_index(['task', 'width_scale']).unstack('task') # .reset_index()

## Aggregate

In [None]:
df = pd.concat([d for _, d in all_dfs.items()])

In [None]:
df_ = df.set_index(['task', 'mode', 'width_scale']).unstack(['mode']).reset_index() #.swaplevel(axis=1)
# df_

In [None]:
# df_ = df.set_index(['task', 'mode', 'width_scale']).unstack(['mode'])#.swaplevel(axis=1)
metrics = ['acc', 'GFLOPs', 'model size (MB)']
algs = ['FedAvg', 'SHeteroFL', 'SplitMix']
task_names = {'Digits': r'Digits feature non-\emph{i.i.d} FL', 'DomainNet': r'DomainNet feature non-\emph{i.i.d} FL',
              'cifar10 100%': r'CIFAR10 \emph{i.i.d} FL (100\%)', 
              'cifar10 50%': r'CIFAR10 \emph{i.i.d} FL (50\%)',
              'cifar10_cniid': r'CIFAR10 class non-\emph{i.i.d} FL',}
def greytext(s):
    return r"\greytext{"+s+"}"
def textbf(s):
    return r"\textbf{"+s+"}"
def cond_bold(s, metric):
    return s if d[(metric, alg)] != bold_dict[metric] else textbf(s)
for task in df_[('task', '')].unique():
    df__ = df_[df_[('task', '')] == task]
    print('\midrule')
    print(" "*17 + r"& \multicolumn{9}{c}{" + task_names[task] + r"} \\")
    for row in range(len(df__)):
        d = df__.iloc[row]
        info = []
        ws = r"$\times "+ f"{d[('width_scale', '')]:g}$"
        info.append(f"{ws:<15s}")
        bold_dict = {}
        for alg in algs:
            if alg == 'FedAvg' and d[('width_scale', '')] > 0.125:
                continue
            if 'acc' not in bold_dict or d[('acc', alg)] > bold_dict['acc']:
                bold_dict['acc'] = d[('acc', alg)]
            if 'GFLOPs' not in bold_dict or d[('GFLOPs', alg)] < bold_dict['GFLOPs']:
                bold_dict['GFLOPs'] = d[('GFLOPs', alg)]
            if 'model size (MB)' not in bold_dict or d[('model size (MB)', alg)] < bold_dict['model size (MB)']:
                bold_dict['model size (MB)'] = d[('model size (MB)', alg)]
        for alg in algs:
            if alg == 'FedAvg' and d[('width_scale', '')] > 0.125:
                info.append("{:10s}".format(greytext(f"{d[('acc', alg)]*100:.1f}\%")))
                info.append("{:10s}".format(greytext(f"{d[('GFLOPs', alg)]*100:.1f}M")))
                info.append("{:14s}".format(greytext(f"{d[('model size (MB)', alg)]:.1f}M")))
            else:
                info.append("{:5s}".format(cond_bold(f"{d[('acc', alg)]*100:.1f}\%", 'acc')))
                if d[('width_scale', '')] > 0.125:
                    info.append("{:5s}".format(cond_bold(f"{d[('GFLOPs', alg)]*100:.1f}M", 'GFLOPs')))
                    info.append("{:5s}".format(cond_bold(f"{d[('model size (MB)', alg)]:.1f}M", 'model size (MB)')))
                else:
                    info.append("{:5s}".format(f"{d[('GFLOPs', alg)]*100:.1f}M", 'GFLOPs'))
                    info.append("{:5s}".format(f"{d[('model size (MB)', alg)]:.1f}M", 'model size (MB)'))
                
        print(f"", ' & '.join(info), r"\\")