---
execute:
    echo: false
    warning: false

format:
    html:
        themes: "theme_ipol.scss"
        self-contained: true
---

In [None]:
n_nodes = 78 #@param {type:"slider", min:2, max:300, step:1}
prob_law = "log_normal_low_var" #@param ["log_normal_low_var", "log_normal_high_var", "gaussian_low_var", "gaussian_med_var", "gaussian_high_var", "poisson", "uniform_low_var", "uniform_high_var", "laplacian_low_var", "laplacian_med_var", "laplacian_high_var", "pareto", "weibull"]
feature_dims = 10 #@param {type:"slider", min:10, max:100, step:1}


In [None]:
n_montecarlo = 1
gamma_list = [1.2, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]

In [None]:
import numpy as np
import scipy
import random
import math
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.linear_model import Ridge, RidgeCV
from scipy.stats import lognorm, poisson, uniform, norm, randint, weibull_min, laplace, pareto
from matplotlib import rc
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default='notebook_connected'
plt.ion()

In [None]:
mpl.style.use('seaborn-darkgrid')
mpl.style.use('seaborn-pastel')
mpl.rcParams['font.size'] = 15

In [None]:
def generate_data(dimension, number_samples_per_node):
    '''
    Generate synthetic data and labels for all de nodes. Data samples are obtained from a standard
    normal distribution, and the labels are synthetized using a randomly generated parameter vector
    (where each of the componentes are randomly distributed).
    Args:
      dimension (int): dimension of the data vector (x)
      number_samples_per_node (list or array): contains the number of samples for each node.
    Returns:
      x (list): the data samples of each node
      y (list): the labels of each node
      theta_star (array): the parameter vector
    '''
    theta_star = np.random.uniform(size = dimension)
    x = []
    y = []
    for n_samples in number_samples_per_node:
        x_i = np.random.normal(loc = 0, scale = 1., size = (n_samples, dimension))
        y_i = x_i @ theta_star + np.random.normal(loc = 0, scale = 1., size = n_samples)
        x.append(x_i)
        y.append(y_i)

    return x, y, theta_star

def get_law_data_nodes(prob_law, n_nodes, gamma):
    '''
    Returns the number of samples per node from a given probability law with mean n_nodes**gamma.
    Args:
      prob_law (str): the probability distribution of the number of samples per node. Supported: 'gaussian_low_var', 'gaussian_med_var', 
      'gaussian_high_var', 'log_normal_low_var', 'log_normal_high_var', 'poisson', 'uniform_low_var', 
      'uniform_high_var', 'laplacian_low_var', 'laplacian_med_var', 'laplacian_high_var', 'pareto', 'weibull'. 
      Raises NotImplementedError if the probability law is not supported.
      n_nodes (int): the total number of nodes
      gamma (float): the parameter to set the mean number of samples per node (which is equal to n_nodes**gamma)

    Returns:
      an array of length n_nodes containing the number of samples for each node.
    '''
    
    if prob_law == 'gaussian_low_var':
        law_datanodes = norm.rvs(n_nodes**gamma, 0.1*n_nodes**gamma, size=n_nodes)
    elif prob_law == 'gaussian_med_var':
        law_datanodes = norm.rvs(n_nodes**gamma, 0.5*n_nodes**gamma, size=n_nodes)
    elif prob_law == 'gaussian_high_var':
        law_datanodes = norm.rvs(n_nodes**gamma, n_nodes**gamma, size=n_nodes)
    elif prob_law == 'log_normal_low_var':
        law_datanodes = lognorm.rvs(1, scale = (n_nodes**gamma)*np.exp(-1/2), size = n_nodes)
    elif prob_law == 'log_normal_high_var':
        law_datanodes = lognorm.rvs(2, scale = (n_nodes**gamma)*np.exp(2), size = n_nodes)
    elif prob_law == 'poisson':
        law_datanodes = poisson.rvs(n_nodes**gamma, size=n_nodes)
    elif prob_law == 'uniform_low_var':
        law_datanodes = randint.rvs(0.5*n_nodes**gamma, n_nodes**gamma, size=n_nodes)
    elif prob_law == 'uniform_high_var':
        law_datanodes = randint.rvs(0, 2*n_nodes**gamma, size=n_nodes)
    elif prob_law == 'laplacian_low_var':
        law_datanodes = laplace.rvs(loc=n_nodes**gamma, scale=0.1*n_nodes**gamma, size=n_nodes)
    elif prob_law == 'laplacian_med_var':
        law_datanodes = laplace.rvs(loc=n_nodes**gamma, scale=0.5*n_nodes**gamma, size=n_nodes)
    elif prob_law == 'laplacian_high_var':
        law_datanodes = laplace.rvs(loc=n_nodes**gamma, scale=n_nodes**gamma, size=n_nodes)
    elif prob_law == 'pareto':
        law_datanodes = pareto.rvs(b=2, scale=n_nodes**gamma/2, size=n_nodes)
    elif prob_law == 'weibull':
        law_datanodes = weibull_min.rvs(c=1.5, scale=n_nodes**gamma/math.gamma(1+1/1.5), size=n_nodes)
    else:
        raise NotImplementedError

    for i in range(n_nodes):
        #to ensure that the sample size is at least 1
        law_datanodes[i] = max(1, law_datanodes[i])

    return law_datanodes.astype(int)

def weights(samples_per_node):
    '''
    Computes the weights for the FESC algorithm.
    Args: 
      samples_per_node (array): array of length n_nodes containing the number of samples for each node.
    
    Returns:
      an array of length n_nodes with the weights of each node for the final model aggregation
    '''

    # sort in decreasing order
    samples_per_node_sorted = -np.sort(-samples_per_node)
    idx_sorted_nodes = np.argsort(-samples_per_node)

    a = 1./samples_per_node_sorted
    b = 1./samples_per_node_sorted**2

    # initializations
    weights = np.zeros(len(samples_per_node))
    temp_ab = b[0]/a[0]
    temp_a = 1./a[0]
    sum_ord_ab = 0
    sum_ord_a = 0

    # calculate k, the number of nodes that participate in the final aggregation
    k = 0
    while b[k] <= (2 + temp_ab) / (temp_a):
        sum_ord_ab += b[k]/a[k]
        sum_ord_a += 1./a[k]
        if k == len(samples_per_node)-1:
            break
        k += 1
        temp_ab = sum_ord_ab + b[k]/a[k]
        temp_a = sum_ord_a + 1./a[k]
    
    # calculate the weights w_i for every node
    for i in range(k):
        node_idx = idx_sorted_nodes[i]
        weights[node_idx] = - 0.5*b[i]/a[i] + 1./a[i]*(1. + 0.5*sum_ord_ab)/sum_ord_a

    return weights

def launch_simulation(montecarlo_rounds, n_nodes, feature_dims, gamma_list, prob_law):
    '''
    Launches simulation. Data is generated, and the theta parameter is estimated in 3 different ways: centralized learning, 
    classic Federated Learning and with the FESC algorithm.
    Args:
      montecarlo_round (int): number of Montecarlo iterations
      n_nodes (int): number of nodes 
      feature_dims (int): number of dimensions of the features space
      gamma_list (list or array): contains the different gammas for which the simulation will be performed
      prob_law (str): probability distribution of the number of samples per node (of mean n_nodes**gamma)
    Returns:
      mse_theta_centralized (array): contains the MSE of the centralized model, for each gamma
      mse_theta_fed (array): contains the MSE of the classicl federated model, for each gamma
      mse_theta_FESC (array): contains the MSE of the FESC model, for each gamma
      samples_per_node_to_plot (array): samples per node for gamma=1.2. This is used to illustrate the chosen probability distribution
    '''
    
    # initialization
    mse_centralized = np.zeros(len(gamma_list))
    mse_fed = np.zeros(len(gamma_list))
    mse_FESC = np.zeros(len(gamma_list))
    
    for gamma_idx, gamma in enumerate(gamma_list):                
        #for each gamma, where the sample size expectation is equal to M^gamma
        for _ in range(montecarlo_rounds):
            
            # generate number of samples per client
            samples_per_node = get_law_data_nodes(prob_law, n_nodes, gamma)
            lambda_ridge = 1./np.sqrt(np.sum(samples_per_node)) # lambda parameter of the ridge regression

            # plot the samples per node histogram only for gamma=1.2
            if gamma_idx == 0:
              samples_per_node_to_plot = samples_per_node
            
            # generate data and labels
            x, y, theta_star = generate_data(feature_dims, samples_per_node)

            # centralized training
            model = Ridge(alpha = lambda_ridge)
            model.fit(np.concatenate(x), np.concatenate(y))
            theta_centralized = model.coef_

            # federated training
            theta_nodes = np.zeros((n_nodes, feature_dims))

            # local steps in all the nodes
            for i,[xi,yi] in enumerate(zip(x,y)):
                model = Ridge(alpha = lambda_ridge)
                model.fit(xi, yi)
                theta_nodes[i] = model.coef_
                
            # server aggregation
            theta_fed = np.dot(samples_per_node/np.sum(samples_per_node), theta_nodes) # theta with weights proportional to the sample size
            weights_FESC = weights(samples_per_node) # weights obtained for theta "Good Practice"
            theta_FESC = np.dot(weights_FESC, theta_nodes)
            
            # compute MSE
            mse_centralized[gamma_idx] += np.linalg.norm(theta_centralized - theta_star, 2)**2 / montecarlo_rounds
            mse_fed[gamma_idx] += np.linalg.norm(theta_fed - theta_star, 2)**2 / montecarlo_rounds
            mse_FESC[gamma_idx] += np.linalg.norm(theta_FESC - theta_star, 2)**2 / montecarlo_rounds
            
    return mse_centralized, mse_fed, mse_FESC, samples_per_node_to_plot

In [None]:
def plot_images_interactive(mse_centralized, mse_fed, mse_FESC, gammas, n_nodes, samples_per_node):
    '''
    Plots interactive images that show the results
    Args:
      mse_centralized (array): contains the MSE of the centralized model, for each gamma
      mse_fed (array): contains the MSE of the classicl federated model, for each gamma
      mse_FESC (array): contains the MSE of the FESC model, for each gamma
      gammas (list or array): gammas used for each MSE
      n_nodes (int): number of nodes
      samples_per_node (list or array): contains the number of samples for each node
    '''

    fig = make_subplots(rows=1, cols=2, subplot_titles=('samples per node histogram', 'MSE vs mean samples per node'))

    fig.add_trace(go.Histogram(x=samples_per_node, nbinsx=int(0.2*n_nodes), histnorm="percent", showlegend = False),1,1)
    fig.update_xaxes(title_text="number of samples per node", row=1, col=1)
    fig.update_yaxes(title_text="count", row=1, col=1)

    fig.add_trace(go.Scatter(x=gammas, y=mse_centralized, name="Centralized"),1,2)
    fig.add_trace(go.Scatter(x=gammas, y=mse_fed, name='Federated'),1,2)
    fig.add_trace(go.Scatter(x=gammas, y=mse_FESC, name='FESC'),1,2)
    
    fig.update_xaxes(title_text="gamma", row=1, col=2)
    fig.update_yaxes(title_text="mean squared error (MSE)", type="log", row=1, col=2)

    fig.update_layout(showlegend=True)

    fig.show()

def plot_result(mse_centralized, mse_fed, mse_FESC, gammas, n_nodes, samples_per_node):
    '''
    Plots the simulation results.
    Args:
      mse_centralized (array): contains the MSE of the centralized model, for each gamma
      mse_fed (array): contains the MSE of the classicl federated model, for each gamma
      mse_FESC (array): contains the MSE of the FESC model, for each gamma
      gammas (list or array): gammas used for each MSE
      n_nodes (int): number of nodes
      samples_per_node (list or array): contains the number of samples for each node
    '''

    # samples per node histogram
    f = plt.figure(figsize=(14,5))
    ax1 = f.add_subplot(121)
    ax2 = f.add_subplot(122)

    weights = np.ones_like(samples_per_node)/float(len(samples_per_node))
    ax1.hist(samples_per_node, bins=int(0.2*n_nodes), weights=weights)
    ax1.set_xlabel("number of samples per node")
    ax1.set_ylabel("count")
    ax1.set_title('samples per node histogram')
    ax1.grid('on')

    # output image
    ax2.plot(gammas, mse_centralized, '+-', label = "Centralized") #sco theta
    ax2.plot(gammas, mse_fed, '+-', label = "Federated") #sco theta s
    ax2.plot(gammas, mse_FESC, '+-',label = "FESC") #sco theta GP
    ax2.set_yscale('log')
    ax2.legend()
    ax2.set_xlabel("gamma")
    ax2.set_title('MSE vs mean samples per node')
    ax2.set_ylabel("mean squared error (MSE)")
    ax2.grid('on')

    plt.tight_layout()

In [1]:
mse_centralized, mse_fed, mse_FESC, samples_per_node = launch_simulation(n_montecarlo, n_nodes, feature_dims, gamma_list, prob_law)

NameError: ignored

::: {.column-screen}

In [None]:
print('You need at least {} samples (in average) per node to equal the performance of a centralized training with Federated Learning'.format(n_nodes))

:::

::: {.column-screen}

In [None]:
plot_images_interactive(mse_centralized, mse_fed, mse_FESC, gamma_list, n_nodes, samples_per_node)

:::