In [1]:
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import seml.database as db_utils
from pathlib import Path


from itertools import product


import pandas as pd

import os

import sys
sys.path.append('../../../..')
from utils import load_results, merge_guarantees

import pickle

In [2]:
collection = 'group_amplification_neurips24_rdp'


jk_config = {
    'username': 'YOURUSERNAME',
    'password': 'YOURPASSWORD',
    'host': 'YOURDATABASEHOST',
    'port': 27017,
    'db_name': 'YOURDATABASENAME'
}

col = db_utils.get_collection(collection, mongodb_config=jk_config)

In [3]:
def get_experiments(col, restrictions={}):
    
    restrictions['status'] = 'COMPLETED'

    if col.count_documents(restrictions) == 0:
        raise ValueError('No matches!')

    exps = col.find(restrictions, {'config':1, 'result': 1, '_id': 1})
    
    return exps

In [4]:
def get_dp_guarantees(save_file):
    with open(save_file, 'rb') as f:
        results = pickle.load(f)

    return {
        'alphas': np.array(results['alphas']),
        'epsilons': np.array(results['epsilons'])
    }

In [5]:
def generate_exp_result_dict(exp):

    result_dict = {}

    
    result_dict['space'] = exp['config']['alphas']['space']
    result_dict['true_response_prob'] = exp['config']['base_mechanism']['params']['true_response_prob']
    result_dict['subsampling_rate'] = exp['config']['amplification']['params']['subsampling_rate']

    result_dict['group_size'] = exp['config']['amplification']['params']['group_size']
    result_dict['insertions'] = exp['config']['amplification']['params']['insertions']

    result_dict['tight'] = bool(exp['config']['amplification']['tight'])
    result_dict['forward'] = bool(exp['config']['amplification']['params'].get('forward', 0))
    result_dict['eval_method'] = exp['config']['amplification']['params']['eval_method']

    save_file = exp['result']['save_file']

    result_dict['raw_results_file'] = save_file

    dp_dict = get_dp_guarantees(result_dict['raw_results_file'])

    result_dict.update(dp_dict)

    return result_dict

In [None]:
experiments = get_experiments(col, {'config.amplification.subsampling_scheme': 'poisson',
                                    'config.base_mechanism.name': 'randomizedresponse',
                                    'config.alphas.space': {'$in': ['log', 'linear']},
                                    })
results = load_results(
            generate_exp_result_dict,
            experiments,
            results_file='./raw_data_randomized_response',
            overwrite=False
            )

results = results.loc[results['eval_method'].isin(['recursive', 'quadrature'])]

results = results.loc[results['group_size'].isin([1, 2, 4, 8])]
#results = results.loc[results['subsampling_rate'].isin([0.2, 0.1, 0.001])]


In [None]:
results

In [8]:
method_label_map = {
        'recursive': 'Post-hoc',
        'expansion': 'Specific',
    }

In [9]:
def prepare_plot_dict(data):

    plot_dict = {}

    for i, (index, row) in enumerate(data.iterrows()):
        alphas, epsilons, eval_method, group_size = row.loc[['alphas', 'epsilons', 'eval_method', 'group_size']]

        assert eval_method in ['recursive', 'expansion', 'quadrature']

        if eval_method == 'recursive':
            # Renyi-divergence is non-decreasing --> Make values smaller to favor baseline
            epsilons = np.minimum.accumulate(epsilons[::-1])[::-1]
        
        if eval_method == 'quadrature':
            eval_method = 'expansion'

        if eval_method not in plot_dict:
            plot_dict[eval_method] = {
                group_size: (alphas, epsilons),
                'label': method_label_map[eval_method]
            }
        
        elif group_size not in plot_dict[eval_method]:

            plot_dict[eval_method][group_size] = alphas, epsilons

        else:
            old_alphas, old_epsilons = plot_dict[eval_method][group_size]
            merged_alphas, merged_epsilons = merge_guarantees(old_alphas, alphas,
                                                                  old_epsilons, epsilons,
                                                                  max)
            
            if eval_method == 'recursive':
                # Renyi-divergence is non-decreasing --> Make values smaller to favor baseline
                merged_epsilons = np.minimum.accumulate(merged_epsilons[::-1])[::-1]

            plot_dict[eval_method][group_size] = merged_alphas, merged_epsilons

    return plot_dict

In [10]:
def plot_plot_dict(plot_dict, draw_legend_group_size=False, draw_legend_method=False, width=0.49, xlim=[2, 10000]):
    sns.set_theme()

    fig, ax = plt.subplots()

    pal = sns.color_palette('colorblind', 4)[::-1]

    smallest_alpha = None
    largest_alpha = None

    for i, (eval_method, eval_method_dict) in list(enumerate(plot_dict.items())):
        group_sizes = np.sort([k for k in eval_method_dict if not isinstance(k, str)])

        for j, group_size in enumerate(group_sizes[::-1]):

            alphas, epsilons = eval_method_dict[group_size]
            smallest_alpha = alphas.min() if smallest_alpha is None else min(smallest_alpha, alphas.min())

            prob_label = group_size if eval_method == 'expansion' else None

            linestyle = 'solid' if eval_method in ['quadrature', 'expansion'] else 'dashed'

            ax.plot(alphas, epsilons, label=prob_label, c=pal[int(np.log2(group_size)) - 1], linestyle=linestyle)

    ax.set_xscale('log')
    ax.set_yscale('log')

    ax.tick_params('both', which='major', length=2.5, width=0.75)
    ax.tick_params('both', which='minor', length=1.5, width=0.75, left=False)

    if xlim == [None, 2]:
        ax.set_xlim(smallest_alpha, 2)
        ax.set_xticks([1.1, 1.5, 2.0], minor=True)
    
    elif xlim == [None, 3]:
        ax.set_xlim(smallest_alpha, 3)
        ax.set_xticks([1.1, 2, 3.0], minor=True)
    
    elif xlim == [None, 4]:
        ax.set_xlim(smallest_alpha, 4)
        ax.set_xticks([1.1, 2, 4.0], minor=True)
    
    elif xlim == [None, 10]:
        ax.set_xlim(smallest_alpha, 10)
        ax.set_xticks([1.1, 4, 10], minor=True)
    
    #else:
    #    ax.set_xlim(left=2, right=10**4)
    #    ax.set_xticks([2, 3, 4, 5, 6, 7, 8, 9,
    #                   20, 30, 40, 50, 60, 70, 80, 90,
    #                   200, 300, 400, 500, 600, 700, 800, 900,
    #                   2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000], minor=True)

    else:
        ax.set_xlim(left=2, right=10**4)

    ax.set_xlabel('RDP order $\\alpha$', fontsize=9)
    ax.set_ylabel('RDP $\\epsilon(\\alpha)$', fontsize=9)

    if draw_legend_group_size:
        legend_group_size = ax.legend(loc='lower right', title='Group size', title_fontsize=9)

    if draw_legend_method:
        handles_ls = []
        handles_ls.append(ax.plot([], [], c='black', ls='dashed')[0])
        handles_ls.append(ax.plot([], [], c='black', ls='solid')[0])

        legend_method = ax.legend(handles_ls, list(method_label_map.values()), loc=('upper left' if True else 'lower right'))

        if draw_legend_group_size:
            ax.add_artist(legend_group_size)

In [11]:
xlim_dict = {
}

In [None]:
save_dir = '/ceph/hdd/staff/schuchaj/group_amplification_plots/neurips24/rdp/poisson/specific_vs_posthoc/laplace/half_page'

for x in results.groupby(['subsampling_rate', 'true_response_prob']):
    
    subsampling_rate, true_response_prob = x[0]
    plot_dict = prepare_plot_dict(x[1])

    plot_plot_dict(plot_dict, draw_legend_group_size=True, draw_legend_method=True, width=0.49)
    plt.savefig(f'{save_dir}/{subsampling_rate}_{true_response_prob}.png', dpi=256)
    #plt.close()