#  Plot paper graphs using precomputed evaluation results

In [None]:
import sys
import numpy as np
from matplotlib import rc
import matplotlib.pyplot as plt
import sklearn.metrics
from collections import defaultdict, OrderedDict
import os
from tqdm import tqdm
import pandas as pd
import seaborn as sns
pd.options.display.float_format = '{:0.2f}'.format
rc('font', **{'family': 'serif'})

from data import data_celebahq

%matplotlib inline

In [None]:
! mkdir -p pdfs

# utility functions

In [None]:
### plot format utilities ### 

sns.set(style='whitegrid')
sns.set_style({'font.family': 'serif'})

def save(f, filename, extra_artists=None):
    f.savefig(os.path.join('pdfs', filename), bbox_inches='tight', dpi=300, bbox_extra_artists=extra_artists)
    
def adjust_saturation(palette, s):
    new_palette = [sns.set_hls_values(color=p, h=None, l=None, s=s)
                   for p in palette]
    return new_palette

def bar_offset(group_size, n_groups, barwidth):
    # utility function to get x-axis values for grouped bar plots
    xvals = np.arange(1, n_groups+1)
    halfwidth = barwidth / 2
    offsets = [i * barwidth for i in range(group_size)]
    if group_size % 2 == 1:
        middle = offsets[int(len(offsets) / 2)]
    if group_size % 2 == 0:
        middle = np.mean(offsets[int(len(offsets) / 2)-1:int(len(offsets) / 2)+1])
    offsets = [off - middle for off in offsets]
    return [xvals + off for off in offsets]

def get_list_stats(l):
    mean = np.mean(l)
    stderr = np.std(l) / np.sqrt(len(l))
    n = len(l)
    return {'mean': mean, 'stderr': stderr, 'n': n}

def make_green_palette(n):
    return sns.light_palette([0.39215686, 0.61960784, 0.45098039], n_colors=n)

def make_blue_palette(n):
    return sns.light_palette([0.29803922, 0.44705882, 0.69019608], n_colors=n)

def make_purple_palette(n):
    return sns.light_palette([0.5058823529411764, 0.4470588235294118, 0.7019607843137254], n_colors=n)

def make_yellow_palette(n):
    return sns.light_palette([0.8666666666666667, 0.5176470588235295, 0.3215686274509804], n_colors=n)
    
def make_diverging_palette(n):
    return sns.color_palette("vlag", n_colors=n)

In [None]:
### data evaluation utilities ### 

def softmax_to_prediction(softmax_prediction):
    # converts softmax prediction to discrete class label
    if np.ndim(softmax_prediction) == 2:
        # N x ensembles binary prediction
        return (softmax_prediction > 0.5).astype(int)
    elif np.ndim(softmax_prediction) == 3:
        # N x ensembles x classes
        return np.argmax(softmax_prediction, axis=-1).squeeze()
    else:
        assert(False)
        
def get_accuracy_from_image_ensembles(data_file, key, resample=False, seed=0, 
                                      n_resamples=20, ens_size=32, verbose=True):
    # helper function to extract ensembled accuracy from image augmentations
    # e.g. image_ensemble_imcolor.npz or image_ensemble_imcrop.npz
    encoded_data = np.load(data_file)
    preds_original = softmax_to_prediction(encoded_data['original'])
    acc_original = sklearn.metrics.accuracy_score(encoded_data['label'], preds_original) * 100
    jitters = np.concatenate([encoded_data['original'], encoded_data[key]], axis=1)
    jitters = np.mean(jitters, axis=1, keepdims=True)
    preds_ensembled = softmax_to_prediction(jitters)
    acc_ensembled = sklearn.metrics.accuracy_score(encoded_data['label'], preds_ensembled) * 100
    
    resamples = None
    if resample:
        # sample num_samples batches with replacement, compute accuracy
        resamples = []
        rng = np.random.RandomState(seed)
        jitters = np.concatenate([encoded_data['original'], encoded_data[key]], axis=1)
        assert(jitters.shape[1] == ens_size) # sanity check
        for i in range(n_resamples):
            if verbose:
                print('*', end='')
            indices = rng.choice(jitters.shape[1], ens_size, replace=True)
            jitters_resampled = jitters[:, indices]
            jitters_resampled = np.mean(jitters_resampled, axis=1, keepdims=True)
            preds_ensembled = softmax_to_prediction(jitters_resampled)
            resamples.append(sklearn.metrics.accuracy_score(encoded_data['label'], preds_ensembled) * 100)
        if verbose:
            print("done")
    return {'acc_original': acc_original, 'acc_ensembled': acc_ensembled, 'resamples': resamples}
    
def sample_ensemble(raw_preds, ens_size=None, seed=None):
    # helper function to resample raw ensemble predictions
    # raw_preds = N x ens_size for  binary classification, or N x ens_size x classes
    # ens_size = number of samples to take preds for ensembling, None takes all all samples
    # seed = random seed to use when sampling with replacement, None takes samples in order
    if ens_size is None:
        ens_size = raw_preds.shape[1] # take all samples
    if seed is None:
        ensemble_preds = raw_preds[:, range(ens_size)] # take the samples in order
    else: # sample the given preds with replacement
        rng = np.random.RandomState(seed)
        indices = rng.choice(raw_preds.shape[1], ens_size, replace=True)
        ensemble_preds = raw_preds[:, indices]
    return ensemble_preds

def get_accuracy_from_npz(data_file, expt_name, weight=None, ens_size=None, seed=None, return_preds=False,
                          add_aug=False, aug_name='image_ensemble_imcrop', aug_key='imcrop'):
    # compute weighted accuracies combining original image and GAN reconstructions from an npz_file
    # option to use either single original image, or multiple image augmentations for the image views
    
    # setup
    encoded_data = np.load(data_file)
    df = defaultdict(list)
    expt_settings = os.path.basename(data_file).split('.')[0]
    if weight is not None:
        weights = [weight]
    else:
        weights = np.linspace(0, 1, 21)

    # determine image classification accuracy
    if not add_aug:
        # basic case: just load the image predictions from the data file
        preds_original = softmax_to_prediction(encoded_data['original'])
        original = encoded_data['original'] # full softmax distribution
    else:
        # ensemble also with the image augmentations data
        print('.', end='')
        im_aug_data = np.load(os.path.join(data_file.rsplit('/', 1)[0], '%s.npz' % aug_name))
        im_aug_ens = np.concatenate([im_aug_data['original'], im_aug_data[aug_key]], axis=1)
        im_aug_ens = sample_ensemble(im_aug_ens, ens_size, seed)
        im_aug_ens = np.mean(im_aug_ens, axis=1, keepdims=True) 
        preds_original = softmax_to_prediction(im_aug_ens)
        original = im_aug_ens # full softmax distribution
    acc_original = sklearn.metrics.accuracy_score(encoded_data['label'], preds_original) * 100
    
    # determine GAN reconstruction accuracy
    preds_reconstructed = softmax_to_prediction(encoded_data['reconstructed'])
    acc_reconstructed = sklearn.metrics.accuracy_score(encoded_data['label'], preds_reconstructed) * 100
    
    # determine GAN ensemble accuracy
    perturbed = encoded_data[expt_name] # N x ens_size x softmax distribution
    gan_ens = np.concatenate((encoded_data['reconstructed'], perturbed), axis=1)
    if ens_size == 0:
        gan_ens = original # dummy case: don't use gan reconstructed images
    else:
        gan_ens = sample_ensemble(gan_ens, ens_size, seed)    
    for weight in weights: # alpha weighting hyperparameter
        # for binary classification: original.shape = N x 1, gan_ens.shape = N x ens_size
        # for multi-class classification: original.shape = N x 1 x classes; gan_ens.shape = N x ens_size x classes
        ensembled = (1-weight) * original + weight * np.mean(gan_ens, axis=1, keepdims=True)
        preds_ensembled = softmax_to_prediction(ensembled)
        acc_ensembled = sklearn.metrics.accuracy_score(encoded_data['label'], preds_ensembled) * 100
        df['acc'].append(acc_ensembled)
        df['weight'].append(weight)
        df['expt_name'].append(expt_name)

    # table of expt_name x weight
    df = pd.DataFrame.from_dict(df)
    return_data = {'expt_settings': expt_settings, 
                   'acc_original': acc_original,
                   'acc_reconstructed': acc_reconstructed,
                   'ensemble_table': df}
    if return_preds:
        assert(len(weights) == 1)
        return_preds = {
            'original': original, # original softmax
            'reconstruction': gan_ens, # softmax of all gan views
            'ensembled': ensembled, # softmax of the weighted ensemble
            'pred_original': preds_original,
            'pred_reconstruction': preds_reconstructed,
            'pred_ensemble': preds_ensembled,
            'label': encoded_data['label'],
        }
        return return_data, return_preds
    return return_data

def compute_best_weight(val_data_file, test_data_file, expt_name, 
                        verbose=True, ens_size=None, seed=None,
                        add_aug=False, aug_name='image_ensemble_imcrop', aug_key='imcrop'):
    # given a val data file and a test data file, find the best weighting between 
    # image view and GAN-generated views on the val split, and use that weighting on the test split
    
    # sanity checks
    assert('val' in val_data_file)
    assert('test' in test_data_file)
    val_accuracy_info = get_accuracy_from_npz(val_data_file, expt_name, 
                                              weight=None, ens_size=ens_size, seed=seed,
                                              add_aug=add_aug, aug_name=aug_name, aug_key=aug_key)
    val_ensemble_table = val_accuracy_info['ensemble_table']
    # find the optimal ensemble weight from validation
    best_val_setting = val_ensemble_table.iloc[val_ensemble_table['acc'].argsort().iloc[-1], :]
    
    if verbose:
        print("Val original %0.4f Val reconstructed %0.4f" % 
              (val_accuracy_info['acc_original'], val_accuracy_info['acc_reconstructed']))
        print("%0.4f @ %0.4f %s" % (best_val_setting['acc'], best_val_setting['weight'], best_val_setting['expt_name']))
    
    test_accuracy_info = get_accuracy_from_npz(test_data_file, expt_name, 
                                               weight=best_val_setting['weight'], 
                                               ens_size=ens_size, seed=seed,
                                               add_aug=add_aug, aug_name=aug_name, aug_key=aug_key)
    test_ensemble_table = test_accuracy_info['ensemble_table']
    assert(test_ensemble_table.shape[0] == 1) # it should only evaluate at the specified weight
    test_setting_from_val = test_ensemble_table.iloc[0, :] # gets the single element from the table
    
    if verbose:
        print("Test original %0.4f Test reconstructed %0.4f" % 
              (test_accuracy_info['acc_original'], test_accuracy_info['acc_reconstructed']))
        print("%0.4f @ %0.4f %s" % (test_setting_from_val['acc'], test_setting_from_val['weight'],
                                    test_setting_from_val['expt_name']))
    
    return {'val_info': val_accuracy_info, 'test_info': test_accuracy_info, 
            'val_setting': best_val_setting, 'test_setting': test_setting_from_val}


def resample_wrapper(val_file, test_file, expt_name, ens_size, add_aug, n_resamples=20, verbose=False,
                     aug_name='image_ensemble_imcrop', aug_key='imcrop'):
    # due to randomness in sampling, it helps to sample multiple times and average the results for stability
    # this function wraps compute_best_weight(), using the specified ensemble size and resampling multiple times
    
    val_samples = []
    test_samples = []
    weights = []
    assert(ens_size==31 or (ens_size==16 and add_aug==True))
    # using ens_size=31 so that with the original image, total size=32; or 16 image views and 16 GAN views
    for s in range(n_resamples):
        res = compute_best_weight(val_file, test_file, expt_name, verbose=verbose, add_aug=add_aug, 
                                  ens_size=ens_size, seed=s, aug_name=aug_name, aug_key=aug_key)
        val_samples.append(res['val_setting']['acc'])
        test_samples.append(res['test_setting']['acc'])
        weights.append(res['test_setting']['weight'])
    return {'val_avg': np.mean(val_samples),
            'test_avg': np.mean(test_samples), 
            'val_stderr': np.std(val_samples) / np.sqrt(n_resamples),
            'test_stderr': np.std(test_samples) / np.sqrt(n_resamples),
            'weights': weights,
            'val_acc_original': res['val_info']['acc_original'],
            'test_acc_original': res['test_info']['acc_original'],
            'val_acc_rec': res['val_info']['acc_reconstructed'],
            'test_acc_rec': res['test_info']['acc_reconstructed'],
           }

# cars domain

In [None]:
# sample 32 crops of images, compare to combination of 16 crops of images and 16 crops of gan
df = defaultdict(list)

for i, classifier in enumerate(['imageclassifier', 'latentclassifier', 
                                'latentclassifier_stylemix_fine']):
    print(classifier)
    val_expts = [
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_isotropic_coarse_tensortransform.npz',
         ('isotropic_coarse_1.00', 'isotropic_coarse_1.50', 'isotropic_coarse_2.00'), 'Isotropic Coarse'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_isotropic_fine_tensortransform.npz',
         ('isotropic_fine_0.30', 'isotropic_fine_0.50', 'isotropic_fine_0.70'), 'Isotropic Fine'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_pca_coarse_tensortransform.npz',
         ('pca_coarse_1.00', 'pca_coarse_2.00', 'pca_coarse_3.00'), 'PCA Coarse'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_pca_fine_tensortransform.npz',
         ('pca_fine_1.00', 'pca_fine_2.00', 'pca_fine_3.00'), 'PCA Fine'),
#         (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_stylemix_coarse_tensortransform.npz',
#          ('stylemix_coarse',), 'Style-mix Coarse'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_stylemix_fine_tensortransform.npz',
         ('stylemix_fine',), 'Style-mix Fine'),
    ]
    test_expts = [(x.replace('_val/', '_test/'), y, z) for x, y, z in val_expts]

    for val, test in zip(val_expts, test_expts):
        expt_settings = []
        print(val[-1])
        for expt_name in val[1]:
            resampled_accs = resample_wrapper(val[0], test[0], expt_name, ens_size=16, 
                                              add_aug=True, aug_name='image_ensemble_imcrop', verbose=False)            
            resampled_accs['expt_name'] = expt_name
            expt_settings.append(resampled_accs)
            print("done")            
            
        best_expt = max(expt_settings, key=lambda x: x['val_avg']) # take the val accuracy, avged over samples
        df['classifier'].append(classifier+'_crop')    
        df['acc'].append(best_expt['test_avg'])
        df['stderr'].append(best_expt['test_stderr'])
        df['expt'].append(best_expt['expt_name'])
        df['expt_group'].append(test[2])

df = pd.DataFrame.from_dict(df)

In [None]:
df

In [None]:
# plot it
f, ax = plt.subplots(1, 1, figsize=(7, 5))

data_file = f'results/precomputed_evaluations/car/output/imageclassifier_test/image_ensemble_imcrop.npz'
im_crops = get_accuracy_from_image_ensembles(data_file, 'imcrop', resample=True)

group_size = 5
bar_width=0.15
n_groups = 3
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = make_blue_palette(3)[1:] + make_green_palette(3)[1:] + make_purple_palette(3)[1:]

resample_stats = get_list_stats(im_crops['resamples'])
ind = 0.2
ax.axhline(im_crops['acc_ensembled'], color='k', linestyle=':', label='Original Images')

xticklabels = []
for i in range(group_size):
    indices = np.arange(i, n_groups*group_size, group_size)
    bar_height = df.iloc[indices]['acc']
    bar_err = df.iloc[indices]['stderr']
    assert(all([x == df.iloc[indices[0]]['expt_group'] for x in df.iloc[indices]['expt_group']]))
    ax.bar(bar_offsets[i], bar_height, width=bar_width, color=palette[i], yerr=bar_err,
           label=df.iloc[indices[0]]['expt_group'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
    xticklabels.append(df.iloc[indices[0]]['classifier'].replace('_', '\n'))

ax.set_ylim([94, 99])
ax.set_xticks(list(range(1, n_groups+1)))
handles,labels = ax.get_legend_handles_labels()
# reorder it so it looks nicer
order = [0, 3, 1, 4, 2, 5]
handles = [handles[i] for i in order]
labels = [labels[i] for i in order]
ax.legend(handles, labels, loc='upper center', ncol=3, prop={'size': 11})

# ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.3), ncol=3, prop={'size': 11})
# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': 14})
ax.set_xticklabels(['Original\nImages', 'GAN\nRecontructions', 'Style-mix Fine\nAugmentations'], fontsize=12)
ax.set_xlabel('Classifier training distribution', fontsize=16)
ax.set_ylabel('Classification Accuracy', fontsize=16)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(14) 
ax.set_title('Cars', fontsize=16)

f.tight_layout()
save(f, 'graph_cars_v2.pdf')

In [None]:
# sample 32 crops of images, compare to combination of 16 crops of images and 16 crops of gan
# using all experiment settings for supplemental

df = defaultdict(list)
im_crop_data = []

for i, classifier in enumerate(['imageclassifier', 'latentclassifier', 
                                'latentclassifier_isotropic_fine', 'latentclassifier_isotropic_coarse',
                                'latentclassifier_pca_fine', 'latentclassifier_pca_coarse',
                                'latentclassifier_stylemix_fine', 'latentclassifier_stylemix_coarse']):
    print(classifier)
    val_expts = [
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_isotropic_coarse_tensortransform.npz', 
         ('isotropic_coarse_1.00', 'isotropic_coarse_1.50', 'isotropic_coarse_2.00'), 'Isotropic Coarse'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_isotropic_fine_tensortransform.npz', 
         ('isotropic_fine_0.30', 'isotropic_fine_0.50', 'isotropic_fine_0.70'), 'Isotropic Fine'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_pca_coarse_tensortransform.npz', 
         ('pca_coarse_1.00', 'pca_coarse_2.00', 'pca_coarse_3.00'), 'PCA Coarse'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_pca_fine_tensortransform.npz',
         ('pca_fine_1.00', 'pca_fine_2.00', 'pca_fine_3.00'), 'PCA Fine'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_stylemix_coarse_tensortransform.npz', 
         ('stylemix_coarse',), 'Style-mix Coarse'),
        (f'results/precomputed_evaluations/car/output/{classifier}_val/gan_ensemble_stylemix_fine_tensortransform.npz', 
         ('stylemix_fine',), 'Style-mix Fine'),
    ]

    test_expts = [(x.replace('_val/', '_test/'), y, z) for x, y, z in val_expts]
    
    data_file = f'results/precomputed_evaluations/car/output/{classifier}_test/image_ensemble_imcrop.npz'
    im_crop_data.append(get_accuracy_from_image_ensembles(data_file, 'imcrop', resample=True))

    for val, test in zip(val_expts, test_expts):
        expt_settings = []
        print(val[-1])
        for expt_name in val[1]:
            resampled_accs = resample_wrapper(val[0], test[0], expt_name, ens_size=16, 
                                              add_aug=True, aug_name='image_ensemble_imcrop', verbose=False)            
            resampled_accs['expt_name'] = expt_name
            expt_settings.append(resampled_accs)
            print("done")            
            
        best_expt = max(expt_settings, key=lambda x: x['val_avg']) # take the val accuracy, avged over samples
        df['classifier'].append(classifier+'_crop')    
        df['acc'].append(best_expt['test_avg'])
        df['stderr'].append(best_expt['test_stderr'])
        df['expt'].append(best_expt['expt_name'])
        df['expt_group'].append(test[2])

df = pd.DataFrame.from_dict(df)

In [None]:
df

In [None]:

# plot it
f, ax = plt.subplots(1, 1, figsize=(14, 6))

group_size = 8
bar_width=0.1
n_groups = 8
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = make_yellow_palette(3)[1:] + make_blue_palette(3)[1:] + make_green_palette(3)[1:] + make_purple_palette(3)[1:]

# resample_stats = get_list_stats(im_crops['resamples'])
ind = 0.2
# ax.axhline(im_crops['acc_ensembled'], color='k', linestyle=':', label='Original Images')
ax.bar(bar_offsets[0], [x['acc_original'] for x in im_crop_data], width=bar_width, color=palette[0], 
       label='Image Single Crop', edgecolor=(0.5, 0.5, 0.5), capsize=5)
ax.bar(bar_offsets[1], [get_list_stats(x['resamples'])['mean'] for x in im_crop_data],
       width=bar_width, color=palette[1], yerr=[get_list_stats(x['resamples'])['stderr'] for x in im_crop_data],
       label='Image Multi Crop', edgecolor=(0.5, 0.5, 0.5), capsize=5)

xticklabels = []
for i in range(6):
    indices = np.arange(i, n_groups*6, 6)
    bar_height = df.iloc[indices]['acc']
    bar_err = df.iloc[indices]['stderr']
    assert(all([x == df.iloc[indices[0]]['expt_group'] for x in df.iloc[indices]['expt_group']]))
    ax.bar(bar_offsets[i+2], bar_height, width=bar_width, color=palette[i+2], yerr=bar_err,
           label=df.iloc[indices[0]]['expt_group'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
    xticklabels.append(df.iloc[indices[0]]['classifier'].replace('_', '\n'))

ax.set_ylim([94, 100])
ax.set_xticks(list(range(1, n_groups+1)))
handles,labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc='upper center', ncol=4, prop={'size': 11})
# ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=4, prop={'size': 11})
ax.set_xticklabels(['Original\nImages', 'GAN\nRecontructions',
                    'Isotropic Fine\nAugmentations', 'Isotropic Coarse\nAugmentations',
                    'PCA Fine\nAugmentations', 'PCA Coarse\nAugmentations', 
                    'Style-mix Fine\nAugmentations', 'Style-mix Coarse\nAugmentations'], fontsize=12)
ax.set_xlabel('Classifier training distribution', fontsize=16)
ax.set_ylabel('Classification Accuracy', fontsize=16)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(14) 
ax.set_title('Cars', fontsize=16)

f.tight_layout()
save(f, 'sm_graph_cars_all_settings.pdf')

# cat face classifier

In [None]:
# figure for main: cat face augmentations (does not use crop)
df = defaultdict(list)

for i, classifier in enumerate(['imageclassifier', 'latentclassifier', 
                                'latentclassifier_stylemix_coarse']):
    print(classifier)
    val_expts = [
        # also tried without _tensortransform, it's similar
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_isotropic_coarse_tensortransform.npz', 
         ('isotropic_coarse_0.50', 'isotropic_coarse_0.70', 'isotropic_coarse_1.00'), 'Isotropic Coarse'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_isotropic_fine_tensortransform.npz', 
         ('isotropic_fine_0.10', 'isotropic_fine_0.20', 'isotropic_fine_0.30'), 'Isotropic Fine'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_pca_coarse_tensortransform.npz', 
         ('pca_coarse_0.50', 'pca_coarse_0.70', 'pca_coarse_1.00'), 'PCA Coarse'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_pca_fine_tensortransform.npz',  
         ('pca_fine_0.50', 'pca_fine_0.70', 'pca_fine_1.00'), 'PCA Fine'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_stylemix_coarse_tensortransform.npz', 
         ('stylemix_coarse',), 'Style-mix Coarse'),
#         (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_stylemix_fine_tensortransform.npz', 
#          ('stylemix_fine',), 'Style-mix Fine'),
    ]
    test_expts = [(x.replace('_val/', '_test/'), y, z) for x, y, z in val_expts]

    for val, test in zip(val_expts, test_expts):
        expt_settings = []
        print(val[-1])
        for expt_name in val[1]:
            resampled_accs = resample_wrapper(val[0], test[0], expt_name, ens_size=31, 
                                              add_aug=False, verbose=False)            
            resampled_accs['expt_name'] = expt_name
            expt_settings.append(resampled_accs)
            print("done")            
            
        best_expt = max(expt_settings, key=lambda x: x['val_avg']) # take the val accuracy, avged over samples
        df['classifier'].append(classifier+'_crop')    
        df['acc'].append(best_expt['test_avg'])
        df['stderr'].append(best_expt['test_stderr'])
        df['expt'].append(best_expt['expt_name'])
        df['expt_group'].append(test[2])

df = pd.DataFrame.from_dict(df)

In [None]:
df

In [None]:
# plot it
f, ax = plt.subplots(1, 1, figsize=(7, 5))

data_file = f'results/precomputed_evaluations/cat/output/imageclassifier_test/image_ensemble_imcrop.npz'
im_s = get_accuracy_from_image_ensembles(data_file, 'imcrop', resample=True)

group_size = 5
bar_width=0.15
n_groups = 3
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = make_blue_palette(3)[1:] + make_green_palette(3)[1:] + make_purple_palette(3)[1:]

resample_stats = get_list_stats(im_s['resamples'])
ind = 0.2
# note: using acc_original here, as it's better
ax.axhline(im_s['acc_original'], color='k', linestyle=':', label='Original Images')

xticklabels = []
for i in range(group_size):
    indices = np.arange(i, n_groups*group_size, group_size)
    bar_height = df.iloc[indices]['acc']
    bar_err = df.iloc[indices]['stderr']
    assert(all([x == df.iloc[indices[0]]['expt_group'] for x in df.iloc[indices]['expt_group']]))
    ax.bar(bar_offsets[i], bar_height, width=bar_width, color=palette[i], yerr=bar_err,
           label=df.iloc[indices[0]]['expt_group'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
    xticklabels.append(df.iloc[indices[0]]['classifier'].replace('_', '\n'))

ax.set_ylim([90, 95])
ax.set_xticks(list(range(1, n_groups+1)))
handles,labels = ax.get_legend_handles_labels()
# reorder it so it looks nicer
order = [0, 3, 1, 4, 2, 5]
handles = [handles[i] for i in order]
labels = [labels[i] for i in order]
ax.legend(handles, labels, loc='upper center', ncol=3, prop={'size': 10.8})
# ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.3), ncol=3, prop={'size': 11})
# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': 14})
ax.set_xticklabels(['Original\nImages', 'GAN\nRecontructions', 'Style-mix Coarse\nAugmentations'], fontsize=12)
ax.set_xlabel('Classifier training distribution', fontsize=16)
ax.set_ylabel('Classification Accuracy', fontsize=16)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(14) 
ax.set_title('Cats', fontsize=16)

f.tight_layout()
save(f, 'graph_cats_v2.pdf')

In [None]:
# all settings for the supplemental
df = defaultdict(list)
im_crop_data = []

for i, classifier in enumerate(['imageclassifier', 'latentclassifier', 
                                'latentclassifier_isotropic_fine', 'latentclassifier_isotropic_coarse',
                                'latentclassifier_pca_fine', 'latentclassifier_pca_coarse',
                                'latentclassifier_stylemix_fine', 'latentclassifier_stylemix_coarse']):
    print(classifier)
    val_expts = [
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_isotropic_coarse_tensortransform.npz', 
         ('isotropic_coarse_0.50', 'isotropic_coarse_0.70', 'isotropic_coarse_1.00'), 'Isotropic Coarse'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_isotropic_fine_tensortransform.npz', 
         ('isotropic_fine_0.10', 'isotropic_fine_0.20', 'isotropic_fine_0.30'), 'Isotropic Fine'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_pca_coarse_tensortransform.npz', 
         ('pca_coarse_0.50', 'pca_coarse_0.70', 'pca_coarse_1.00'), 'PCA Coarse'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_pca_fine_tensortransform.npz',  
         ('pca_fine_0.50', 'pca_fine_0.70', 'pca_fine_1.00'), 'PCA Fine'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_stylemix_coarse_tensortransform.npz', 
         ('stylemix_coarse',), 'Style-mix Coarse'),
        (f'results/precomputed_evaluations/cat/output/{classifier}_val/gan_ensemble_stylemix_fine_tensortransform.npz', 
         ('stylemix_fine',), 'Style-mix Fine'),
    ]
    test_expts = [(x.replace('_val/', '_test/'), y, z) for x, y, z in val_expts]
    data_file = f'results/precomputed_evaluations/cat/output/{classifier}_test/image_ensemble_imcrop.npz'
    im_crop_data.append(get_accuracy_from_image_ensembles(data_file, 'imcrop', resample=True))

    for val, test in zip(val_expts, test_expts):
        expt_settings = []
        print(val[-1])
        for expt_name in val[1]:
            resampled_accs = resample_wrapper(val[0], test[0], expt_name, ens_size=31, 
                                              add_aug=False, verbose=False)            
            resampled_accs['expt_name'] = expt_name
            expt_settings.append(resampled_accs)
            print("done")            
            
        best_expt = max(expt_settings, key=lambda x: x['val_avg']) # take the val accuracy, avged over samples
        df['classifier'].append(classifier)    
        df['acc'].append(best_expt['test_avg'])
        df['stderr'].append(best_expt['test_stderr'])
        df['expt'].append(best_expt['expt_name'])
        df['expt_group'].append(test[2])

df = pd.DataFrame.from_dict(df)

In [None]:
df

In [None]:
# plot it
f, ax = plt.subplots(1, 1, figsize=(14, 6))

group_size = 8
bar_width=0.1
n_groups = 8
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = make_yellow_palette(3)[1:] + make_blue_palette(3)[1:] + make_green_palette(3)[1:] + make_purple_palette(3)[1:]

ind = 0.2
# ax.axhline(im_crops['acc_ensembled'], color='k', linestyle=':', label='Original Images')
ax.bar(bar_offsets[0], [x['acc_original'] for x in im_crop_data], width=bar_width, color=palette[0], 
       label='Image Single Crop', edgecolor=(0.5, 0.5, 0.5), capsize=5)
ax.bar(bar_offsets[1], [get_list_stats(x['resamples'])['mean'] for x in im_crop_data],
       width=bar_width, color=palette[1], yerr=[get_list_stats(x['resamples'])['stderr'] for x in im_crop_data],
       label='Image Multi Crop', edgecolor=(0.5, 0.5, 0.5), capsize=5)


xticklabels = []
for i in range(6):
    indices = np.arange(i, n_groups*6, 6)
    bar_height = df.iloc[indices]['acc']
    bar_err = df.iloc[indices]['stderr']
    assert(all([x == df.iloc[indices[0]]['expt_group'] for x in df.iloc[indices]['expt_group']]))
    ax.bar(bar_offsets[i+2], bar_height, width=bar_width, color=palette[i+2], yerr=bar_err,
           label=df.iloc[indices[0]]['expt_group'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
    xticklabels.append(df.iloc[indices[0]]['classifier'].replace('_', '\n'))

ax.set_ylim([90, 94])
ax.set_xticks(list(range(1, n_groups+1)))
handles,labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc='upper center', ncol=4, prop={'size': 11})
# ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=4, prop={'size': 11})
ax.set_xticklabels(['Original\nImages', 'GAN\nRecontructions',
                    'Isotropic Fine\nAugmentations', 'Isotropic Coarse\nAugmentations',
                    'PCA Fine\nAugmentations', 'PCA Coarse\nAugmentations', 
                    'Style-mix Fine\nAugmentations', 'Style-mix Coarse\nAugmentations'], fontsize=12)
ax.set_xlabel('Classifier training distribution', fontsize=16)
ax.set_ylabel('Classification Accuracy', fontsize=16)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(14) 
ax.set_title('Cats', fontsize=16)

f.tight_layout()
save(f, 'sm_graph_cats_all_settings.pdf')

# stylegan faces 40 attributes

In [None]:
attr_mean = data_celebahq.attr_celebahq.mean(axis=0)[:-1]
attr_order = sorted([(abs(v-0.5), v, k) for k, v in attr_mean.to_dict().items()])

table_dict = OrderedDict([])
table_accs = OrderedDict([])


for i, (_, _, attr) in enumerate(tqdm(attr_order[:40])):
    # print('========== %s ==========' % attr)
    
    # gan jitter
    val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine.npz'
    test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/gan_ensemble_stylemix_fine.npz'
    expt_name = 'stylemix_fine'
    # resample
    resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                      add_aug=False, verbose=False)   
    val_orig = resampled_accs['val_acc_original']
    val_top1 = resampled_accs['val_avg']
    test_orig = resampled_accs['test_acc_original']
    test_top1_from_val = resampled_accs['test_avg']
    

    # gan jitter with color/crop jitter
    val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine_tensortransform.npz'
    test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/gan_ensemble_stylemix_fine_tensortransform.npz'
    expt_name = 'stylemix_fine'
    # resample
    resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                      add_aug=False, verbose=False)   
    val_orig_mix = resampled_accs['val_acc_original']
    val_top1_mix = resampled_accs['val_avg']
    test_orig_mix = resampled_accs['test_acc_original']
    test_top1_from_val_mix = resampled_accs['test_avg']
    
    # color jitter
    val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/image_ensemble_imcolor.npz'
    im_ensemble = get_accuracy_from_image_ensembles(val_file, 'imcolor', resample=True, verbose=False)
    val_color_orig = im_ensemble['acc_original']
    val_color_ens = np.mean(im_ensemble['resamples']) # im_ensemble['acc_ensembled']
    test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/image_ensemble_imcolor.npz'
    im_ensemble = get_accuracy_from_image_ensembles(test_file, 'imcolor', resample=True, verbose=False)
    test_color_orig = im_ensemble['acc_original']
    test_color_ens = np.mean(im_ensemble['resamples']) # im_ensemble['acc_ensembled']
    
    # crop jitter
    val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/image_ensemble_imcrop.npz'
    im_ensemble = get_accuracy_from_image_ensembles(val_file, 'imcrop', resample=True, verbose=False)
    val_crop_orig = im_ensemble['acc_original']
    val_crop_ens = np.mean(im_ensemble['resamples']) # im_ensemble['acc_ensembled']
    test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/image_ensemble_imcrop.npz'
    im_ensemble = get_accuracy_from_image_ensembles(test_file, 'imcrop', resample=True, verbose=False)
    test_crop_orig = im_ensemble['acc_original']
    test_crop_ens = np.mean(im_ensemble['resamples']) # im_ensemble['acc_ensembled']
    
    # sanity check
    assert(test_color_orig == test_orig)
    assert(test_crop_orig == test_orig)
    assert(test_orig_mix == test_orig)
    assert(val_color_orig == val_orig)
    assert(val_crop_orig == val_orig)
    assert(val_orig_mix == val_orig)
    
    val_labels = ['Val Orig', 'Val Color', 'Val Crop', 'Val GAN', 'Val Combined']
    val_values = [val_orig, val_color_ens, val_crop_ens, val_top1, val_top1_mix]
    val_diffs = [x - val_values[0] for x in val_values]
    
    test_labels = ['Test Orig', 'Test Color', 'Test Crop', 'Test GAN', 'Test Combined']
    test_values = [test_orig, test_color_ens, test_crop_ens, test_top1_from_val, test_top1_from_val_mix]
    test_diffs = [x - test_values[0] for x in test_values]
    table_dict[attr] = val_diffs + test_diffs
    table_accs[attr] = val_values + test_values

In [None]:
table = pd.DataFrame.from_dict(table_dict, orient='index', columns=val_labels+test_labels)
table = table.append(table.mean(axis=0).rename('Avg'))
std = table.iloc[:-1, :].std(axis=0).rename('Std')
print(std / np.sqrt(40))
display(table.iloc[-1:, :])

In [None]:
table_acc = pd.DataFrame.from_dict(table_accs, orient='index', columns=val_labels+test_labels)
table_acc = table_acc.append(table_acc.mean(axis=0).rename('Avg'))
std_acc = table_acc.iloc[:-1, :].std(axis=0).rename('Std')
print(std_acc / np.sqrt(40))
display(table_acc.iloc[-1:, :])

In [None]:
df = table_acc.iloc[[-1], 5:].T
df = df.reset_index()
display(df)
f, ax = plt.subplots(1, 1, figsize=(6, 3))
palette = adjust_saturation(make_blue_palette(3), 0.3)
ax.bar(np.arange(len(df)), df.loc[:, 'Avg'], color=palette[-1], edgecolor=(0.5, 0.5, 0.5))
ax.set_ylim([88.5, 89.5])
ax.set_xticks(range(5))
ax.set_xticklabels(['Single\nImage', 'Color\nJitter', 'Crop\nJitter', 'Style-mix\nJitter', 'Combined\nJitter'], 
                   fontsize=12)
ax.set_ylabel('Classification Accuracy', fontsize=16)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(12) 
ax.set_xlabel('')
ax.set_xlim([-0.7, 4.7])
save(f, 'graph_face_testaug.pdf')

In [None]:
f, ax = plt.subplots(1, 1, figsize=(6, 3))

diffs = table.iloc[:-1, 5:]
bar_height = diffs.mean(axis=0)
bar_err = diffs.std(axis=0) / np.sqrt(diffs.shape[0])
palette = adjust_saturation(make_blue_palette(3), 0.3)
ax.bar(range(5), bar_height, edgecolor=(0.5, 0.5, 0.5), yerr=bar_err, color=palette[-1], capsize=5)
ax.set_xticks(range(5))
ax.set_xticklabels(['Single\nImage', 'Color\nJitter', 'Crop\nJitter', 'Style-mix\nJitter', 'Combined\nJitter'], 
                   fontsize=12)
ax.set_ylabel('Accuracy Difference', fontsize=16)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(12) 
ax.set_xlabel('')
ax.set_xlim([-0.7, 4.7])
ax.set_ylim([-0.1, 0.2])
save(f, 'graph_face_testaug_diffs.pdf')

# stylegan idinvert

In [None]:
attr_mean = data_celebahq.attr_celebahq.mean(axis=0)[:-1]
attr_order = sorted([(abs(v-0.5), v, k) for k, v in attr_mean.to_dict().items()])

table_dict = OrderedDict([])
table_accs = OrderedDict([])

for i, (_, _, attr) in enumerate(tqdm(attr_order[:40])):
    # print('========== %s ==========' % attr)
    
    # gan jitter
    val_file = f'results/precomputed_evaluations/celebahq-idinvert/output/{attr}_val/gan_ensemble_stylemix_fine.npz'
    test_file = f'results/precomputed_evaluations/celebahq-idinvert/output/{attr}_test/gan_ensemble_stylemix_fine.npz'
    expt_name = 'stylemix_fine'
    # resample
    resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                      add_aug=False, verbose=False)   
    val_orig = resampled_accs['val_acc_original']
    val_top1 = resampled_accs['val_avg']
    test_orig = resampled_accs['test_acc_original']
    test_top1_from_val = resampled_accs['test_avg']

    # gan jitter with color/crop jitter
    val_file = f'results/precomputed_evaluations/celebahq-idinvert/output/{attr}_val/gan_ensemble_stylemix_fine_tensortransform.npz'
    test_file = f'results/precomputed_evaluations/celebahq-idinvert/output/{attr}_test/gan_ensemble_stylemix_fine_tensortransform.npz'
    expt_name = 'stylemix_fine'
    # resample
    resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                      add_aug=False, verbose=False)   
    val_orig_mix = resampled_accs['val_acc_original']
    val_top1_mix = resampled_accs['val_avg']
    test_orig_mix = resampled_accs['test_acc_original']
    test_top1_from_val_mix = resampled_accs['test_avg']
    
    # sanity check
    assert(test_orig_mix == test_orig)
    assert(val_orig_mix == val_orig)
    
    val_labels = ['Val Orig', 'Val GAN', 'Val Combined']
    val_values = [val_orig, val_top1, val_top1_mix]
    val_diffs = [x - val_values[0] for x in val_values]
    
    test_labels = ['Test Orig',  'Test GAN', 'Test Combined']
    test_values = [test_orig,test_top1_from_val, test_top1_from_val_mix]
    test_diffs = [x - test_values[0] for x in test_values]
    table_dict[attr] = val_diffs + test_diffs
    table_accs[attr] = val_values + test_values

In [None]:
table_idinvert = pd.DataFrame.from_dict(table_dict, orient='index', columns=val_labels+test_labels)
table_idinvert = table_idinvert.append(table_idinvert.mean(axis=0).rename('Avg'))
std = table_idinvert.iloc[:-1, :].std(axis=0).rename('Std')
print(std / np.sqrt(40))

display(table_idinvert.iloc[-1:, :])

In [None]:
table_idinvert_acc = pd.DataFrame.from_dict(table_accs, orient='index', columns=val_labels+test_labels)
table_idinvert_acc = table_idinvert_acc.append(table_idinvert_acc.mean(axis=0).rename('Avg'))
std_acc = table_idinvert_acc.iloc[:-1, :].std(axis=0).rename('Std')
print(std_acc / np.sqrt(40))
display(table_idinvert_acc.iloc[-1:, :])

In [None]:
f, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(table['Test GAN'], table_idinvert['Test GAN'], '*', label='GAN Aug')
ax.plot(table['Test Combined'], table_idinvert['Test Combined'], '*', label='Combined Aug')
ax.set_xlabel('Pre-trained FFHQ + Encoder\nAccuracy Difference', fontsize=14)
ax.set_ylabel('ID-Invert\nAccuracy Difference', fontsize=14)
ax.legend(loc='lower right')
from scipy.stats import pearsonr

corr, pval = pearsonr(table['Test GAN'].to_list() + table['Test Combined'].to_list(), 
                   table_idinvert['Test GAN'].to_list() + table_idinvert['Test Combined'].to_list())
print('Pearsons correlation: %.3f pval %f' % (corr, pval))
save(f, 'sm_graph_face_idinvert.pdf')

# different training approaches

In [None]:
# different training approaches
attr_mean = data_celebahq.attr_celebahq.mean(axis=0)[:-1]
attr_order = sorted([(abs(v-0.5), v, k) for k, v in attr_mean.to_dict().items()])

table_dict = OrderedDict([])
table_accs = OrderedDict([])

for i, (_, _, attribute) in enumerate(tqdm(attr_order)):
    
    val_values = []
    val_diffs = []
    test_values = []
    test_diffs = []
    val_labels = ['Val ' + train_method + ' ' + eval_method for train_method in 
                  ['Im', 'latent', 'latent_stylemix', 'latent_stylemix_crop'] for eval_method in ['Single', 'GAN Ens', 'Combined Ens']]
    test_labels = ['Test ' + train_method + ' ' + eval_method for train_method in 
                   ['Im', 'latent', 'latent_stylemix', 'latent_stylemix_crop'] for eval_method in ['Single', 'GAN Ens', 'Combined Ens']]
    for suffix in ['', '__latent', '__latent_stylemix_fine', '__latent_stylemix_fine_crop']:
        attr = attribute + suffix
        # print('========== %s ==========' % attr)

        # gan jitter
        val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine.npz'
        test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/gan_ensemble_stylemix_fine.npz'
        expt_name = 'stylemix_fine'
        # resample
        resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                          add_aug=False, verbose=False)   
        val_orig = resampled_accs['val_acc_original']
        val_top1 = resampled_accs['val_avg']
        test_orig = resampled_accs['test_acc_original']
        test_top1_from_val = resampled_accs['test_avg']

        # gan jitter with color/crop jitter
        val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine_tensortransform.npz'
        test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/gan_ensemble_stylemix_fine_tensortransform.npz'
        expt_name = 'stylemix_fine'
        # resample
        resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                          add_aug=False, verbose=False)   
        val_orig_mix = resampled_accs['val_acc_original']
        val_top1_mix = resampled_accs['val_avg']
        test_orig_mix = resampled_accs['test_acc_original']
        test_top1_from_val_mix = resampled_accs['test_avg']
    
        # sanity check
        assert(test_orig_mix == test_orig)
        assert(val_orig_mix == val_orig)

        new_val_values = [val_orig, val_top1, val_top1_mix]
        new_test_values = [test_orig, test_top1_from_val, test_top1_from_val_mix]
        val_values.extend(new_val_values)
        test_values.extend(new_test_values)
        val_diffs.extend([x - val_values[0] for x in new_val_values])
        test_diffs.extend([x - test_values[0] for x in new_test_values])
        
    table_dict[attribute] = val_diffs + test_diffs
    table_accs[attribute] = val_values + test_values

In [None]:
table = pd.DataFrame.from_dict(table_dict, orient='index', columns=val_labels+test_labels)
table = table.append(table.mean(axis=0).rename('Avg'))
std = table.iloc[:-1, :].std(axis=0).rename('Std')
print(std / np.sqrt(40))
# table = table.append(table.iloc[:-1, :].std(axis=0).rename('Std'))
# display(table.iloc[-2:, 12:])

display(table.iloc[-1:, 12:])

In [None]:
table_acc = pd.DataFrame.from_dict(table_accs, orient='index', columns=val_labels+test_labels)
table_acc = table_acc.append(table_acc.mean(axis=0).rename('Avg'))
# table_acc.iloc[:, 12:]
display(table_acc.iloc[-1:, 12:])
# show the IM and W columns

In [None]:
assert(table_acc.iloc[:-1, 12:].shape[0] == 40)
df = {'train_method': ['Im', 'Im', 'Im', 'latent', 'latent', 'latent'] + ['latent_stylemix'] * 3 + ['latent_stylemix_crop'] * 3,
      'ens_method': ['Single Image', 'Style-mix Ensemble', 'Combined Ensemble'] * 4,
      'acc': table_acc.iloc[:-1, 12:].mean(axis=0),
      'stderr': table_acc.iloc[:-1, 12:].std(axis=0) / np.sqrt(table_acc.iloc[:-1, 12:].shape[0])
     }
df = pd.DataFrame.from_dict(df)
display(df)

f, ax = plt.subplots(1, 1, figsize=(6, 4))
group_size = 3
bar_width=0.2
n_groups = 4
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = make_blue_palette(group_size)

xticklabels = []
for i in range(group_size):
    indices = np.arange(i, n_groups*group_size, group_size)
    bar_height = df.iloc[indices]['acc']
    bar_err = df.iloc[indices]['stderr']
    assert(all([x == df.iloc[indices[0]]['ens_method'] for x in df.iloc[indices]['ens_method']]))
    ax.bar(bar_offsets[i], bar_height, width=bar_width, color=palette[i], 
           label=df.iloc[indices[0]]['ens_method'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
    xticklabels.append(df.iloc[indices[0]]['train_method'].replace('_', '\n'))
ax.set_ylim([88.5, 89.8])
ax.legend(prop={'size': 12}) # , loc='upper left')
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=2, prop={'size': 11})
ax.set_xticks(np.arange(1,n_groups+1))
ax.set_xticklabels(['Train\nImage', 'Train\nLatent', 'Train\nStyle-mix', 'Train\nCombined'], fontsize=14)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(12) 
ax.set_xlabel('')
ax.set_ylabel('Accuracy', fontsize=16)
f.tight_layout()
save(f, 'graph_face_train_latent.pdf')

In [None]:
assert(table.iloc[:-1, 12:].shape[0] == 40)
df = {'train_method': ['Im', 'Im', 'Im', 'latent', 'latent', 'latent'] + ['latent_stylemix'] * 3 + ['latent_stylemix_crop'] * 3,
      'ens_method': ['Single Image', 'Style-mix Ensemble', 'Combined Ensemble'] * 4,
      'acc': table.iloc[:-1, 12:].mean(axis=0),
      'stderr': table.iloc[:-1, 12:].std(axis=0) / np.sqrt(table.iloc[:-1, 12:].shape[0])
     }
df = pd.DataFrame.from_dict(df)
display(df)

f, ax = plt.subplots(1, 1, figsize=(6, 4))
group_size = 3
bar_width=0.2
n_groups = 4
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = make_blue_palette(group_size)

xticklabels = []
for i in range(group_size):
    indices = np.arange(i, n_groups*group_size, group_size)
    bar_height = df.iloc[indices]['acc']
    bar_err = df.iloc[indices]['stderr']
    assert(all([x == df.iloc[indices[0]]['ens_method'] for x in df.iloc[indices]['ens_method']]))
    ax.bar(bar_offsets[i], bar_height, width=bar_width, color=palette[i],  yerr=bar_err,
           label=df.iloc[indices[0]]['ens_method'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
    xticklabels.append(df.iloc[indices[0]]['train_method'].replace('_', '\n'))
# ax.set_ylim([88.5, 89.6])
ax.set_ylim([-0.3, 0.8])
ax.legend(prop={'size': 12}, loc='upper left')
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=2, prop={'size': 11})
ax.set_xticks(np.arange(1,n_groups+1))
ax.set_xticklabels(['Train\nImage', 'Train\nLatent', 'Train\nStyle-mix', 'Train\nCombined'], fontsize=14)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(12) 
ax.set_xlabel('')
ax.set_ylabel('Accuracy Difference', fontsize=16)
f.tight_layout()
save(f, 'graph_face_train_latent_diff.pdf')

# distribution of classification accuracies

In [None]:
f, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.hist(table_acc['Test Im Single'])
ax.set_xlim([50, 100])
ax.set_ylabel('Count', fontsize=14)
ax.set_xlabel('Test Accuracy', fontsize=14)
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(12) 
for tick in ax.xaxis.get_major_ticks():
    tick.label.set_fontsize(12) 
save(f, 'sm_graph_face_acc_distribution.pdf')

# over 12 attributes, plot stylemix, isotropic, and PCA fine and coarse

In [None]:
attr_mean = data_celebahq.attr_celebahq.mean(axis=0)[:-1]
attr_order = sorted([(abs(v-0.5), v, k) for k, v in attr_mean.to_dict().items()])

df_val = defaultdict(list)
df_test = defaultdict(list)

for i, (_, _, attr) in enumerate(tqdm(attr_order[:12])):
    # print('========== %s ==========' % attr)
    
    val_expts = [
        (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_isotropic_coarse.npz', 
         ('isotropic_coarse_0.10', 'isotropic_coarse_0.30'), 'Isotropic Coarse'),
        (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_isotropic_fine.npz', 
         ('isotropic_fine_0.10', 'isotropic_fine_0.30'), 'Isotropic Fine'),
        (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_pca_coarse.npz', 
         ('pca_coarse_1.00', 'pca_coarse_2.00', 'pca_coarse_3.00'), 'PCA Coarse'),
        (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_pca_fine.npz',
         ('pca_fine_1.00', 'pca_fine_2.00', 'pca_fine_3.00'), 'PCA Fine'),
        (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_coarse.npz', 
         ('stylemix_coarse',), 'Style-mix Coarse'),
        (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine.npz', 
         ('stylemix_fine',), 'Style-mix Fine'),
    ]
    test_expts = [(x.replace('_val/', '_test/'), y, z) for x, y, z in val_expts]
    for i, (val, test) in enumerate(zip(val_expts, test_expts)):
        expt_settings = []
        for expt_name in val[1]:
            resampled_accs = resample_wrapper(val[0], test[0], expt_name, ens_size=31, 
                                              add_aug=False, verbose=False)            
            resampled_accs['expt_name'] = expt_name
            expt_settings.append(resampled_accs)
            
        # these should all be the same -- just standard test info
        assert(all([x['val_acc_original'] == expt_settings[0]['val_acc_original'] for x in expt_settings]))
        assert(all([x['test_acc_original'] == expt_settings[0]['test_acc_original'] for x in expt_settings]))
        
        if i == 0:
            df_val['attribute'].append(attr)
            df_val['acc'].append(expt_settings[0]['val_acc_original'])
            df_val['stderr'].append(0.)
            df_val['expt_group'].append('Original Image')
            df_val['expt'].append('original')
            df_test['attribute'].append(attr)
            df_test['acc'].append(expt_settings[0]['test_acc_original'])
            df_test['stderr'].append(0.)
            df_test['expt_group'].append('Original Image')
            df_test['expt'].append('original')

        
        # import pdb; pdb.set_trace()
        best_expt = max(expt_settings, key=lambda x: x['val_avg']) # take the val accuracy
        # val result
        df_val['attribute'].append(attr)
        df_val['acc'].append(best_expt['val_avg'])
        df_val['stderr'].append(best_expt['val_stderr'])
        df_val['expt'].append(best_expt['expt_name'])
        df_val['expt_group'].append(val[2])
        
        # test result
        df_test['attribute'].append(attr) 
        df_test['acc'].append(best_expt['test_avg'])
        df_test['stderr'].append(best_expt['test_stderr'])
        df_test['expt'].append(best_expt['expt_name'])
        df_test['expt_group'].append(test[2])
        
df_val = pd.DataFrame.from_dict(df_val)
df_test = pd.DataFrame.from_dict(df_test)

In [None]:
df_per_attr_val = OrderedDict([])
group_size=7
num_attr=12
for i in range(0, num_attr*group_size, group_size):
    attribute_names = list(df_val.iloc[i:i+group_size]['attribute'])
    assert(all([x == attribute_names[0] for x in attribute_names]))
    df_per_attr_val[attribute_names[0]] = list(df_val.iloc[i:i+group_size]['acc'])
df_per_attr_val = pd.DataFrame.from_dict(df_per_attr_val, orient='index', columns=['Original'] + [x for _,_, x in val_expts])

df_per_attr_test = OrderedDict([])
group_size=7
num_attr=12
for i in range(0, num_attr*group_size, group_size):
    attribute_names = list(df_test.iloc[i:i+group_size]['attribute'])
    assert(all([x == attribute_names[0] for x in attribute_names]))
    df_per_attr_test[attribute_names[0]] = list(df_test.iloc[i:i+group_size]['acc'])
df_per_attr_test = pd.DataFrame.from_dict(df_per_attr_test, orient='index', columns=['Original'] + [x for _,_, x in test_expts])

In [None]:
df_per_attr_test

In [None]:
df_per_attr_val_diff = (df_per_attr_val.sub(df_per_attr_val['Original'], axis=0)).iloc[:, 1:]
df_per_attr_test_diff = (df_per_attr_test.sub(df_per_attr_test['Original'], axis=0)).iloc[:, 1:]
f, ax = plt.subplots(1, 1, figsize=(6, 3))

group_size = 2
bar_width=0.25
n_groups = 6
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = sns.color_palette()

#### combined plot ####
for i, label in enumerate(df_per_attr_val_diff.columns):
    
    # val
    height = df_per_attr_val_diff[label].mean()
    yerr = df_per_attr_val_diff[label].std() / np.sqrt(df_per_attr_val_diff.shape[0])
    ax.bar(bar_offsets[0][i], height, yerr=yerr, width=bar_width, color=palette[0],
           edgecolor=(0.5, 0.5, 0.5), capsize=5, label='Validation' if i == 0 else None)
    # test
    height = df_per_attr_test_diff[label].mean()
    yerr = df_per_attr_test_diff[label].std() / np.sqrt(df_per_attr_test_diff.shape[0])
    ax.bar(bar_offsets[1][i], height, yerr=yerr, width=bar_width, color=palette[1],
           edgecolor=(0.5, 0.5, 0.5), capsize=5, label='Test' if i == 0 else None)
ax.legend()
ax.set_ylabel('Accuracy Difference', fontsize=14)
ax.set_xticks(np.arange(1,n_groups+1))
ax.set_xticklabels([x.replace(' ', '\n') for x in df_per_attr_val_diff.columns], fontsize=11)
save(f, 'graph_face_gan_aug_types.pdf')

# plot the accuracy vs alpha graph

In [None]:
for attr in ['Smiling',]:
    val_expt = (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine_tensortransform.npz', 
                ('stylemix_fine',), 'Style-Mix Fine')
    x, y, z = val_expt
    test_expt = (x.replace('_val', '_test'), y, z)

    val_res = get_accuracy_from_npz(val_expt[0], val_expt[1][0], add_aug=False, ens_size=31)
    test_res = get_accuracy_from_npz(test_expt[0], test_expt[1][0], add_aug=False, ens_size=31)

    f, ax = plt.subplots(1, 1, figsize=(6, 3)) # , sharey=True)

    ax.plot(val_res['ensemble_table']['weight'], val_res['ensemble_table']['acc'], label='Validation')
    ax.plot(test_res['ensemble_table']['weight'], test_res['ensemble_table']['acc'], label='Test')
    
    # plot the ensemble weight
    val_ensemble_table = val_res['ensemble_table']
    best_val_setting = val_ensemble_table.iloc[val_ensemble_table['acc'].argsort().iloc[-1], :]
    ax.axvline(best_val_setting.weight, color='k', linestyle=':', label='Selected Weight')

    ax.set_ylabel('Accuracy')
    ax.set_xlabel('Ensemble Weight')
    #for tick in ax.yaxis.get_major_ticks():
    #    tick.label.set_fontsize(12) 
    #for tick in ax.xaxis.get_major_ticks():
    #    tick.label.set_fontsize(12) 
    if attr == 'Smiling':
        ax.legend()

    # ax.set_title('Attribute: ' + attr.replace('_', ' '), fontsize=16)
    # ax[1].set_title('Test', fontsize=16)
    # f.suptitle('Attribute: ' + attr.replace('_', ' '), fontsize=16, y=1.0)
    f.tight_layout()
    save(f, 'sm_ensemble_alpha_%s_v2.pdf' % attr)

# stylegan corruptions

In [None]:
# sample each 20 times
table_dict = OrderedDict([])
table_accs = OrderedDict([])
table_stderrs = OrderedDict([])

# axes = [col for row in axes for col in row]

n_samples = 20

for i, attribute in enumerate(['Smiling', 'Arched_Eyebrows', 'Young', 'Wavy_Hair']):
    val_values = []
    test_values = []
    val_stderrs = []
    test_stderrs = []
    val_diffs = []
    test_diffs = []
    
    val_labels = ['Val ' + corruption + ' ' + eval_method for corruption in 
                  ['Im', 'Jpeg', 'Blur', 'Noise', 'FGSM', 'PGD', 'CW'] for eval_method in ['S', 'R', 'G', 'C']]
    test_labels = ['Test ' + corruption + ' ' + eval_method for corruption in 
                   ['Im', 'Jpeg', 'Blur', 'Noise', 'FGSM', 'PGD', 'CW'] for eval_method in ['S', 'R', 'G', 'C']]    
        
    for prefix in ['', 'corruption_jpeg_', 'corruption_gaussian_blur_', 'corruption_gaussian_noise_', 'fgsm_', 'pgd_', 'cw_']:
        
        attr = prefix + attribute

        print(attr)
        
        # gan jitter fine
        val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine.npz'
        test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/gan_ensemble_stylemix_fine.npz'
        expt_name = 'stylemix_fine'
        # resample
        resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                          add_aug=False, verbose=False)   
        val_orig = resampled_accs['val_acc_original']
        val_top1 = resampled_accs['val_avg']
        val_stderr = resampled_accs['val_stderr']
        val_rec = resampled_accs['val_acc_rec']
        test_orig = resampled_accs['test_acc_original']
        test_top1_from_val = resampled_accs['test_avg']
        test_stderr = resampled_accs['test_stderr']
        test_rec = resampled_accs['test_acc_rec']

        # gan jitter with color/crop jitter
        val_file = f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine_tensortransform.npz'
        test_file = f'results/precomputed_evaluations/celebahq/output/{attr}_test/gan_ensemble_stylemix_fine_tensortransform.npz'
        expt_name = 'stylemix_fine'
        resampled_accs = resample_wrapper(val_file, test_file, expt_name, ens_size=31, 
                                          add_aug=False, verbose=False)   
        val_orig_mix = resampled_accs['val_acc_original']
        val_top1_mix = resampled_accs['val_avg']
        val_stderr_mix = resampled_accs['val_stderr']
        val_rec_mix = resampled_accs['val_acc_rec']
        test_orig_mix = resampled_accs['test_acc_original']
        test_top1_from_val_mix = resampled_accs['test_avg']
        test_stderr_mix = resampled_accs['test_stderr']
        test_rec_mix = resampled_accs['test_acc_rec']


        # sanity check
        assert(test_orig_mix == test_orig)
        assert(test_rec_mix == test_rec)
        assert(val_orig_mix == val_orig)
        assert(val_rec_mix == val_rec)


        new_val_values = [val_orig, val_rec, val_top1, val_top1_mix]
        new_val_stderrs = [0., 0., val_stderr, val_stderr_mix]
        new_test_values = [test_orig, test_rec, test_top1_from_val, test_top1_from_val_mix]
        new_test_stderrs = [0., 0., test_stderr, test_stderr_mix]

        val_values.extend(new_val_values)
        test_values.extend(new_test_values)
        val_stderrs.extend(new_val_stderrs)
        test_stderrs.extend(new_test_stderrs) 
        val_diffs.extend([x - val_values[0] for x in new_val_values])
        test_diffs.extend([x - test_values[0] for x in new_test_values])

    table_dict[attribute] = val_diffs + test_diffs
    table_accs[attribute] = val_values + test_values
    table_stderrs[attribute] = val_stderrs + test_stderrs

In [None]:
table = pd.DataFrame.from_dict(table_dict, orient='index', columns=val_labels+test_labels)
table.shape

In [None]:
display(table.iloc[:, 28:])

table_acc = pd.DataFrame.from_dict(table_accs, orient='index', columns=val_labels+test_labels)
display(table_acc.iloc[:, 28:])

table_stderr = pd.DataFrame.from_dict(table_stderrs, orient='index', columns=val_labels+test_labels)
display(table_stderr.iloc[:, 28:])

In [None]:
f, axes = plt.subplots(1, 4, figsize=(16, 3.5))

for row, attr in enumerate(table_acc.index):
    ax = axes[row]
    df = {'train_method': ['Uncorrupted'] * 4 + ['Jpeg'] * 4 + ['Blur'] * 4 + ['Noise'] * 4,
          'ens_method': ['Image', 'Reconstruction', 'Style-mix Ensemble', 'Combined Ensemble'] * 4,
          'acc': table_acc.iloc[row, 28:-12],
          'stderr': table_stderr.iloc[row, 28:-12]
         }
    df = pd.DataFrame.from_dict(df)
    # display(df)
    palette = make_blue_palette(4)
    
    group_size=4
    n_groups=4
    bar_width=0.2
    bar_offsets = bar_offset(group_size, n_groups, bar_width)
    xticklabels = []
    for i in range(group_size):
        indices = np.arange(i, n_groups*group_size, group_size)
        bar_height = df.iloc[indices]['acc']
        bar_err = df.iloc[indices]['stderr']
        assert(all([x == df.iloc[indices[0]]['ens_method'] for x in df.iloc[indices]['ens_method']]))
        ax.bar(bar_offsets[i], bar_height, width=bar_width, color=palette[i], #   yerr=bar_err,
               label=df.iloc[indices[0]]['ens_method'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
        xticklabels.append(df.iloc[indices[0]]['train_method'].replace('_', '\n'))
    
    ax.set_ylim([np.min(df['acc'])-1.0, np.max(df['acc'])+1.0])
    # ax.legend(loc='upper left', prop={'size': 12})
    # ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.set_xticks(np.arange(1, n_groups+1))
    ax.set_xticklabels(['Clean', 'Jpeg', 'Blur', 'Noise'], fontsize=14)
    ax.set_xlabel('')
    ax.set_ylabel('Accuracy', fontsize=16)
    ax.set_title(attr.replace('_', ' '), fontsize=16)
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(12) 
    

handles, labels = ax.get_legend_handles_labels() # on the last axis
lgd = f.legend(handles, labels, loc='lower center', ncol=4, prop={'size': 12},
               bbox_to_anchor=(0.5, -0.08), edgecolor='1.0')
f.tight_layout()
save(f, 'graph_face_untargeted_corruption.pdf')

In [None]:
f, axes = plt.subplots(1, 4, figsize=(16, 3.5))

for row, attr in enumerate(table_acc.index):
    ax = axes[row]
    df = {'train_method': ['Uncorrupted'] * 4 + ['FGSM'] * 4 + ['PGD'] * 4 + ['CW'] * 4,
          'ens_method': ['Image', 'Reconstruction', 'Style-mix Ensemble', 'Combined Ensemble'] * 4,
          'acc': table_acc.iloc[row, list(range(28, 32)) + list(range(44,56))],
          'stderr': table_stderr.iloc[row, list(range(28, 32)) + list(range(44,56))]
         }
    df = pd.DataFrame.from_dict(df)
    # display(df)
    palette = make_blue_palette(4)
    
    group_size=4
    n_groups=4
    bar_width=0.2
    bar_offsets = bar_offset(group_size, n_groups, bar_width)
    xticklabels = []
    for i in range(group_size):
        indices = np.arange(i, n_groups*group_size, group_size)
        bar_height = df.iloc[indices]['acc']
        bar_err = df.iloc[indices]['stderr']
        assert(all([x == df.iloc[indices[0]]['ens_method'] for x in df.iloc[indices]['ens_method']]))
        b = ax.bar(bar_offsets[i], bar_height, width=bar_width, color=palette[i], # yerr=bar_err,
                   label=df.iloc[indices[0]]['ens_method'], edgecolor=(0.5, 0.5, 0.5), capsize=5)
        xticklabels.append(df.iloc[indices[0]]['train_method'].replace('_', '\n'))
    
    # ax.set_ylim([np.min(df['acc'])-1.0, np.max(df['acc'])+1.0])
    # ax.legend(loc='upper center', prop={'size': 12})
    ax.set_ylim([0, 100])
    #ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.set_xticks(np.arange(1, n_groups+1))
    ax.set_xticklabels(['Clean', 'FGSM', 'PGD', 'CW'], fontsize=14)
    ax.set_xlabel('')
    ax.set_ylabel('Accuracy', fontsize=16)
    ax.set_title(attr.replace('_', ' '), fontsize=16)
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(12) 
    
# axes[0].legend([],[], frameon=False)
# axes[1].legend([],[], frameon=False)
# axes[2].legend([],[], frameon=False)

handles, labels = ax.get_legend_handles_labels() # on the last axis
lgd = f.legend(handles, labels, loc='lower center', ncol=4, prop={'size': 12}, bbox_to_anchor=(0.5, -0.08), edgecolor='1.0')
f.tight_layout()
save(f, 'graph_face_targeted_corruption.pdf')

# stylegan ensemble size

In [None]:
def compute_best_weight_ensemble_size(val_data_file, test_data_file, expt_name, verbose=True, add_aug=False, seed=None):
    ens_sizes = [0, 2, 4, 8, 12, 16, 20, 24, 28, 30, 31]
    num_samples = 16
    assert('val' in val_data_file)
    assert('test' in test_data_file)
    # compute best val setting using full ensemble
    val_accuracy_info = get_accuracy_from_npz(val_data_file, expt_name, add_aug=add_aug, ens_size=31, seed=seed)
    val_ensemble_table = val_accuracy_info['ensemble_table']
    # best_val_setting = val_ensemble_table.iloc[val_ensemble_table['acc'].idxmax(), :]
    best_val_setting = val_ensemble_table.iloc[val_ensemble_table['acc'].argsort().iloc[-1], :]
    
    if verbose:
        print("Val original %0.4f Val reconstructed %0.4f" % 
              (val_accuracy_info['acc_original'], val_accuracy_info['acc_reconstructed']))
        print("%0.4f @ %0.4f %s" % (best_val_setting['acc'], best_val_setting['weight'], best_val_setting['expt_name']))
    
    
    # test: iterate through ensemble sizes, taking samples from each
    accs_reconstructed = []
    accs_original = []
    test_table = OrderedDict([(ens_size, []) for ens_size in ens_sizes])
    for ens_size in ens_sizes:
        for sample in range(num_samples):
            test_accuracy_info = get_accuracy_from_npz(test_data_file, expt_name, weight=best_val_setting['weight'], 
                                                       add_aug=add_aug, ens_size=ens_size, seed=sample)

            accs_reconstructed.append(test_accuracy_info['acc_reconstructed'])
            accs_original.append(test_accuracy_info['acc_original'])
            test_ensemble_table = test_accuracy_info['ensemble_table']
            assert(test_ensemble_table.shape[0] == 1) # it should only evaluate at the specified weight
            test_setting_from_val = test_ensemble_table.iloc[0, :]
            test_table[ens_size].append(test_setting_from_val['acc'])
            
    # sanity check
    assert(all([x == accs_reconstructed[0] for x in accs_reconstructed]))
    assert(all([x == accs_original[0] for x in accs_original]))

    test_df = pd.DataFrame.from_dict(test_table, orient='index', columns=range(num_samples))
    
    return {'val_info': val_accuracy_info, 'test_info': test_accuracy_info, 
            'val_setting': best_val_setting, 'test_df': test_df}

In [None]:
expt_name = 'stylemix_fine'
expt_data = [
    ('Smiling', f'results/precomputed_evaluations/celebahq/output/%s_%s/gan_ensemble_stylemix_fine_tensortransform.npz'),
    ('Arched_Eyebrows', f'results/precomputed_evaluations/celebahq/output/%s_%s/gan_ensemble_stylemix_fine_tensortransform.npz'),
    ('Wavy_Hair', f'results/precomputed_evaluations/celebahq/output/%s_%s/gan_ensemble_stylemix_fine_tensortransform.npz'),
    ('Young', f'results/precomputed_evaluations/celebahq/output/%s_%s/gan_ensemble_stylemix_fine_tensortransform.npz')
]

f, axes = plt.subplots(1, 4, figsize=(16, 4))
# axes = [ax for row in axes for ax in row]

for i, (attr, data_file_base) in enumerate(expt_data):
    ax = axes[i]
    output = compute_best_weight_ensemble_size(data_file_base % (attr, 'val'), 
                                               data_file_base % (attr, 'test'), 
                                               expt_name)
    plot_vals = output['test_df'].to_numpy()
    m = np.mean(plot_vals, axis=1)
    s = np.std(plot_vals, axis=1) / np.sqrt(plot_vals.shape[1])
    ax.plot(output['test_df'].index, m)
    ax.fill_between(output['test_df'].index, m-s, m+s, alpha=0.3)
    ax.set_title(attr.replace('_', ' '), fontsize=16)
    ax.set_xlabel('Number of\nGAN samples', fontsize=14)
    ax.set_ylabel('Accuracy', fontsize=16)
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(12) 
    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(12) 
    # ax.axhline(test_output[0][0])
    # ax.axhline(test_output[2])
f.tight_layout()
save(f, 'graph_face_ensemble_size.pdf')

# cifar10

In [None]:
table_dict = {}

for classifier in ['imageclassifier', 'latentclassifier', 'latentclassifier_layer6', 'latentclassifier_layer7']:
    print("==================")
    for expt_name in ['stylemix_layer6', 'stylemix_layer7']:
        print("---> %s %s" % (classifier, expt_name))
        val_data_file = f'results/precomputed_evaluations/cifar10/output/{classifier}_val/gan_ensemble_{expt_name}.npz'
        test_data_file = val_data_file.replace('_val', '_test')
        resampled_accs = resample_wrapper(val_data_file, test_data_file, expt_name, ens_size=31, 
                                          add_aug=False, verbose=False)
        
        print("val improvement: %0.3f" % (resampled_accs['val_avg'] - resampled_accs['val_acc_original']))
        print("test improvement: %0.3f" % (resampled_accs['test_avg'] - resampled_accs['test_acc_original']))
        
        oracle = get_accuracy_from_npz(test_data_file, expt_name)
        oracle_table = oracle['ensemble_table']
        oracle_setting = oracle_table.iloc[oracle_table['acc'].argsort().iloc[-1], :]
        print("oracle imrovement: %0.3f" % (oracle_setting['acc'] - resampled_accs['test_acc_original']))
        if expt_name == 'stylemix_layer6':
            # also extract the classifier acc on images
            table_dict['%s %s' % (classifier, 'images')] = [np.nan, resampled_accs['val_acc_original'],
                                                         resampled_accs['test_acc_original'], np.nan, np.nan]
        table_dict['%s %s' % (classifier, expt_name)] = [np.mean(resampled_accs['weights']), resampled_accs['val_avg'],
                                                         resampled_accs['test_avg'], oracle_setting['weight'],
                                                         oracle_setting['acc']]

In [None]:
table = pd.DataFrame.from_dict(table_dict, orient='index', 
                               columns=['val weight', 'val acc', 'test acc', 'oracle weight', 'oracle acc'])
table

In [None]:
# plot it
f, ax = plt.subplots(1, 1, figsize=(6, 4))

group_size = 3
bar_width=0.2
n_groups = 4 # training configurations
bar_offsets = bar_offset(group_size, n_groups, bar_width)
palette = make_yellow_palette(2)[1:] + make_blue_palette(2)[1:] + make_green_palette(2)[1:]

ind = 0.2
# ax.axhline(im_crops['acc_ensembled'], color='k', linestyle=':', label='Original Images')
ax.bar(bar_offsets[0], table.loc[[x for x in table.index if x.endswith('images')]]['test acc'],
       width=bar_width, color=palette[0], label='Image', edgecolor=(0.5, 0.5, 0.5), capsize=5)

for i, layer in enumerate([6, 7]):
    ax.bar(bar_offsets[i+1], table.loc[[x for x in table.index if x.endswith('layer%d' % layer)]]['test acc'],
           width=bar_width, color=palette[i+1], label='Style-mix Layer%d' % layer, edgecolor=(0.5, 0.5, 0.5), capsize=5)
    
ax.set_ylim([92, 96])
ax.set_ylabel('Classification Accuracy', fontsize=14)
ax.set_xticks(np.arange(1, n_groups+1))
ax.legend()
ax.set_xticklabels(['Original\nImages', 'GAN\nReconstructions',
                    'Style-mix\nLayer 6', 'Style-mix\nLayer 8'], fontsize=12)
ax.set_xlabel('Classifier training distribution', fontsize=16)
save(f, 'graph_cifar10.pdf')