In [None]:
import numpy as np
import attacks
from datasets import ADULT, German, HealthHeritage, Lawschool
from attacks.ensembling import pooled_ensemble
from attacks.inversion_losses import _cosine_similarity_loss
from attacks import invert_grad
import torch
from models import FullyConnected
from utils import calculate_entropy_heat_map, Timer, match_reconstruction_ground_truth, categorical_accuracy_continuous_tolerance_score, categorical_softmax, post_process_continuous
from scipy.optimize import linear_sum_assignment
from scipy.stats import kendalltau
import os
import pickle
import pandas as pd

# Results on Assessment via Entropy

In [None]:
# set the dataset and rerun the cells below to display the results corresponding to the chosen dataset
# chose from 'ADULT', 'German', 'Lawschool', 'HealthHeritage'
dataset_name = 'ADULT'

In [None]:
# specify the setting
n_samples = 50
size_of_ensemble = 30
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]

# load the network
with open(f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/metadata_46_{dataset_name}_{n_samples}_{size_of_ensemble}_1500_128_0.319_42_epoch0/net.pickle', 'rb') as f:
    net = pickle.load(f)

# load all the data
reconstruction_ensembles_over_batch_size = []
ground_truths_over_batch_size = []
ground_truth_labels_over_batch_size = []
orig_recons_over_batch_size = []
for batch_size in batch_sizes:
    path = f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/metadata_46_{dataset_name}_{n_samples}_{size_of_ensemble}_1500_128_0.319_42_epoch0/batch_size_{batch_size}/'
    reconstruction_ensembles = [[torch.tensor(np.load(path + f'all_reconstructions_{sample}/ensemble_recon_{num}.npy')) for num in range(size_of_ensemble)] for sample in range(n_samples)]
    ground_truths = [torch.tensor(np.load(path + f'ground_truth_{batch_size}_{sample}.npy')) for sample in range(n_samples)]
    orig_recons = [torch.tensor(np.load(path + f'reconstruction_{batch_size}_{sample}.npy')) for sample in range(n_samples)]
    ground_truth_labels = [torch.tensor(np.load(path + f'true_labels_{batch_size}_{sample}.npy')).long() for sample in range(n_samples)]
    reconstruction_ensembles_over_batch_size.append(reconstruction_ensembles)
    ground_truths_over_batch_size.append(ground_truths)
    ground_truth_labels_over_batch_size.append(ground_truth_labels)
    orig_recons_over_batch_size.append(orig_recons)
print('Data loaded')

base_path = f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/inversion_data_all_46_{dataset_name}_50_30_1500_{batch_sizes[-1]}_0.319_42.npy'
stuff = np.load(base_path)


# initialize the dataset
print('Instantiating the dataset')
if dataset_name == 'ADULT':
    dataset = ADULT()
elif dataset_name == 'German':
    dataset = German()
elif dataset_name == 'Lawschool':
    dataset = Lawschool()
elif dataset_name == 'HealthHeritage':
    dataset = HealthHeritage()
else:
    raise ValueError('No such dataset')
print('Dataset instantiated')
dataset.standardize()

The next cell calculates the error bitmaps and the entropy maps for each feature in each reconstruction (50 samples in each batch size). For this, first the 30 reconstructions in the assembles are matched to the minimum loss reconstruction. Then, the reconstruction are pooled. Finally, we calculate the entropy and the error maps from the pooled reconstruction. The first run of this cell might take up to a few hours for each dataset, but any later runs will be faster, as the results are automatically saved and loaded.

In [None]:
# we want to get the error maps and the entropy maps for each reconstruction
error_maps_over_batch_size = []
entropy_maps_over_batch_size = []
decoded_ground_truths_over_batch_size = []
decoded_reconstructions_over_batch_size = []

timer = Timer(len(batch_sizes)*n_samples)

if os.path.isfile(f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/metadata_46_{dataset_name}_{n_samples}_{size_of_ensemble}_1500_128_0.319_42_epoch0/error_maps_over_batch_size.pickle'):
    with open(f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/metadata_46_{dataset_name}_{n_samples}_{size_of_ensemble}_1500_128_0.319_42_epoch0/error_maps_over_batch_size.pickle', 'rb') as f:
        error_maps_over_batch_size = pickle.load(f)

    with open(f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/metadata_46_{dataset_name}_{n_samples}_{size_of_ensemble}_1500_128_0.319_42_epoch0/entropy_maps_over_batch_size.pickle', 'rb') as f:
        entropy_maps_over_batch_size = pickle.load(f)

else:
    for i, batch_size in enumerate(batch_sizes):
        error_maps = []
        entropy_maps = []
        decoded_ground_truths = []
        decoded_reconstructions = []
        errors = []
    
        for j in range(n_samples):
            timer.start()
            print(timer, end='\r')
            # --------- calculate minimum loss sample --------- #
            losses = []
            criterion = torch.nn.CrossEntropyLoss()
            outputs = net(ground_truths_over_batch_size[i][j])
            loss = criterion(outputs, ground_truth_labels_over_batch_size[i][j])
            true_grad = [grad.detach() for grad in torch.autograd.grad(loss, net.parameters(), create_graph=True)]
            for k in range(size_of_ensemble):
                # calculate the guessed gradient
                outputs = net(categorical_softmax(reconstruction_ensembles_over_batch_size[i][j][k], dataset))
                loss = criterion(outputs, ground_truth_labels_over_batch_size[i][j])
                guessed_gradient = [grad.detach() for grad in torch.autograd.grad(loss, net.parameters(), create_graph=True)]
                loss = _cosine_similarity_loss(guessed_gradient, true_grad, device='cpu').item()
                losses.append(loss)
            min_index = np.argmin(np.array(losses)).item()
            min_loss_sample = reconstruction_ensembles_over_batch_size[i][j][min_index]
            # --------- minimum loss sample calculated --------- #
        
            # --------- reorder wrt minimum sample --------- #
            reconstructions_decoded = [dataset.decode_batch(rec.detach().clone()) for rec in reconstruction_ensembles_over_batch_size[i][j]]
            tolerance_map = dataset.create_tolerance_map()
            all_indices_match = []
            for reconstruction in reconstructions_decoded:
                _, _, _, _, _, indices = match_reconstruction_ground_truth(dataset.decode_batch(min_loss_sample.detach().clone()),
                                                                           reconstruction, tolerance_map=tolerance_map,
                                                                           return_indices=True, match_based_on='all')
                all_indices_match.append(indices)
            reordered_reconstructions = torch.stack([rec[idx].detach().clone() for rec, idx in zip(reconstruction_ensembles_over_batch_size[i][j], all_indices_match)])
        
            resulting_reconstruction_for_entropy, cont_stds = pooled_ensemble(reconstructions=reordered_reconstructions, match_to_batch=min_loss_sample, dataset=dataset, pooling='soft_avg+softmax', return_std=True, already_reordered=True)
            resulting_reconstruction, _ = pooled_ensemble(reconstructions=reordered_reconstructions, match_to_batch=min_loss_sample, dataset=dataset, pooling='median+softmax', return_std=True, already_reordered=True)
            decoded_reconstruction, decoded_ground_truth = dataset.decode_batch(resulting_reconstruction, standardized=True), dataset.decode_batch(ground_truths_over_batch_size[i][j], standardized=True)
            reordered_reconstruction, _, _, _, _, idx = match_reconstruction_ground_truth(decoded_ground_truth, decoded_reconstruction, tolerance_map=dataset.create_tolerance_map(), return_indices=True)
            entropy_heat_map, _ = calculate_entropy_heat_map(resulting_reconstruction_for_entropy[idx], ground_truths_over_batch_size[i][j], cont_stds[idx], dataset)
            _, error_heat_map = calculate_entropy_heat_map(post_process_continuous(resulting_reconstruction[idx], dataset), ground_truths_over_batch_size[i][j], cont_stds[idx], dataset)
            decoded_ground_truths.append(decoded_ground_truth)
            decoded_reconstructions.append(reordered_reconstruction)
            error_maps.append(error_heat_map)
            entropy_maps.append(entropy_heat_map)
            errors.append(np.mean(error_heat_map))
            timer.end()
        error_maps_over_batch_size.append(error_maps)
        entropy_maps_over_batch_size.append(entropy_maps)
        decoded_ground_truths_over_batch_size.append(decoded_ground_truths)
        decoded_reconstructions_over_batch_size.append(decoded_reconstructions)
        print(f'{batch_size} {1-np.mean(errors):.4f} {np.std(errors):.4f}')
        
    with open(f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/metadata_46_{dataset_name}_{n_samples}_{size_of_ensemble}_1500_128_0.319_42_epoch0/error_maps_over_batch_size.pickle', 'wb') as f:
        pickle.dump(error_maps_over_batch_size, f)

    with open(f'experiment_data/large_scale_experiments/{dataset_name}/experiment_46/metadata_46_{dataset_name}_{n_samples}_{size_of_ensemble}_1500_128_0.319_42_epoch0/entropy_maps_over_batch_size.pickle', 'wb') as f:
        pickle.dump(entropy_maps_over_batch_size, f)
timer.duration()

## Results of quantiling based on entropy, as table 4 in the main paper, and tables 29-32 in the Appendix

In [None]:
cat_indices = dataset.train_cat_indices
cont_indices = dataset.train_cont_indices
percentile = 0.25 # change to display different results -- in the paper we use 0.25
display_data = np.zeros((len(batch_sizes), 5), dtype='object')

for i, batch_size in enumerate(batch_sizes):
    cont_top_errors = []
    cont_bottom_errors = []
    cat_top_errors = []
    cat_bottom_errors = []
    for j in range(n_samples):
        cont_error_map = error_maps_over_batch_size[i][j][:, cont_indices]
        cat_error_map = error_maps_over_batch_size[i][j][:, cat_indices]
        cont_entropy_map = entropy_maps_over_batch_size[i][j][:, cont_indices]
        cat_entropy_map = entropy_maps_over_batch_size[i][j][:, cat_indices]
        
        all_cont_len = np.prod(cont_error_map.shape)
        all_cat_len = np.prod(cont_error_map.shape)
        cont_percentile = np.ceil(0.25*all_cont_len).astype(int)
        cat_percentile = np.ceil(0.25*all_cat_len).astype(int)
        
        top_percentile_cat, bottom_percentile_cat = np.argsort(cat_entropy_map.flatten())[:cat_percentile], np.argsort(cat_entropy_map.flatten())[len(cat_entropy_map.flatten())-cat_percentile:]
        top_percentile_cont, bottom_percentile_cont = np.argsort(cont_entropy_map.flatten())[:cont_percentile], np.argsort(cont_entropy_map.flatten())[len(cont_entropy_map.flatten())-cont_percentile:]
        
        mean_top_error_cat, mean_bottom_error_cat = np.mean(cat_error_map.flatten()[top_percentile_cat]), np.mean(cat_error_map.flatten()[bottom_percentile_cat])
        mean_top_error_cont, mean_bottom_error_cont = np.mean(cont_error_map.flatten()[top_percentile_cont]), np.mean(cont_error_map.flatten()[bottom_percentile_cont])
        
        cont_top_errors.append(mean_top_error_cont)
        cont_bottom_errors.append(mean_bottom_error_cont)
        cat_top_errors.append(mean_top_error_cat)
        cat_bottom_errors.append(mean_bottom_error_cat)
    
    display_data[i] = batch_size, (np.around(100-100*np.nanmean(cat_top_errors), 1), np.around(100*np.nanstd(cat_top_errors), 1)), (np.around(100-100*np.nanmean(cat_bottom_errors), 1), np.around(100*np.nanstd(cat_bottom_errors), 1)), (np.around(100-100*np.nanmean(cont_top_errors), 1), np.around(100*np.nanstd(cont_top_errors), 1)), (np.around(100-100*np.nanmean(cont_bottom_errors), 1), np.around(100*np.nanstd(cont_bottom_errors), 1))
display_data_df = pd.DataFrame(data=display_data, columns=['Batch Size', f'Categorical Top {int(100*percentile)}%', f'Categorical Bottom {int(100*percentile)}%', f'Continuous Top {int(100*percentile)}%', f'Continuous Bottom {int(100*percentile)}%'])

In [None]:
display_data_df

## Results on the correlation of entropy and correctness for each batch size, as tables 25-28 in the Appendix

In [None]:
cat_indices = dataset.train_cat_indices
cont_indices = dataset.train_cont_indices
display_data = np.zeros((len(batch_sizes), 7), dtype='object')

for i, batch_size in enumerate(batch_sizes):
    cont_errors = []
    cont_entropies = []
    cat_errors = []
    cat_entropies = []
    for j in range(n_samples):
        cont_error_map = error_maps_over_batch_size[i][j][:, cont_indices]
        cat_error_map = error_maps_over_batch_size[i][j][:, cat_indices]
        cont_entropy_map = entropy_maps_over_batch_size[i][j][:, cont_indices]
        cat_entropy_map = entropy_maps_over_batch_size[i][j][:, cat_indices]
        
        cont_errors.append(np.nanmean(cont_error_map))
        cont_entropies.append(np.nanmean(np.ma.masked_invalid(cont_entropy_map)))
        cat_errors.append(np.nanmean(cat_error_map))
        cat_entropies.append(np.nanmean(cat_entropy_map))
    
    cont_tau = kendalltau(1-np.array(cont_errors), cont_entropies)[0]
    cat_tau = kendalltau(1-np.array(cat_errors), cat_entropies)[0]

    display_data[i, 0] = batch_size
    display_data[i, 1:4] = (np.around(100-100*np.nanmean(cat_errors), 1), np.around(100*np.nanstd(cat_errors), 1)), (np.around(np.nanmean(cat_entropies), 2), np.around(np.nanstd(cat_entropies), 2)), np.around(cat_tau, 2)
    display_data[i, 4:] = (np.around(100-100*np.nanmean(cont_errors), 1), np.around(100*np.nanstd(cont_errors), 1)), (np.around(np.nanmean(cont_entropies), 2), np.around(np.nanstd(cont_entropies), 2)), np.around(cont_tau, 2)
display_data_df = pd.DataFrame(data=display_data, columns=['Batch Size', 'Discrete Accuracy %', 'Discrete Entropy', 'Discrete Kendall\'s Tau', 'Continuous Accuracy %', 'Continuous Entropy', 'Continuous Kendall\'s Tau'])

In [None]:
display_data_df