In [1]:
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import seml.database as db_utils
from pathlib import Path


from itertools import product


import pandas as pd

import os

import sys
sys.path.append('../../../..')
from utils import load_results, merge_guarantees

import pickle

In [2]:
collection = 'group_amplification_neurips24_rdp'


jk_config = {
    'username': 'YOURUSERNAME',
    'password': 'YOURPASSWORD',
    'host': 'YOURDATABASEHOST',
    'port': 27017,
    'db_name': 'YOURDATABASENAME'
}

col = db_utils.get_collection(collection, mongodb_config=jk_config)

In [3]:
def get_experiments(col, restrictions={}):
    
    restrictions['status'] = 'COMPLETED'

    if col.count_documents(restrictions) == 0:
        raise ValueError('No matches!')

    exps = col.find(restrictions, {'config':1, 'result': 1, '_id': 1})
    
    return exps

In [4]:
def get_dp_guarantees(save_file):
    with open(save_file, 'rb') as f:
        results = pickle.load(f)

    return {
        'alphas': np.array(results['alphas']),
        'epsilons': np.array(results['epsilons'])
    }

In [5]:
def generate_exp_result_dict(exp):

    result_dict = {}

    

    result_dict['std'] = exp['config']['base_mechanism']['params']['standard_deviation']
    result_dict['dataset_size'] = exp['config']['amplification']['params']['dataset_size']
    result_dict['batch_size'] = exp['config']['amplification']['params']['batch_size']

    result_dict['group_size'] = exp['config']['amplification']['params']['group_size']

    result_dict['tight'] = bool(exp['config']['amplification']['tight'])
    result_dict['eval_method'] = exp['config']['amplification']['params']['eval_method']
    result_dict['self_consistency'] = bool(exp['config']['amplification']['params']['eval_params'].get('use_self_consistency', False))

    save_file = exp['result']['save_file']

    result_dict['raw_results_file'] = save_file

    dp_dict = get_dp_guarantees(result_dict['raw_results_file'])

    result_dict.update(dp_dict)

    return result_dict

In [None]:
experiments = get_experiments(col, {'config.amplification.subsampling_scheme': 'withoutreplacement',
                                    'config.base_mechanism.name': 'gaussian',
                                    'config.alphas.space': 'log'
                                    })
results = load_results(
            generate_exp_result_dict,
            experiments,
            results_file='./raw_data',
            overwrite=False
            )

results = results.loc[results['eval_method'].isin(['recursive', 'directtransport'])]
results = results.loc[results['std'].isin([0.5, 5.0])]
results = results.loc[~results['batch_size'].isin([2000, 5000])]


results = results.loc[results['group_size'].isin([1, 2, 4])]
results = results.loc[results['tight'] | results['self_consistency']]

baseline_results_orig = results.loc[~results['tight']].copy()
baseline_results = results.loc[~results['tight']]
baseline_results['dataset_size'] = baseline_results['dataset_size'].mul(10)
baseline_results['batch_size'] = baseline_results['batch_size'].mul(10)
results = pd.concat(( baseline_results, results))
baseline_results = baseline_results_orig
baseline_results['dataset_size'] = baseline_results['dataset_size'].floordiv(10)
baseline_results['batch_size'] = baseline_results['batch_size'].floordiv(10)
results = pd.concat(( baseline_results, results))

In [None]:
results

In [None]:
results.loc[(results['dataset_size'] == 1000) & (results['batch_size'] == 100)]

In [9]:
def prepare_plot_dict(data):

    method_label_map = {
        'recursive': 'Posthoc',
        'directtransport': 'No conditioning'
    }

    plot_dict = {}

    for i, (index, row) in enumerate(data.iterrows()):
        alphas, epsilons, eval_method, group_size = row.loc[['alphas', 'epsilons', 'eval_method', 'group_size']]

        assert eval_method in ['recursive', 'directtransport']

        if eval_method == 'recursive':
            # Renyi-divergence is non-decreasing --> Make values smaller to favor baseline
            epsilons = np.minimum.accumulate(epsilons[::-1])[::-1]

        if eval_method not in plot_dict:
            plot_dict[eval_method] = {
                group_size: (alphas, epsilons),
                'label': method_label_map[eval_method]
            }
        
        else:
            assert group_size not in plot_dict[eval_method]

            plot_dict[eval_method][group_size] = alphas, epsilons

    return plot_dict

In [10]:
def plot_plot_dict(plot_dict, draw_legend_group_size=False, draw_legend_method=False, width=0.49):
    sns.set_theme()

    fig, ax = plt.subplots()

    pal = sns.color_palette('colorblind', 3)

    for i, (eval_method, eval_method_dict) in list(enumerate(plot_dict.items()))[::-1]:
        group_sizes = np.sort([k for k in eval_method_dict if not isinstance(k, str)])

        for j, group_size in enumerate(group_sizes[::-1]):

            alphas, epsilons = eval_method_dict[group_size]
            
            prob_label = group_size if eval_method == 'recursive' else None

            linestyle = 'solid' if eval_method == 'recursive' else 'dashed'

            ax.plot(alphas, epsilons, label=prob_label, c=pal[j], linestyle=linestyle)


    ax.tick_params('both', which='major', length=2.5, width=0.75)
    ax.tick_params('both', which='minor', length=1.5, width=0.75, left=False)
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(2, 10000)
    ax.set_xlabel('RDP $\\alpha$', fontsize=9)
    ax.set_ylabel('RDP $\\rho(\\alpha)$', fontsize=9)

    if draw_legend_group_size:
        legend_group_size = ax.legend(loc='lower right', title='Group size', title_fontsize=9)

    if draw_legend_method:
        handles_ls = []
        handles_ls.append(ax.plot([], [], c='black', ls='dashed')[0])
        handles_ls.append(ax.plot([], [], c='black', ls='solid')[0])
        legend_method = ax.legend(handles_ls, ['No conditioning', 'Post-hoc'], loc=('upper left' if draw_legend_group_size else 'lower right'))

        if draw_legend_group_size:
            ax.add_artist(legend_group_size)

In [None]:
save_dir = '/ceph/hdd/staff/schuchaj/group_amplification_plots/neurips24/rdp/without_replacement/direct_transport_vs_posthoc/gaussian/half_page/both_legends'

for x in results.groupby(['dataset_size', 'batch_size', 'std']):

    dataset_size, batch_size, std = x[0]
    plot_dict = prepare_plot_dict(x[1])

    plot_plot_dict(plot_dict, draw_legend_group_size=True, draw_legend_method=True, width=0.49)
    
    plt.savefig(f'{save_dir}/{dataset_size}_{batch_size}_{std}.png', dpi=256)

In [None]:
save_dir = '/ceph/hdd/staff/schuchaj/group_amplification_plots/neurips24/rdp/without_replacement/direct_transport_vs_posthoc/gaussian/half_page/no_legend'

for x in results.groupby(['dataset_size', 'batch_size', 'std']):

    dataset_size, batch_size, std = x[0]
    plot_dict = prepare_plot_dict(x[1])

    plot_plot_dict(plot_dict, draw_legend_group_size=False, draw_legend_method=False, width=0.49)
    
    plt.savefig(f'{save_dir}/{dataset_size}_{batch_size}_{std}.png', dpi=256)