# RA/SA Trade-off curves

In [None]:
import os, sys
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())
plt.rc('axes', labelsize=12)

In [None]:
cache_path = './fig/oat'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

In [None]:
def summarize_sweep(sweep, include_elapsed=False, acc_type='', app_keys=[]):
    keys = ["test_noise"]+app_keys #, 'src_weight_mode' # ', "noise"
    short_keys = keys
    group_keys = keys
    df_dict = defaultdict(list)
    for run in sweep.runs:
        if run.state != 'finished':
            continue
        flag_get_summary = False
#         for summary_key in run.summary.keys():
        summary_key = 'avg test acc'
        df_dict[summary_key].append(run.summary[summary_key])
        flag_get_summary = True
        if flag_get_summary:
            for k, sk in zip(keys, short_keys):
                df_dict[sk].append(run.config[k])
    if len(df_dict) <= 0:
        return pd.DataFrame()
    else:
        return pd.DataFrame(df_dict).groupby(group_keys).mean()

## Cifar10

In [None]:
data = 'cifar10'
sweep_id = {
    'FedAvg': 'jyhong/SplitMix_release/sweeps/g8xmc74v',
    'SplitMix': 'jyhong/SplitMix_release/sweeps/dsmxxbkc',
    'OAT': 'jyhong/FOAL_AT_Cifar10/sweeps/znbftq21',
}

In [None]:
all_df = {}

In [None]:
for mode in sweep_id:
    print(f'mode: {mode}')
    api = wandb.Api()
    sweep = api.sweep(sweep_id[mode])
    app_keys=['adv_lmbd']
    if mode != 'FedAvg':
        app_keys += ['test_adv_lmbd']
    df = summarize_sweep(sweep, app_keys=app_keys)
    df['mode'] = mode
    all_df[mode] = df

In [None]:
agg = pd.concat([df.reset_index() for _, df in all_df.items()], axis=0, ignore_index=True)
agg = agg.reset_index().set_index(['test_noise']).rename(index={'LinfPGD': 'RA', 'none': 'SA'}).drop('index', axis=1)
agg['mode'] = agg['mode'].apply(lambda n: n if n!='FedAvg' else 'FedAvg+AT')
agg['mode'] = agg['mode'].apply(lambda n: n if n!='OAT' else 'FedAvg+OAT')
agg['mode'] = agg['mode'].apply(lambda n: n if n!='SplitMix' else 'Split-Mix+DAT')
# agg = agg.reset_index().set_index('test_noise') # .rename(index={'LinfPGD': 'RA', 'none': 'SA'})
# agg

In [None]:
_df = agg.reset_index().set_index(['test_noise', 'mode', 'adv_lmbd', 'test_adv_lmbd']).unstack('test_noise').droplevel(0, axis=1)

display(_df)

fig, ax = plt.subplots(1, 1, figsize=(3,2.5))
sns.lineplot(data=_df, y='RA', x='SA', marker='o', hue='mode')
ax.grid(True)
# ax.set(title='CIFAR10 100% AT')
ax.set(xlim=(None, 0.91))
ax.get_legend().remove()

plt.tight_layout()
out_file = os.path.join(cache_path, f'cifar10_ra_sa_tradeoff.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

## Digits All noised

In [None]:
data = 'digits'
sweep_id = {
    'FedAvg': 'jyhong/SplitMix_release/sweeps/d3gmza1k',
    'OAT': 'jyhong/FOAL_Digits_bmk/sweeps/xepuell2',
    'SplitMix': 'jyhong/SplitMix_release/sweeps/zql6s714',  # lbn
}

In [None]:
all_df = {}

In [None]:
# for mode in ['FedAvg', 'OAT']:  # 'SplitMix'
for mode in ['FedAvg', 'OAT', 'SplitMix']:  # 
    print(f'mode: {mode}')
    api = wandb.Api()
    sweep = api.sweep(sweep_id[mode])
    app_keys=['adv_lmbd']
    if mode != 'FedAvg':
        app_keys += ['test_adv_lmbd']
    df = summarize_sweep(sweep, app_keys=app_keys)
    df['mode'] = mode
    all_df[mode] = df

In [None]:
agg = pd.concat([df.reset_index() for _, df in all_df.items()], axis=0, ignore_index=True)
agg = agg.reset_index().set_index(['test_noise']).rename(index={'LinfPGD': 'RA', 'none': 'SA'}).drop('index', axis=1)
agg['mode'] = agg['mode'].apply(lambda n: n if n!='FedAvg' else 'FedAvg+AT')
agg['mode'] = agg['mode'].apply(lambda n: n if n!='OAT' else 'FedAvg+OAT')
# agg = agg.reset_index().set_index('test_noise') # .rename(index={'LinfPGD': 'RA', 'none': 'SA'})
# agg['mode'] = agg['mode'].apply(lambda n: n if n!='RT' else 'FedAvg')
agg['mode'] = agg['mode'].apply(lambda n: n if n!='SplitMix' else 'Split-Mix+DAT')
# agg

In [None]:
_df = agg.reset_index().set_index(['test_noise', 'mode', 'adv_lmbd', 
                                   'test_adv_lmbd']).unstack('test_noise').droplevel(0, axis=1)

display(_df)

fig, ax = plt.subplots(1, 1, figsize=(3,2.5))
sns.lineplot(data=_df, y='RA', x='SA', marker='o', hue='mode', sort=False)
ax.axvline(_df['SA'].max(), linestyle='--', alpha=0.5)
ax.axhline(_df['RA'].max(), linestyle='--', alpha=0.5)
ax.set(ylabel='', ylim=(0.45, 0.66),)
#       yticks=[0.55, 0.6, 0.65]) # , title=data)
ax.grid(True)
# ax.set(title='Digits 100% AT')

plt.tight_layout()
out_file = os.path.join(cache_path, f'{data}_ra_sa_tradeoff.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)