# Plot Classification Accuracy as a function of Ensemble Weight

In [None]:
import sys
import numpy as np
from matplotlib import rc
import matplotlib.pyplot as plt
import sklearn.metrics
from collections import defaultdict
import os
import glob
from collections import OrderedDict, namedtuple
from tqdm import tqdm
import pandas as pd
import seaborn as sns
pd.options.display.float_format = '{:0.2f}'.format
rc('font', **{'family': 'serif'})

from data import data_celebahq

%matplotlib inline

# Define utility functions

In [None]:
### data evaluation utilities ### 

def softmax_to_prediction(softmax_prediction):
    # converts softmax prediction to discrete class label
    if np.ndim(softmax_prediction) == 2:
        # N x ensembles binary prediction
        return (softmax_prediction > 0.5).astype(int)
    elif np.ndim(softmax_prediction) == 3:
        # N x ensembles x classes
        return np.argmax(softmax_prediction, axis=-1).squeeze()
    else:
        assert(False)
        
def get_accuracy_from_image_ensembles(data_file, key, resample=False, seed=0, 
                                      n_resamples=20, ens_size=32, verbose=True):
    # helper function to extract ensembled accuracy from image augmentations
    # e.g. image_ensemble_imcolor.npz or image_ensemble_imcrop.npz
    encoded_data = np.load(data_file)
    preds_original = softmax_to_prediction(encoded_data['original'])
    acc_original = sklearn.metrics.accuracy_score(encoded_data['label'], preds_original) * 100
    jitters = np.concatenate([encoded_data['original'], encoded_data[key]], axis=1)
    jitters = np.mean(jitters, axis=1, keepdims=True)
    preds_ensembled = softmax_to_prediction(jitters)
    acc_ensembled = sklearn.metrics.accuracy_score(encoded_data['label'], preds_ensembled) * 100
    
    resamples = None
    if resample:
        # sample num_samples batches with replacement, compute accuracy
        resamples = []
        rng = np.random.RandomState(seed)
        jitters = np.concatenate([encoded_data['original'], encoded_data[key]], axis=1)
        assert(jitters.shape[1] == ens_size) # sanity check
        for i in range(n_resamples):
            if verbose:
                print('*', end='')
            indices = rng.choice(jitters.shape[1], ens_size, replace=True)
            jitters_resampled = jitters[:, indices]
            jitters_resampled = np.mean(jitters_resampled, axis=1, keepdims=True)
            preds_ensembled = softmax_to_prediction(jitters_resampled)
            resamples.append(sklearn.metrics.accuracy_score(encoded_data['label'], preds_ensembled) * 100)
        if verbose:
            print("done")
    return {'acc_original': acc_original, 'acc_ensembled': acc_ensembled, 'resamples': resamples}
    
def sample_ensemble(raw_preds, ens_size=None, seed=None):
    # helper function to resample raw ensemble predictions
    # raw_preds = N x ens_size for  binary classification, or N x ens_size x classes
    # ens_size = number of samples to take preds for ensembling, None takes all all samples
    # seed = random seed to use when sampling with replacement, None takes samples in order
    if ens_size is None:
        ens_size = raw_preds.shape[1] # take all samples
    if seed is None:
        ensemble_preds = raw_preds[:, range(ens_size)] # take the samples in order
    else: # sample the given preds with replacement
        rng = np.random.RandomState(seed)
        indices = rng.choice(raw_preds.shape[1], ens_size, replace=True)
        ensemble_preds = raw_preds[:, indices]
    return ensemble_preds

def get_accuracy_from_npz(data_file, expt_name, weight=None, ens_size=None, seed=None, return_preds=False,
                          add_aug=False, aug_name='image_ensemble_imcrop', aug_key='imcrop'):
    # compute weighted accuracies combining original image and GAN reconstructions from an npz_file
    # option to use either single original image, or multiple image augmentations for the image views
    
    # setup
    encoded_data = np.load(data_file)
    df = defaultdict(list)
    expt_settings = os.path.basename(data_file).split('.')[0]
    if weight is not None:
        weights = [weight]
    else:
        weights = np.linspace(0, 1, 21)

    # determine image classification accuracy
    if not add_aug:
        # basic case: just load the image predictions from the data file
        preds_original = softmax_to_prediction(encoded_data['original'])
        original = encoded_data['original'] # full softmax distribution
    else:
        # ensemble also with the image augmentations data
        print('.', end='')
        im_aug_data = np.load(os.path.join(data_file.rsplit('/', 1)[0], '%s.npz' % aug_name))
        im_aug_ens = np.concatenate([im_aug_data['original'], im_aug_data[aug_key]], axis=1)
        im_aug_ens = sample_ensemble(im_aug_ens, ens_size, seed)
        im_aug_ens = np.mean(im_aug_ens, axis=1, keepdims=True) 
        preds_original = softmax_to_prediction(im_aug_ens)
        original = im_aug_ens # full softmax distribution
    acc_original = sklearn.metrics.accuracy_score(encoded_data['label'], preds_original) * 100
    
    # determine GAN reconstruction accuracy
    preds_reconstructed = softmax_to_prediction(encoded_data['reconstructed'])
    acc_reconstructed = sklearn.metrics.accuracy_score(encoded_data['label'], preds_reconstructed) * 100
    
    # determine GAN ensemble accuracy
    perturbed = encoded_data[expt_name] # N x ens_size x softmax distribution
    gan_ens = np.concatenate((encoded_data['reconstructed'], perturbed), axis=1)
    if ens_size == 0:
        gan_ens = original # dummy case: don't use gan reconstructed images
    else:
        gan_ens = sample_ensemble(gan_ens, ens_size, seed)    
    for weight in weights: # alpha weighting hyperparameter
        # for binary classification: original.shape = N x 1, gan_ens.shape = N x ens_size
        # for multi-class classification: original.shape = N x 1 x classes; gan_ens.shape = N x ens_size x classes
        ensembled = (1-weight) * original + weight * np.mean(gan_ens, axis=1, keepdims=True)
        preds_ensembled = softmax_to_prediction(ensembled)
        acc_ensembled = sklearn.metrics.accuracy_score(encoded_data['label'], preds_ensembled) * 100
        df['acc'].append(acc_ensembled)
        df['weight'].append(weight)
        df['expt_name'].append(expt_name)

    # table of expt_name x weight
    df = pd.DataFrame.from_dict(df)
    return_data = {'expt_settings': expt_settings, 
                   'acc_original': acc_original,
                   'acc_reconstructed': acc_reconstructed,
                   'ensemble_table': df}
    if return_preds:
        assert(len(weights) == 1)
        return_preds = {
            'original': original, # original softmax
            'reconstruction': gan_ens, # softmax of all gan views
            'ensembled': ensembled, # softmax of the weighted ensemble
            'pred_original': preds_original,
            'pred_reconstruction': preds_reconstructed,
            'pred_ensemble': preds_ensembled,
            'label': encoded_data['label'],
        }
        return return_data, return_preds
    return return_data

# Make plot

In [None]:
attr = 'Smiling'
val_expt = (f'results/precomputed_evaluations/celebahq/output/{attr}_val/gan_ensemble_stylemix_fine_tensortransform.npz', 
            ('stylemix_fine',), 'Style-Mix Fine')
x, y, z = val_expt
test_expt = (x.replace('_val', '_test'), y, z)

val_res = get_accuracy_from_npz(val_expt[0], val_expt[1][0], add_aug=False, ens_size=31)
test_res = get_accuracy_from_npz(test_expt[0], test_expt[1][0], add_aug=False, ens_size=31)

f, ax = plt.subplots(1, 1, figsize=(6, 3)) # , sharey=True)

ax.plot(val_res['ensemble_table']['weight'], val_res['ensemble_table']['acc'], label='Validation')
ax.plot(test_res['ensemble_table']['weight'], test_res['ensemble_table']['acc'], label='Test')

# plot the ensemble weight
val_ensemble_table = val_res['ensemble_table']
best_val_setting = val_ensemble_table.iloc[val_ensemble_table['acc'].argsort().iloc[-1], :]
ax.axvline(best_val_setting.weight, color='k', linestyle=':', label='Selected Weight')

ax.set_ylabel('Accuracy')
ax.set_xlabel('Ensemble Weight')
ax.legend()
f.tight_layout()

print("Test Accuracy, Images: %0.4f" % test_res['acc_original'])
print("Test Accuracy, GAN Reconstructions: %0.4f" % test_res['acc_reconstructed'])
test_weighted = test_res['ensemble_table'].loc[test_res['ensemble_table']['weight'] == best_val_setting.weight]
print("Test Accuracy, Weighted GAN Reconstructions (weight from validation): %0.4f @ weight=%0.2f"
      % (test_weighted.acc, test_weighted.weight))
test_oracle = test_res['ensemble_table'].iloc[test_res['ensemble_table']['acc'].argsort().iloc[-1], :]
print("Test Accuracy, Weighted GAN Reconstructions (Oracle): %0.4f @ weight=%0.2f"
      % (test_oracle.acc, test_oracle.weight))