In [None]:
import os, sys
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())
plt.rc('axes', labelsize=12)

In [None]:
cache_path = './fig/joint'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

In [None]:
def fetch_config_summary(runs, config_keys, summary_keys):
    df_dict = defaultdict(list)
    for run in runs:
        if run.state != 'finished':
            print(f"Non-fin run w/ stat: {run.state}")
        missing_sum_key = []
        for k in summary_keys:
            if k in run.summary:
                h = run.summary[k]
                df_dict[k].append(h)
            else:
                missing_sum_key.append(k)
                break
        if len(missing_sum_key) > 0:
            print(f"missing key: {missing_sum_key}")
            continue
        for k in config_keys:
            if k == 'adv_lmbd' and k not in run.config:
                df_dict[k].append(False)
            elif k == 'test_noise' and k not in run.config:
                df_dict[k].append('none')
            else:
                df_dict[k].append(run.config[k])
    return df_dict

In [None]:
all_df_dict = {}

# DomainNet

In [None]:
data = 'DomainNet'
domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']

sweep_dict = {
    'FedAvg': "jyhong/SplitMix_release/sweeps/tft5h80j",
    'SplitMix': "jyhong/SplitMix_release/sweeps/ctuur0ey",
}

## FedAvg

In [None]:
mode = 'FedAvg'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['width_scale', 'adv_lmbd', 'test_noise'], 
    summary_keys = ['avg test acc']
)
df = pd.DataFrame(df_dict)

df['test_noise'] = df['test_noise'].apply(lambda n: {'LinfPGD':'RA', 'none':'SA'}[n])
df = df.groupby(['width_scale', 'adv_lmbd', 'test_noise']).mean()
df = df.unstack('test_noise')
df = df.droplevel(0, axis=1)
df

In [None]:
df_ = df.reset_index()
df_['mode'] = mode
df_ = df_.rename(columns={'width_scale': 'width'})
all_df_dict[mode] = df_[df_['width']==0.125]
fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.lineplot(data=df_, x='SA', y='RA', hue='width', marker='o')
ax.set()
ax.grid(True)

# plt.tight_layout()
# out_file = os.path.join(cache_path, f'{data}_domain_joint.pdf')
# print(f"save fig => {out_file}")
# plt.savefig(out_file)

plt.show()

## SplitMix

In [None]:
mode = 'SplitMix'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['test_slim_ratio', 'test_adv_lmbd', 'test_noise'], 
    summary_keys = ['avg test acc']
)
df = pd.DataFrame(df_dict)
# df['mode'] = mode
# # df['slim_ratio'] = df['slim_ratio'] * 100
# df['width'] = df['slim_ratio']


# df = rearrange_by_domain(df)
# all_df[mode] = df
# # agg_df_dict[mode] = df  # [df['slim_ratio'] == 1.0]
df['test_noise'] = df['test_noise'].apply(lambda n: {'LinfPGD':'RA', 'none':'SA'}[n])
df = df.groupby(['test_slim_ratio', 'test_adv_lmbd', 'test_noise']).mean()
df = df.unstack('test_noise')
df = df.droplevel(0, axis=1)

df

In [None]:
df_ = df.reset_index().rename(columns={'test_slim_ratio': 'width'})
df_['mode'] = mode
all_df_dict[mode] = df_

fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.lineplot(data=df_, x='SA', y='RA', hue='width', marker='o')
ax.set()
ax.grid(True)

# plt.tight_layout()
# out_file = os.path.join(cache_path, f'{data}_domain_joint.pdf')
# print(f"save fig => {out_file}")
# plt.savefig(out_file)

plt.show()

## Aggregate

In [None]:
agg = pd.concat([all_df_dict[k] for k in all_df_dict], ignore_index=True)
agg['mode'] = agg['mode'].apply(lambda n: n if n!='RT' else 'FedAvg+AT')
agg['mode'] = agg['mode'].apply(lambda n: n if n!='SplitMix' else 'SplitMixDAT')

In [None]:
df_ = agg

fig, ax = plt.subplots(1, 1, figsize=(3,2.5))
sns.lineplot(data=df_, x='SA', y='RA', hue='width', marker='o',
            style='mode')
cmap = plt.get_cmap()
# ax.axvline(df_[df_['mode']=='FedAvg']['SA'].max(), linestyle='--', alpha=1., color=cmap.colors[10])
# ax.axhline(df_[df_['mode']=='FedAvg']['RA'].max(), linestyle='--', color=cmap.colors[3])
ax.set(xlim=(.42, 0.64), ylim=(0.05, 0.46),
      yticks=[0.05, 0.25, 0.3, 0.35, 0.4, 0.45], ylabel='') # , title=data)
ax.grid(True)
legend = ax.get_legend()
legend.get_frame().set(alpha=0.1)

plt.tight_layout()
out_file = os.path.join(cache_path, f'{data}_joint.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()

# CIFAR10

In [None]:
data = 'CIFAR10'
domains = ['cifar10']

sweep_dict = {
#     'RT': "jyhong/FOAL_AT_Cifar10/sweeps/75rvm0co",
#     "RT 1.0-Net": "jyhong/FOAL_AT_Cifar10/sweeps/smank7jt",
#     "RT slim lmbd0": "jyhong/FOAL_slimmable_Cifar10/sweeps/10igr3ul",
#     'SplitMix': "jyhong/FOAL_AT_Cifar10/sweeps/kgm6e7k7",  # 220 epochs
#     'FedAvg lmbd0': "jyhong/SplitMix_release/sweeps/d6ua8kbt",
#     'FedAvg x0.5, x1': "jyhong/SplitMix_release/sweeps/4mv1qxnp",
#     'SplitMix': "jyhong/SplitMix_release/sweeps/tdzwg05m",  # 230 epochs
    'FedAvg x0.5, x1': "jyhong/SplitMix_release/sweeps/g8xmc74v",
    'SplitMix': "jyhong/SplitMix_release/sweeps/d26ifudn",
}
all_df_dict = {}

## FedAvg

In [None]:
mode = 'FedAvg x0.5, x1'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['width_scale', 'adv_lmbd', 'test_noise'], 
    summary_keys = ['avg test acc']
)
df = pd.DataFrame(df_dict)
df['test_noise'] = df['test_noise'].apply(lambda n: {'LinfPGD':'RA', 'none':'SA'}[n])
# df['width_scale'] = 1.
df = df.groupby(['width_scale', 'adv_lmbd', 'test_noise']).mean()
df = df.unstack('test_noise')
df = df.droplevel(0, axis=1)
df_1net = df
df_1net

In [None]:
df_ = df.reset_index()
df_['mode'] = 'FedAvg'
df_ = df_.rename(columns={'width_scale': 'width'})
all_df_dict[mode] = df_[df_['width']==0.125]
fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.lineplot(data=df_, x='SA', y='RA', hue='width', marker='o')
ax.set()
ax.grid(True)

# plt.tight_layout()
# out_file = os.path.join(cache_path, f'{data}_domain_joint.pdf')
# print(f"save fig => {out_file}")
# plt.savefig(out_file)

plt.show()

## SplitMix

In [None]:
mode = 'SplitMix'
api = wandb.Api()
sweep = api.sweep(sweep_dict[mode])

In [None]:
df_dict = fetch_config_summary(
    sweep.runs,
    config_keys = ['test_slim_ratio', 'test_adv_lmbd', 'test_noise'], 
    summary_keys = ['avg test acc']
)
df = pd.DataFrame(df_dict)
# df['mode'] = mode
# # df['slim_ratio'] = df['slim_ratio'] * 100
# df['width'] = df['slim_ratio']

# df = rearrange_by_domain(df)
# all_df[mode] = df
# # agg_df_dict[mode] = df  # [df['slim_ratio'] == 1.0]
df['test_noise'] = df['test_noise'].apply(lambda n: {'LinfPGD':'RA', 'none':'SA'}[n])
df = df.groupby(['test_slim_ratio', 'test_adv_lmbd', 'test_noise']).mean()
df = df.unstack('test_noise')
df = df.droplevel(0, axis=1)
df

In [None]:
df_ = df.reset_index().rename(columns={'test_slim_ratio': 'width'})
df_['mode'] = mode
all_df_dict[mode] = df_

fig, ax = plt.subplots(1, 1, figsize=(4,3))
sns.lineplot(data=df_, x='SA', y='RA', hue='width', marker='o')
ax.set()
ax.grid(True)

plt.show()

## Aggregate

In [None]:
agg = pd.concat([all_df_dict[k] for k in all_df_dict], ignore_index=True)
agg['mode'] = agg['mode'].apply(lambda n: n if n!='RT' else 'FedAvg+AT')

In [None]:
agg

In [None]:
df_ = agg

fig, ax = plt.subplots(1, 1, figsize=(3,2.5))
sns.lineplot(data=df_, x='SA', y='RA', hue='width', marker='o',
            style='mode')
ax.set(ylabel='')
legend = ax.get_legend()
ax.get_legend().remove()
ax.grid(True)

plt.tight_layout()
out_file = os.path.join(cache_path, f'{data}_joint.pdf')
print(f"save fig => {out_file}")
plt.savefig(out_file)

plt.show()