In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
from math import log, sqrt
import numpy as np
import random
import matplotlib.pyplot as plt
from scipy.linalg import hadamard
import scipy.io as io
import scipy.special as sc
from tqdm import tqdm
from scipy.stats import ortho_group
import os
import dp_accounting
from dp_accounting import dp_event
from dp_accounting import privacy_accountant
from dp_accounting import privacy_accountant_test
from dp_accounting.rdp import rdp_privacy_accountant
from quantized_schemes import *
from utils import *
import pickle

In [9]:
def run_exp(d, n, c, num_itr,delta_target=1e-5,
            seed=1234, scaled=True):
    # Logging
    np.random.seed(seed)
    mse_quant, bits_quant = {}, {}
    mse_exact, bits_exact = {}, {}
    k = [mse_quant, bits_quant, mse_exact, bits_exact]
    RDP_ORDERS = np.array(list(range(2, 129, 3)) + [256.])
    # Experiments setting
    gamma_list = [0.3, 0.5, 1.0]
    eps_target_list = np.linspace(start=0.5, stop=4.0, num=15)
    if scaled==True:
        c = c/np.sqrt(d)
    # Get sigma dict
    sigma_dic = {}
    for eps_target in tqdm(eps_target_list):
        for gamma in gamma_list:
            sigma = get_SGM_sigma_from_rdp(
                n, d, gamma, eps_target, c=c, delta=delta_target)
            sigma_dic[(eps_target, gamma)] = sigma
            #Initializing ,logging dict
            for i in k:
                i[(eps_target, gamma)] = np.zeros(num_itr)
    #Data generation
    X = np.zeros((num_itr,n,d))
    for i in range(num_itr):
        p = 0.8*np.ones((n, d))
        X[i] = c*(2*np.random.binomial(1, p)-1)*np.random.uniform(size=(n, d))
    X_true_mean = np.mean(X, axis=1)
    print(X_true_mean.shape)
    
    #Running the scheme for multishift
    for gamma in gamma_list:
        for eps_target in tqdm(eps_target_list):
            sigma = sigma_dic[(eps_target, gamma)]
            for i in range(num_itr):
                X_priv_mean_e, bits = SIGM(X[i], sigma=sigma, gamma=gamma)
                mse_e = np.sum((X_priv_mean_e - X_true_mean[i])**2)
                mse_exact[(eps_target, gamma)][i] = mse_e
                bits_exact[(eps_target, gamma)][i] = bits

            for i in range(num_itr):
                # print(np.mean(bits_exact[(eps_target, gamma)]))
                b = np.ceil(2**np.mean(bits_exact[(eps_target, gamma)]/(n*d)))
                X_priv_mean_q, bits = CSGM(X[i], sigma=sigma, gamma=gamma, s=b)
                mse_q = np.sum((X_priv_mean_q - X_true_mean[i])**2)
                mse_quant[(eps_target, gamma)][i] = mse_q
                bits_quant[(eps_target, gamma)][i] = bits
                
    print("Saving the files")
    name = "n"+str(n)+"_"+"d"+str(d)+"_"+"c"+str(scaled)+"_"+"delta"+str(delta_target)
    with open("new_results"+"/"+"name"+"_exact.pkl", 'wb') as file:
            pickle.dump(mse_exact, file)
    with open("new_results"+"/"+name+"_exact_bits.pkl", 'wb') as file:
            pickle.dump(bits_exact, file)
    with open("new_results"+"/"+name+"_quant.pkl", 'wb') as file:
            pickle.dump(mse_quant, file)
    with open("new_results"+"/"+name+"_quant_bits.pkl", 'wb') as file:
            pickle.dump(bits_quant, file)
    return eps_target_list,mse_exact, bits_exact, mse_quant, bits_quant

In [10]:
def draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant):
    # SIGM
    plt.plot(eps_target_list,[np.mean(bits_exact[(e,0.3)] / (d*n)) for e in eps_target_list],
        label = rf'SIGM ($\gamma$=0.1)',color='gold',ms=5,
        marker='v',linewidth=1)
    plt.plot(eps_target_list,[np.mean(bits_exact[(e,0.5)]/ (d*n)) for e in eps_target_list],
        label = rf'SIGM ($\gamma$=0.5)',color='orange',ms=5,
        marker='v',linewidth=1)
    plt.plot(eps_target_list,[np.mean(bits_exact[(e,1.0)]/ (d*n)) for e in eps_target_list],
        label = rf'SIGM ($\gamma$=1.0)',color='red',ms=5,
        marker='v',linewidth=1)
    plt.legend()
    plt.show()

    plt.figure(figsize=(5,5))
    # SIGM
    plt.plot(eps_target_list,[np.mean(mse_exact[(e,0.3)]) for e in eps_target_list],
        label = rf'SIGM ($\gamma$=0.3)',color='gold',ms=5,
        marker='v',linewidth=1)
    plt.plot(eps_target_list,[np.mean(mse_exact[(e,0.5)]) for e in eps_target_list],
        label = rf'SIGM ($\gamma$=0.5)',color='orange',ms=5,
        marker='v',linewidth=1)
    plt.plot(eps_target_list,[np.mean(mse_exact[(e,1.0)]) for e in eps_target_list],
        label = rf'SIGM ($\gamma$=1.0)',color='red',ms=5,
        marker='v',linewidth=1)

    # CSGM
    plt.plot(eps_target_list,[np.mean(mse_quant[(e,0.3)]) for e in eps_target_list],
        linestyle='--',label = rf'CIGM ($\gamma$=0.3)',color='cyan',ms=4,
        marker='o',linewidth=1)
    plt.plot(eps_target_list,[np.mean(mse_quant[(e,0.5)]) for e in eps_target_list],
        linestyle='--',label = rf'CIGM ($\gamma$=0.5)',color='dodgerblue',ms=4,
        marker='o',linewidth=1)
    plt.plot(eps_target_list,[np.mean(mse_quant[(e,1.0)]) for e in eps_target_list],
        linestyle='--',label = rf'CIGM ($\gamma$=1.0)',color='navy',ms=4,
        marker='o',linewidth=1)
    plt.legend(fontsize=10)
    plt.xlabel(r'Privacy ($\varepsilon$)', fontsize=16)
    # plt.text(0.0, 0.35, 'MSE', rotation=90,fontsize=16, va='center')
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.yscale('log')
    plt.grid(alpha=0.2,which='both')
    plt.show()    

In [11]:
num_itr = 10

## Scalled Experiments:
    Configurations:
        * n=500, d=500
        * n=500, d=5_000
        * n=500, d=1_000

#### Delta  1e-5

In [None]:
d = 500
c = 1
n = 500
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)


In [None]:
d = 1000
c = 1
n = 500
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

In [None]:
d = 5000
c = 1
n = 500
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

#### Delta 1e-6

In [None]:
d = 500
c = 1
n = 500
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr, delta_target = 1e-6)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

In [None]:
d = 1000
c = 1
n = 500
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr, delta_target = 1e-6)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

In [None]:
d = 5000
c = 1
n = 500
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr, delta_target = 1e-6)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

# More experiments:

In [None]:
d = 20
c = 1
n = 100
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr,scaled=False)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

In [None]:
d = 20
c = 1
n = 200
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr,scaled=False)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

In [None]:
d = 20
c = 1
n = 100
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr,scaled=False)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

In [12]:
d = 20
c = 1
n = 100
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr,scaled=False)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

 67%|██████▋   | 10/15 [00:26<00:13,  2.71s/it]

In [None]:
d = 20
c = 1
n = 100
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr,scaled=True)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)

In [None]:
d = 100
c = 1
n = 1000
eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant = run_exp(d,n,c,num_itr,scaled=False)
draw_figures(eps_target_list, mse_exact, bits_exact, mse_quant, bits_quant)