In [44]:
import numpy as np
import pandas as pd
from scipy import stats

In [47]:
def exact_MSE_mean_est_weighted(var = stats.beta(a=7, b=5).var(), mue = stats.beta(a=7, b=5).mean(), 
                                n_list = [10, 20, 30], w_best = False, w_list = [0.2, 0.4, 0.6], v_best = False,
                               v_mat = [[0.1, 0.6, 0.3], [0.2, 0.8, 0.0], [0.3, 0.5, 0.2]]):
    '''
    Calculate exact error for mean estimation.  

    Args:
        var: variance of true mean distributions
        mue: mean of true error distribution. 
        n_list: a list of length M (number of players) with the number of samples each has.
        w_best: boolean, if true, calculates error given optimal values for w
        w_list: if w_best is false, a list of w-weights (in [0, 1]) for coarse-grained federation.
        v_best: boolean, if true, calculates error given optimal values for v
        v_mat: a matrix (list of lists) of weights each player uses in fine-grained federation: the rows sum up 
               to 1.
    Returns:
        dataframe with average error for each player, for: local, uniform, coarse-grained, and fine-grained 
        federation.  
    '''
    # dataframe for storing error
    player_error = pd.DataFrame(data = 0.0, index = ['local', 'uniform', 'coarse', 'fine'], 
                                columns = range(len(n_list)))
    N = sum(n_list)
    
    # for each player, calculate their true error 
    for j in range(len(n_list)):
        n = n_list[j]
        
        # local
        player_error.loc['local'][j] = mue/n
        
        sumsquares = sum([nval**2 for nval in n_list]) - n**2 + (N-n)**2
        
        # uniform
        player_error.loc['uniform'][j] = mue/N + sumsquares * var/(N**2)
        
        # coarse-grained
        if w_best: 
            if len(n_list) == 1: # division by 0 issue if length 1 list - equivalent to local
                w_err = player_error.loc['local'][j]
            else:
                w_err = (mue * (N-n) + var * sumsquares)/((N-n)*N + n*var*sumsquares/mue)
        else:
            w = w_list[j]
            w_err = mue * ( w**2/n + (1-w**2)/N) + ((1-w)**2/(N**2)) * sumsquares* var
        player_error.loc['coarse'][j] = w_err
        
        # fine-grained
        if v_best: 
            # calculate optimal v weights
            V_list = [var + mue/ni for ni in n_list]
            sum_inv = sum([1/Vi for Vi in V_list]) - 1/V_list[j]
            vjj = (1 + var * sum_inv)/(1 + V_list[j] * sum_inv)
            weights = [(V_list[j]-var)/(Vk * (1 + V_list[j]*sum_inv)) for Vk in V_list]
            weights[j] = vjj
            v_vec = pd.DataFrame(weights)
        else:
            v_vec = pd.DataFrame(v_mat[j])
            
        player_error.loc['fine'][j] = (mue * (v_vec**2).T.dot(pd.DataFrame([1/nval for nval in n_list])) + 
                                       var * ((v_vec**2).sum() - v_vec.iloc[j]**2 + (1 - v_vec.iloc[j])**2))[0][0]
        
    return player_error

In [48]:
def simulate_means(mean_dist = stats.beta(a=7, b=5), err_dist = stats.beta(a=5, b=7), draws_dist = stats.norm, 
                   n_list = [10, 20, 30], w_list = [0.2, 0.4, 0.6], 
                   v_mat = [[0.1, 0.6, 0.3], [0.2, 0.8, 0.0], [0.3, 0.5, 0.2]], 
                   world_nrun = 100, sample_nrun = 1):
    
    '''
    Simulate mean estimation. 
    

    Args:
        mean_dist: distribution to draw true means from (mean = theta)
        err_dist: distribution to draw true error parameters from (err = epsilon^2)
        draws_dist: distribution each player draws from: with mean theta and variance epsilon^2. 
        n_list: a list of length M (number of players) with the number of samples each has. 
        w_list: a list of w-weights each player uses for coarse-grained federation.  
        v_mat: a matrix (list of lists) of weights each player uses in fine-grained federation: the rows sum up 
               to 1.
        world_nrun: number of times where means and errors are re-drawn
        sample_nrun: for each worldrun, number of times samples are re-drawn
    Returns:
        dataframe with average error for each player, for local, uniform, coarse-grained, and fine-grained 
        federation.  
    '''
    M = len(w_list)
    n_list_pd = pd.DataFrame(n_list)
    w_list_pd = pd.DataFrame(w_list)
    v_mat_pd = pd.DataFrame(v_mat)
    
    # dataframe for storing error
    player_error = pd.DataFrame(data = 0, index = ['local', 'uniform', 'coarse', 'fine'], 
                                columns = range(len(n_list)))
    
    for wn in range(world_nrun):
        # draw means and errors
        means = mean_dist.rvs(M) 
        errors = err_dist.rvs(M)    
        
        for sn in range(sample_nrun): 
            
            # draw samples for each player
            sample_dict = {}
            for i in range(M):
                sample_dict[i] = draws_dist(means[i], np.sqrt(errors[i])).rvs(n_list[i])
            
            # calculate mean estimates
            local_est = pd.DataFrame([sample_dict[i].mean() for i in range(M)])
            uniform_est = local_est.T.dot(n_list_pd)[0][0]/sum(n_list)
            coarse_est = w_list_pd * local_est + (1-w_list_pd) * uniform_est
            fine_est = v_mat_pd.dot(local_est)
            
            # calculate MSE
            player_error.loc['local'] += ((local_est - pd.DataFrame(means))**2).values.flatten()
            player_error.loc['uniform'] += ((uniform_est - pd.DataFrame(means))**2).values.flatten()
            player_error.loc['coarse'] += ((coarse_est - pd.DataFrame(means))**2).values.flatten()
            player_error.loc['fine'] += ((fine_est - pd.DataFrame(means))**2).values.flatten()
    
    player_error = player_error/(world_nrun * sample_nrun)
    
    return player_error

Simulate empirical means, compare to true means.

In [49]:
simulate_means(world_nrun = 1000)

Unnamed: 0,0,1,2
local,0.045953,0.021641,0.014396
uniform,0.026291,0.020928,0.013509
coarse,0.020256,0.014181,0.010271
fine,0.033016,0.017766,0.028017


In [50]:
exact_MSE_mean_est_weighted()

Unnamed: 0,0,1,2
local,0.058333,0.029167,0.019444
uniform,0.029458,0.023225,0.016993
coarse,0.024297,0.017694,0.014386
fine,0.036391,0.022496,0.031642
