In [None]:
import numpy as np
import scipy.special
from scipy.optimize import minimize
import pandas as pd
import time
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional

# --- Config and States ---

@dataclass
class RandOpts:
    method: str
    eps1_coeff: List[float]
    dirichlet_coeff: float

@dataclass
class OnEstOpts:
    on_est_vec: List[int]
    active_on_est_method: int
    online_EM_step_size_param: float
    offline_EM_iters: int
    offline_EM_update_per: int
    P_quad_prog_update_per: int

@dataclass
class FinEstOpts:
    fin_est_vec: List[int]
    offline_EM_iters: int
    online_EM_step_size_param: float
    online_EM_iters: int

@dataclass
class SimulationState:
    # Core data arrays
    Y: np.ndarray
    log_P_yx: np.ndarray
    Theta_Est: List[np.ndarray]
    
    # State for specific estimators
    C: Optional[np.ndarray] = None
    theta_online_EM_est: Optional[np.ndarray] = None
    theta_offline_est: Optional[np.ndarray] = None
    theta_QP1_est: Optional[np.ndarray] = None
    theta_QP2_est: Optional[np.ndarray] = None
    
    # QP Accumulators
    Sum_GHTGH: Optional[np.ndarray] = None
    sum_GHTc: Optional[np.ndarray] = None
    Sum_HTGH: Optional[np.ndarray] = None
    sum_HTc: Optional[np.ndarray] = None
    
    # Estimation State
    theta_tilde: Optional[np.ndarray] = None

# --- Utilities ---

def log_sum_exp(x):
    max_x = np.max(x, axis=0)
    return np.log(np.sum(np.exp(x - max_x), axis=0)) + max_x

def calculate_utility(theta, G, util_type):
    theta_ord = np.sort(theta)[::-1]
    if util_type == 6: # Prob(Y = X)
        return np.sum(theta_ord * np.diag(G))
    return 0

def make_G(K, k0, eps1, eps2):
    exp_eps1, exp_eps2 = np.exp(eps1), np.exp(eps2)
    p11 = exp_eps1 / (exp_eps1 + min(k0, K - 1))
    p12 = 1 / (exp_eps1 + min(k0, K - 1))
    if K - k0 - 1 <= 0 or (exp_eps2 + K - k0 - 1) == 0:
        p22, p21 = 1, 0
    else:
        p22 = exp_eps2 / (exp_eps2 + K - k0 - 1)
        p21 = 1 / (exp_eps2 + K - k0 - 1)
    G = np.zeros((K, K))
    if k0 > 0: G[0:k0, 0:k0] = (p11 - p12) * np.eye(k0) + p12
    if K > k0:
        G[k0:K, k0:K] = p11 * ((p22 - p21) * np.eye(K-k0) + p21)
        G[k0:K, 0:k0] = p12 / (K - k0) if K > k0 else 0
        G[0:k0, k0:K] = p12
    return G

# --- Estimation Algorithms ---

def MLE_RR(Y, K, eps_DP):
    S, _ = np.histogram(Y, bins=np.arange(1, K + 2))
    S_sorted = np.sort(S)
    k = 0
    while k < K:
        denominator = S_sorted[k]
        if denominator == 0: break
        if (K - k) + np.exp(eps_DP) - 1 - np.sum(S_sorted[k:]) / denominator < 0:
            k += 1
        else: break
    sum_S_sorted_k_onward = np.sum(S_sorted[k:])
    if sum_S_sorted_k_onward == 0: return np.ones(K) / K
    Phi_vec = S / sum_S_sorted_k_onward
    term1 = (K - k) / (np.exp(eps_DP) - 1) + 1
    term2 = 1 / (np.exp(eps_DP) - 1)
    theta = np.maximum(0, Phi_vec * term1 - term2)
    return theta / np.sum(theta) if np.sum(theta) > 0 else np.ones(K) / K

def offline_LDP_EM(theta0, log_P_yx, M):
    theta_est = theta0.flatten().copy()
    K = len(theta_est)
    Theta_est = np.zeros((M, K))
    for m in range(M):
        log_Pi_post = log_P_yx + np.log(theta_est[:, np.newaxis])
        Pi_post = np.exp(log_Pi_post - log_sum_exp(log_Pi_post))
        theta_est = np.mean(Pi_post, axis=1)
        Theta_est[m, :] = theta_est
    return Theta_est

def estimate_online_em(current_theta, p_yx_vec, t, step_size_param):
    pi_post_unnorm = np.log(current_theta) + np.log(p_yx_vec)
    pi_post = np.exp(pi_post_unnorm - log_sum_exp(pi_post_unnorm))
    gamma_t = (t + 1) ** (-step_size_param)
    return (1 - gamma_t) * current_theta + gamma_t * pi_post

def estimate_online_mm(C, H, y, t, K, K_prime, eps_DP):
    C[H == y] += 1
    p = np.exp(eps_DP) / (np.exp(eps_DP) + K_prime - 1)
    q = 1 / K_prime
    F = (C - (t + 1) * q) / (p - q)
    F_max = np.maximum(F, 0)
    return F_max / np.sum(F_max) if np.sum(F_max) > 0 else np.ones(K)/K

def estimate_online_qp1(current_theta, Sum_GHTGH, sum_GHTc):
    def quad_fun(theta): return 0.5 * theta.T @ Sum_GHTGH @ theta - sum_GHTc @ theta
    cons = ({'type': 'eq', 'fun': lambda x: np.sum(x) - 1}, {'type': 'ineq', 'fun': lambda x: x})
    res = minimize(quad_fun, current_theta, method='SLSQP', constraints=cons, options={'disp': False})
    return res.x

def estimate_online_qp2(current_theta, Sum_HTGH, sum_HTc):
    MM2 = Sum_HTGH.T @ Sum_HTGH
    vv2 = -Sum_HTGH.T @ sum_HTc
    def quad_fun(theta): return 0.5 * theta.T @ MM2 @ theta + vv2 @ theta
    cons = ({'type': 'eq', 'fun': lambda x: np.sum(x) - 1}, {'type': 'ineq', 'fun': lambda x: x})
    res = minimize(quad_fun, current_theta, method='SLSQP', constraints=cons, options={'disp': False})
    return res.x

def estimate_final_online_em(log_P_yx, initial_theta, T, iters, step_size_param):
    start_time = time.time()
    theta_est = initial_theta.copy()
    theta_est_all = np.zeros((iters, len(initial_theta)))
    for i in range(iters):
        for t_inner in range(T):
            theta_est = estimate_online_em(theta_est, np.exp(log_P_yx[:, t_inner]), i * T + t_inner, step_size_param)
        theta_est_all[i, :] = theta_est
    return theta_est_all, time.time() - start_time

def estimate_final_mm(C, T, K, K_prime, eps_DP):
    p = np.exp(eps_DP) / (np.exp(eps_DP) + K_prime - 1)
    q = 1 / K_prime
    F = (C - T * q) / (p - q)
    F_max = np.maximum(F, 0)
    return F_max / np.sum(F_max) if np.sum(F_max) > 0 else np.ones(K)/K, 0.0

# --- Components ---

def initialize_state(K, T, theta0, on_est_opts, fin_est_opts):
    use_qp = np.sum(fin_est_opts.fin_est_vec[4:6]) + np.sum(on_est_opts.on_est_vec[4:6]) > 0
    return SimulationState(
        Y=np.zeros(T), log_P_yx=np.zeros((K, T)), Theta_Est=[np.zeros((K, T)) for _ in range(6)],
        C=np.zeros(K) if on_est_opts.on_est_vec[3] or fin_est_opts.fin_est_vec[3] else None,
        theta_online_EM_est=theta0.copy() if on_est_opts.on_est_vec[1] else None,
        theta_offline_est=theta0.copy() if on_est_opts.on_est_vec[2] else None,
        theta_QP1_est=theta0.copy() if on_est_opts.on_est_vec[4] else None,
        theta_QP2_est=theta0.copy() if on_est_opts.on_est_vec[5] else None,
        Sum_GHTGH=np.zeros((K, K)) if use_qp else None, sum_GHTc=np.zeros(K) if use_qp else None,
        Sum_HTGH=np.zeros((K, K)) if use_qp else None, sum_HTc=np.zeros(K) if use_qp else None,
        theta_tilde=theta0.copy()
    )

def initialize_adaptive_mechanisms(K_prime, eps_DP, rand_opts):
    GG = []
    for eps1_coeff in rand_opts.eps1_coeff:
        eps1 = eps1_coeff * eps_DP
        Sc_min = int(np.ceil(np.exp(eps_DP - eps1)))
        k_vec = np.arange(Sc_min, K_prime + 1)
        eps2_UB = np.full(K_prime + 1, eps_DP)
        denom = np.exp(eps1 - eps_DP) * k_vec - 1
        valid_k_vec = k_vec[denom != 0]
        eps2_UB[valid_k_vec] = (valid_k_vec - 1) / denom[denom != 0]
        eps2_vec = np.minimum(eps_DP, eps2_UB)
        for k in range(1, K_prime + 1): GG.append(make_G(K_prime, k, eps1, eps2_vec[k]))
    return GG

def perform_privatization(t, x, T, state, G0, GG, K, K_prime, rand_opts, hashing):
    H_mtx = np.eye(K_prime, dtype=int)[np.random.randint(0, K_prime, size=K)].T if hashing else None
    H = np.arange(K) if not hashing else np.where(H_mtx.T)[1]
    if rand_opts.method not in ['RRRR', 'OLH+RRRR'] or t < T // 10:
        G_current = G0
    else:
        theta_tilde_compressed = H_mtx @ state.theta_tilde if hashing else state.theta_tilde
        ord_ind = np.argsort(theta_tilde_compressed)[::-1]
        L = [calculate_utility(theta_tilde_compressed[ord_ind], g, 6) for g in GG]
        k_best = np.argmax(L)
        ord_ind_inv = np.argsort(ord_ind)
        G_current = GG[k_best][ord_ind_inv, :][:, ord_ind_inv]
    y = np.random.choice(K_prime, p=G_current[:, H[x]])
    p_yx_vec = G_current[y, H]
    return y, p_yx_vec, H, G_current, H_mtx

def update_online_estimates(t, y, p_yx_vec, H, G_current, H_mtx, state, K, K_prime, eps_DP, on_est_opts, hashing):
    if on_est_opts.on_est_vec[0]: state.Theta_Est[0][:, t] = MLE_RR(state.Y[:t+1], K, eps_DP)
    if on_est_opts.on_est_vec[1]:
        state.theta_online_EM_est = estimate_online_em(state.theta_online_EM_est, p_yx_vec, t, on_est_opts.online_EM_step_size_param)
        state.Theta_Est[1][:, t] = state.theta_online_EM_est
    if on_est_opts.on_est_vec[2]:
        if (t + 1) % on_est_opts.offline_EM_update_per == 0:
            state.theta_offline_est = offline_LDP_EM(state.theta_offline_est, state.log_P_yx[:, :t+1], on_est_opts.offline_EM_iters)[-1,:]
        state.Theta_Est[2][:, t] = state.theta_offline_est
    if on_est_opts.on_est_vec[3]: state.Theta_Est[3][:, t] = estimate_online_mm(state.C, H, y, t, K, K_prime, eps_DP)
    if state.Sum_GHTGH is not None:
        GH_curr = G_current @ H_mtx if hashing else G_current
        ct = np.eye(K_prime)[y]
        state.Sum_GHTGH += GH_curr.T @ GH_curr; state.sum_GHTc += GH_curr.T @ ct
        state.Sum_HTGH += H_mtx.T @ GH_curr if hashing else GH_curr; state.sum_HTc += H_mtx.T @ ct if hashing else ct
    if on_est_opts.on_est_vec[4]:
        if (t + 1) % on_est_opts.P_quad_prog_update_per == 0:
            state.theta_QP1_est = estimate_online_qp1(state.theta_QP1_est, state.Sum_GHTGH, state.sum_GHTc)
        state.Theta_Est[4][:, t] = state.theta_QP1_est
    if on_est_opts.on_est_vec[5]:
        if (t + 1) % on_est_opts.P_quad_prog_update_per == 0:
            state.theta_QP2_est = estimate_online_qp2(state.theta_QP2_est, state.Sum_HTGH, state.sum_HTc)
        state.Theta_Est[5][:, t] = state.theta_QP2_est

def adapt_mechanism(t, state, rand_opts, on_est_opts):
    theta_est_to_be_used = state.Theta_Est[on_est_opts.active_on_est_method - 1][:, t]
    gamma_shape = 1 + rand_opts.dirichlet_coeff * (t + 1) * theta_est_to_be_used
    gamma_shape[gamma_shape < 0] = 0
    state.theta_tilde = np.random.gamma(gamma_shape)
    state.theta_tilde /= np.sum(state.theta_tilde)

def run_final_estimation(state, K, K_prime, eps_DP, theta0, fin_est_opts, on_est_opts):
    theta_fin_est = [np.zeros(K) for _ in range(6)]
    comp_times = np.zeros(6)
    initial_theta_for_final = state.Theta_Est[on_est_opts.active_on_est_method - 1][:, -1] if on_est_opts.active_on_est_method > 0 else theta0

    if fin_est_opts.fin_est_vec[0]:
        start = time.time(); theta_fin_est[0] = MLE_RR(state.Y, K, eps_DP); comp_times[0] = time.time() - start
    if fin_est_opts.fin_est_vec[1]:
        theta_fin_est[1], comp_times[1] = estimate_final_online_em(state.log_P_yx, initial_theta_for_final, len(state.Y), fin_est_opts.online_EM_iters, fin_est_opts.online_EM_step_size_param)
    if fin_est_opts.fin_est_vec[2]:
        start = time.time(); theta_fin_est[2] = offline_LDP_EM(initial_theta_for_final, state.log_P_yx, fin_est_opts.offline_EM_iters); comp_times[2] = time.time() - start
    if fin_est_opts.fin_est_vec[3]:
        theta_fin_est[3], comp_times[3] = estimate_final_mm(state.C, len(state.Y), K, K_prime, eps_DP)
    if fin_est_opts.fin_est_vec[4]:
        start = time.time(); theta_fin_est[4] = estimate_online_qp1(theta0, state.Sum_GHTGH, state.sum_GHTc); comp_times[4] = time.time() - start
    if fin_est_opts.fin_est_vec[5]:
        start = time.time(); theta_fin_est[5] = estimate_online_qp2(theta0, state.Sum_HTGH, state.sum_HTc); comp_times[5] = time.time() - start
        
    return theta_fin_est, comp_times

# --- Runner ---

def Adaptive_LDP(X, K, K_prime, eps_DP, rand_opts: RandOpts, on_est_opts: OnEstOpts, fin_est_opts: FinEstOpts, theta0):
    T = len(X)
    hashing = 1 if rand_opts.method in ['OLH', 'OLH+RRRR'] else 0
    if not hashing: K_prime = K
    
    state = initialize_state(K, T, theta0, on_est_opts, fin_est_opts)
    G0 = np.full((K_prime, K_prime), 1 / (np.exp(eps_DP) + K_prime - 1))
    np.fill_diagonal(G0, np.exp(eps_DP) / (np.exp(eps_DP) + K_prime - 1))
    GG = initialize_adaptive_mechanisms(K_prime, eps_DP, rand_opts) if rand_opts.method in ['RRRR', 'OLH+RRRR'] else []

    for t in range(T):
        y, p_yx_vec, H, G_current, H_mtx = perform_privatization(t, X[t] - 1, T, state, G0, GG, K, K_prime, rand_opts, hashing)
        state.Y[t] = y + 1
        state.log_P_yx[:, t] = np.log(p_yx_vec)
        update_online_estimates(t, y, p_yx_vec, H, G_current, H_mtx, state, K, K_prime, eps_DP, on_est_opts, hashing)
        if rand_opts.method in ['RRRR', 'OLH+RRRR'] and on_est_opts.active_on_est_method > 0:
            adapt_mechanism(t, state, rand_opts, on_est_opts)

    theta_fin_est, comp_times = run_final_estimation(state, K, K_prime, eps_DP, theta0, fin_est_opts, on_est_opts)
    return {"Theta_Online_Est": state.Theta_Est, "Theta_Fin_Est": theta_fin_est, "offline_est_comp_times": comp_times}

# --- Experiment ---

if __name__ == '__main__':
    np.random.seed(1)
    
    # Experiment Configuration
    eps_DP_vec = [0.5, 1, 2]; K_vec = [100]; rho_coeff_vec = [0.01, 0.05, 0.5]; MC_run = 10
    Online_est_methods = ['Exact MLE-on', 'online EM-on', 'offline EM-on', 'Simple MM-on', 'Quad Prog1-on', 'Quad Prog2-on']
    Offline_est_methods = ['Exact MLE-off', 'online EM-off', 'offline EM-off', 'Simple MM-off', 'Quad Prog1-off', 'Quad Prog2-off']
    method_params_cell = [
        {'rand_method': 'SRR', 'on_est_vec': [1, 0, 0, 1, 1, 0], 'fin_est_vec': [1, 0, 1, 1, 1, 1], 'active_on_est_method': 0},
        {'rand_method': 'OLH', 'on_est_vec': [0, 1, 0, 1, 1, 1], 'fin_est_vec': [0, 1, 1, 1, 1, 1], 'active_on_est_method': 0},
        {'rand_method': 'RRRR', 'on_est_vec': [0, 1, 0, 0, 0, 0], 'fin_est_vec': [0, 1, 1, 0, 1, 1], 'active_on_est_method': 2},
        {'rand_method': 'RRRR', 'on_est_vec': [0, 0, 0, 0, 1, 0], 'fin_est_vec': [0, 1, 1, 0, 1, 1], 'active_on_est_method': 5},
        {'rand_method': 'RRRR', 'on_est_vec': [0, 0, 0, 0, 0, 1], 'fin_est_vec': [0, 1, 1, 0, 1, 1], 'active_on_est_method': 6},
        {'rand_method': 'OLH+RRRR', 'on_est_vec': [0, 1, 0, 0, 0, 0], 'fin_est_vec': [0, 1, 1, 0, 1, 1], 'active_on_est_method': 2},
        {'rand_method': 'OLH+RRRR', 'on_est_vec': [0, 0, 0, 0, 1, 0], 'fin_est_vec': [0, 1, 1, 0, 1, 1], 'active_on_est_method': 5},
        {'rand_method': 'OLH+RRRR', 'on_est_vec': [0, 0, 0, 0, 0, 1], 'fin_est_vec': [0, 1, 1, 0, 1, 1], 'active_on_est_method': 6}
    ]
    
    results_list = []
    total_runs = len(K_vec) * len(rho_coeff_vec) * len(eps_DP_vec) * len(method_params_cell)
    run_count = 0

    for K in K_vec:
        for rho_coeff in rho_coeff_vec:
            for eps_DP in eps_DP_vec:
                for params in method_params_cell:
                    run_count += 1
                    active_method_str = 'None' if params['active_on_est_method'] == 0 else Online_est_methods[params['active_on_est_method']-1]
                    print(f"\n--- Run {run_count}/{total_runs}: K={K}, rho={rho_coeff}, eps={eps_DP}, method={params['rand_method']}, active={active_method_str} ---")
                    
                    mc_results = []
                    for mc in range(MC_run):
                        T = 100 * K; theta0 = np.ones(K) / K
                        theta_true = np.random.gamma(rho_coeff * np.ones(K), 1); theta_true /= np.sum(theta_true)
                        X = np.random.choice(np.arange(1, K + 1), size=T, p=theta_true)
                        K_prime = int(np.ceil(np.exp(eps_DP) + 1))
                        rand_opts = RandOpts(params['rand_method'], [0.8, 0.9, 1], 0.01)
                        on_est_opts = OnEstOpts(params['on_est_vec'], params['active_on_est_method'], 0.55, 100, 100, 100)
                        fin_est_opts = FinEstOpts(params['fin_est_vec'], 1000, 0.55, 100)
                        outputs = Adaptive_LDP(X, K, K_prime, eps_DP, rand_opts, on_est_opts, fin_est_opts, theta0)
                        
                        res = {'mc': mc, 'K': K, 'rho': rho_coeff, 'eps_DP': eps_DP, 'rand_method': params['rand_method'], 'active_method': active_method_str}
                        for i, name in enumerate(Online_est_methods):
                            if on_est_opts.on_est_vec[i]: res[f'TV_{name}'] = 0.5 * np.sum(np.abs(outputs['Theta_Online_Est'][i][:, -1] - theta_true))
                        for i, name in enumerate(Offline_est_methods):
                            if fin_est_opts.fin_est_vec[i]:
                                est = outputs['Theta_Fin_Est'][i]
                                final_est = est[-1, :] if est.ndim > 1 else est
                                res[f'TV_{name}'] = 0.5 * np.sum(np.abs(final_est - theta_true))
                        mc_results.append(res)

                    temp_df = pd.DataFrame(mc_results)
                    tv_cols = [col for col in temp_df.columns if 'TV_' in col]
                    mean_tvs = temp_df[tv_cols].mean()
                    
                    print("    Average TV Distances for this run:")
                    for col, val in mean_tvs.items():
                        print(f"      {col:<22}: {val:.4f}")
                    results_list.extend(mc_results)

    results_df = pd.DataFrame(results_list)
    grouping_cols = ['K', 'rho', 'eps_DP', 'rand_method', 'active_method']
    mean_results = results_df.groupby(grouping_cols).mean().drop('mc', axis=1)
    
    print("\n\n--- FINAL AGGREGATE RESULTS (MEAN TV) ---")
    with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 200):
        print(mean_results)


--- Run 1/72: K=100, rho=0.01, eps=0.5, method=SRR, active=None ---
