In [None]:
import os, sys
import torch
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from IPython.display import display
from scipy.stats import pearsonr

In [None]:
sns.set_style("ticks")
cmap = sns.color_palette()
sns.set_palette(sns.color_palette())

In [None]:
cache_path = '../fig/'
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

Experiment: `wandb sweep sweeps/ablation/digits_SplitMix_test.yaml`

In [None]:
sweep_ids = {
    "no track": "jyhong/SplitMix_release/sweeps/cpzoxxq9",
    "no track + refresh BN": "jyhong/SplitMix_release/sweeps/2di8eygl",
    "track + avg BN": "jyhong/SplitMix_release/sweeps/tr78ctgv",
    "track + local BN": "jyhong/SplitMix_release/sweeps/x11krsq4",
}

In [None]:
cfg_keys = ['test_slim_ratio', 'loss_temp', 'rescale_layer', 'rescale_init', 'no_track_stat',
            'test_refresh_bn', 'lbn']
df_dict = defaultdict(list)
for sweep_name, sweep_id in sweep_ids.items():
    print(f"collect: {sweep_name}")
    api = wandb.Api()
    sweep = api.sweep(sweep_id)
    for run in sweep.runs:
        if run.state != 'finished':
            continue
        for k in cfg_keys:
            if k == 'test_refresh_bn' and k not in run.config:
                df_dict[k].append(False)
            elif k == 'lbn' and k not in run.config:
                df_dict[k].append(False)
            else:
                df_dict[k].append(run.config[k])
        k = 'avg test acc'
        df_dict[k].append(run.summary[k])
df = pd.DataFrame(df_dict)
df = df.rename(columns={'no_track_stat': 'track_stat'})
df['track_stat'] = ~df['track_stat']

In [None]:
# df[df['test_slim_ratio']==0.125]

In [None]:
df_ = df[(df['loss_temp'] == 'none')].drop(['loss_temp'], axis=1)
df_ = df_.set_index(['test_slim_ratio', 'lbn', 'test_refresh_bn', 'track_stat', 'rescale_init', 'rescale_layer'])
df_ = df_.unstack('test_slim_ratio').reset_index()
# df_

In [None]:
def mark(b):
    return r'\cmark' if b else r'\xmark'
def bn_stat(track_stat, test_refresh_bn, lbn):
    if track_stat:
        if lbn:
            return 'locally tracked'
        else:
            return 'tracked'
    else:
        if test_refresh_bn:
            return 'post average'
        else:
            return 'batch average'
def bn_max_text(acc, col, track_stat, test_refresh_bn, lbn):
    if track_stat:
        bn_max_acc = df_[(df_['track_stat']==track_stat) & (df_['lbn']==lbn)][col].max()
    else:
        bn_max_acc = df_[(df_['track_stat']==track_stat) & (df_['test_refresh_bn']==test_refresh_bn)][col].max()
    col_max_acc = df_[col].max()
    s = f"{acc*100:.1f}\%"
    if np.isclose(acc, col_max_acc, atol=0.001):
        s = "$\mathbf{"+s+"}$"
    else:
        s = "$"+s+"$"
    if np.isclose(acc, bn_max_acc, atol=0.001):
        s = '\greycell ' + s
    return s
n_cond = len([c for c in df_.columns if c[1]==''])
for i in range(len(df_)):
#     print(df_.iloc[row])
    row = df_.iloc[i]
    accs = [bn_max_text(acc, df_.columns[j+n_cond],
                       row[('track_stat', '')], row[('test_refresh_bn', '')], row[('lbn', '')]) 
            for j, acc in enumerate(row[n_cond:])]
#     accs = ["$\mathbf{"+f"{acc*100:.1f}\%"+"}$" 
#             if np.isclose(v, df_[df_.columns[j+n_cond]].max(), atol=0.001) else f"${acc*100:.1f}\%$"
#             for j, acc in enumerate(row[n_cond:])]
    cond_strs = []
    cond_strs.append('%13s' % bn_stat(row[('track_stat', '')], row[('test_refresh_bn', '')], 
                                      row[('lbn', '')]))
    for col in ['rescale_init', 'rescale_layer']:
        cond_strs.append(mark(row[(col, '')]))
    print(' & '.join(cond_strs) + ' & ' + ' & '.join(accs) + " \\\\")