In [1]:
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import seml.database as db_utils
from pathlib import Path


from itertools import product


import pandas as pd

import os

import sys
sys.path.append('../../../..')
from utils import load_results, merge_guarantees

import pickle

In [2]:
collection = 'group_amplification_neurips24_rdp'


jk_config = {
    'username': 'YOURUSERNAME',
    'password': 'YOURPASSWORD',
    'host': 'YOURDATABASEHOST',
    'port': 27017,
    'db_name': 'YOURDATABASENAME'
}

col = db_utils.get_collection(collection, mongodb_config=jk_config)

In [3]:
def get_experiments(col, restrictions={}):
    
    restrictions['status'] = 'COMPLETED'

    if col.count_documents(restrictions) == 0:
        raise ValueError('No matches!')

    exps = col.find(restrictions, {'config':1, 'result': 1, '_id': 1})
    
    return exps

In [4]:
def get_dp_guarantees(save_file):
    with open(save_file, 'rb') as f:
        results = pickle.load(f)

    return {
        'alphas': np.array(results['alphas']),
        'epsilons': np.array(results['epsilons'])
    }

In [5]:
def generate_exp_result_dict(exp):

    result_dict = {}

    

    result_dict['true_response_prob'] = exp['config']['base_mechanism']['params']['true_response_prob']
    result_dict['dataset_size'] = exp['config']['amplification']['params']['dataset_size']
    result_dict['batch_size'] = exp['config']['amplification']['params']['batch_size']

    result_dict['group_size'] = exp['config']['amplification']['params']['group_size']

    result_dict['tight'] = bool(exp['config']['amplification']['tight'])
    result_dict['eval_method'] = exp['config']['amplification']['params']['eval_method']
    result_dict['self_consistency'] = bool(exp['config']['amplification']['params']['eval_params'].get('use_self_consistency', False))

    save_file = exp['result']['save_file']

    result_dict['raw_results_file'] = save_file

    dp_dict = get_dp_guarantees(result_dict['raw_results_file'])

    result_dict.update(dp_dict)

    return result_dict

In [None]:
experiments = get_experiments(col, {'config.amplification.subsampling_scheme': 'withoutreplacement',
                                    'config.base_mechanism.name': 'randomizedresponse',
                                    'config.alphas.space': {'$in': ['log']},
                                    'config.amplification.params.group_size': 1,
                                    'config.amplification.params.dataset_size': 10000
                                    })
results = load_results(
            generate_exp_result_dict,
            experiments,
            results_file='./raw_data_randomized_response',
            overwrite=False
            )

results = results.loc[~(results['self_consistency'])]
results = results.loc[results['eval_method'].isin(['recursive', 'quadrature'])]
results = results.loc[results['true_response_prob'].isin([0.6, 0.75, 0.9])]


#results = results.loc[results['eval_method'].isin(['recursive', 'expansion', 'quadrature'])]
#
#results = results.loc[results['group_size'].isin([2, 4, 8])]
#results = results.loc[results['subsampling_rate'].isin([0.2, 0.1, 0.001])]


In [None]:
results

In [8]:
def prepare_plot_dict(dataset_size, batch_size, data):

    method_label_map = {
        'recursive': 'Agnostic',
        'quadrature': 'Specific'
    }

    plot_dict = {}

    for i, (index, row) in enumerate(data.iterrows()):
        alphas, epsilons, eval_method, true_response_prob = row.loc[['alphas', 'epsilons', 'eval_method', 'true_response_prob']]

        assert eval_method in ['recursive', 'quadrature']

        if eval_method == 'recursive':
            # Renyi-divergence is non-decreasing --> Make values smaller to favor baseline
            epsilons = np.minimum.accumulate(epsilons[::-1])[::-1]

        if eval_method not in plot_dict:
            plot_dict[eval_method] = {
                true_response_prob: (alphas, epsilons),
                'label': method_label_map[eval_method]
            }
        
        else:
            assert true_response_prob not in plot_dict[eval_method]

            plot_dict[eval_method][true_response_prob] = alphas, epsilons

    return plot_dict

In [9]:
def plot_plot_dict(plot_dict, draw_legend_method=False, draw_legend_probs=False, width=0.49):
    sns.set_theme()

    fig, ax = plt.subplots()

    pal = sns.color_palette('colorblind', 3)

    for i, (eval_method, eval_method_dict) in list(enumerate(plot_dict.items()))[::-1]:
        true_response_probs = np.sort([k for k in eval_method_dict if not isinstance(k, str)])

        for j, true_response_prob in enumerate(true_response_probs[::-1]):

            alphas, epsilons = eval_method_dict[true_response_prob]
            
            prob_label = true_response_prob if eval_method == 'quadrature' else None

            linestyle = 'solid' if eval_method == 'quadrature' else 'dashed'

            ax.plot(alphas, epsilons, label=prob_label, c=pal[j], linestyle=linestyle)


    ax.tick_params('both', which='major', length=2.5, width=0.75)
    ax.tick_params('both', which='minor', length=1.5, width=0.75, left=False)
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(2, 10000)
    ax.set_xlabel('RDP $\\alpha$', fontsize=9)
    ax.set_ylabel('RDP $\\rho(\\alpha)$', fontsize=9)

    if draw_legend_probs:
        legend_probs = ax.legend(loc='lower right', title='$\\theta$', title_fontsize=9)

    if draw_legend_method:
        handles_ls = []
        handles_ls.append(ax.plot([], [], c='black', ls='dashed')[0])
        handles_ls.append(ax.plot([], [], c='black', ls='solid')[0])
        legend_method = ax.legend(handles_ls, [v['label'] for v in list(plot_dict.values())], loc=('upper left' if draw_legend_probs else 'lower right'))

        if draw_legend_probs:
            ax.add_artist(legend_probs)

In [None]:
save_dir = '/ceph/hdd/staff/schuchaj/group_amplification_plots/neurips24/rdp/without_replacement/specific_vs_agnostic/randomized_response/half_page'

for x in results.groupby(['dataset_size', 'batch_size']):

    dataset_size, batch_size = x[0]
    plot_dict = prepare_plot_dict(dataset_size, batch_size, x[1])
    plot_plot_dict(plot_dict, draw_legend_method=(batch_size in [10, 100]), draw_legend_probs=(batch_size in [10, 100]))
    
    plt.savefig(f'{save_dir}/{dataset_size}_{batch_size}.png', dpi=256)