In [1]:
import torch
import numpy.linalg as LA
import torch.quantization
import torch.nn as nn
import torch.optim as optim
import pickle
import os
import sys
sys.path.append("..")
from utils import categorical_gumbel_softmax_sampling, categorical_softmax, get_acc_and_bac, continuous_sigmoid_bound, Timer
import torch
from attacks.initializations import _uniform_initialization, _gaussian_initialization, _mean_initialization, \
    _dataset_sample_initialization, _likelihood_prior_sample_initialization, _mixed_initialization, \
    _best_sample_initialization
from attacks.priors import _joint_gmm_prior, _mean_field_gmm_prior, _categorical_prior, _categorical_l2_prior, \
    _categorical_mean_field_jensen_shannon_prior, _continuous_uniform_prior, _theoretical_optimal_prior, \
    _theoretical_typicality_prior, _theoretical_marginal_prior, _theoretical_marginal_typicality_prior
from attacks.inversion_losses import _weighted_CS_SE_loss, _gradient_norm_weighted_CS_SE_loss, _squared_error_loss, _cosine_similarity_loss
from attacks.ensembling import pooled_ensemble
from collections import OrderedDict
from models import MetaMonkey
import numpy as np
import copy
import pickle
import os
import multiprocessing
from defenses import dp_defense
from fair_loss import FairLoss
import torch.nn as nn
import os
import attacks
import numpy as np
import torch
from utils import match_reconstruction_ground_truth, Timer, post_process_continuous
from attacks import train_and_attack_fed_avg
from models import FullyConnected
from datasets import ADULT
import pickle
from attacks import calculate_random_baseline
from torch.quantization import QuantStub, DeQuantStub


# Quantization Aware Training

In [2]:
class LinReLU(nn.Module):
    """
    A linear layer followed by a ReLU activation layer.
    """
    def __init__(self, in_size, out_size):
        super(LinReLU, self).__init__()
        self.linear = nn.Linear(in_size, out_size)
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(self.linear, self.relu)

    def reset_parameters(self):
        self.linear.reset_parameters()
        return self

    def forward(self, x):
        x = self.layers(x)
        return x

class FullyConnected(nn.Module):
    """
    A fully connected neural network with ReLU activations and QAT support.
    """
    def __init__(self, input_size, layout):
        super(FullyConnected, self).__init__()
        self.quant = torch.quantization.QuantStub()
        layers = [nn.Flatten()]
        prev_fc_size = input_size
        for i, fc_size in enumerate(layout):
            if i + 1 < len(layout):
                layers += [LinReLU(prev_fc_size, fc_size)]
            else:
                layers += [nn.Linear(prev_fc_size, 1), nn.Sigmoid()]
            prev_fc_size = fc_size
        self.layers = nn.Sequential(*layers)
            
        self.dequant = torch.quantization.DeQuantStub()        
        
        # self.qconfig = None

    def forward(self, x):
        x = self.quant(x)
        x = self.layers(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        """
        Fuses Linear and ReLU layers in LinReLU modules for QAT
        """
        for module in self.modules():
            if isinstance(module, LinReLU):
                torch.quantization.fuse_modules(
                    module.layers, 
                    ['0', '1'],  # Fuse Linear and ReLU
                    inplace=True
                )

In [3]:
def epoch_matching_prior_mean_square_error(epoch_data, device=None):
    """
    Permutation invariant prior that can be applied over the individual datapoints in the epochs. We first average up
    each dataset in the epoch and then calculate pairwise L2 distances between the epoch-data. It is normalized for
    number of features and number of epochs.

    :param epoch_data: (list of torch.tensor) List of the data-tensors used for each epoch.
    :param device: (str) Name of the device on which the tensors are stored. If None is given, the device on which the
        first of the epoch data is taken.
    :return: prior (torch.tensor) The calculated value of the prior with gradient information.

    """
    n_epochs = len(epoch_data)
    n_features = epoch_data[0].size()[-1]
    if device is None:
        device = epoch_data[0].device
    average_local_data = torch.stack([1/data.size()[0] * data.sum(dim=0) for data in epoch_data]).to(device)
    prior = torch.tensor([0.], device=device)
    for i in range(n_epochs):
        prior += 1/(n_epochs**2) * 1/n_features * (average_local_data - average_local_data[i]).pow(2).sum()
    return prior

In [4]:
def simulate_local_training_for_attack(client_net, lr, criterion, dataset, labels, original_params,
                                       reconstructed_data_per_epoch, local_batch_size, priors=None,
                                       epoch_matching_prior=None, softmax_trick=True, gumbel_softmax_trick=False,
                                       sigmoid_trick=False, temperature=None, apply_projection_to_features=None,
                                       device=None):
    """
    Simulates the local training such that it can be differentiated through with the Pytorch engine.

    :param client_net: (MetaMonkey) A MetaMonkey wrapped nn.Module neural network that supports parameter assignment$
        directly through assigning and OrderedDict.
    :param lr: (float) The learning rate of the local training.
    :param criterion: (nn.Module) The loss function of the training.
    :param dataset: (datasets.BaseDataset) The dataset with which we work. It contains usually the data necessary for
        the calculation of the prior.
    :param labels: (torch.tensor) The labels for a whole local epoch, ordered as the batches should be.
    :param original_params: (OrderedDict) The original parameter dictionary of the network before training.
    :param reconstructed_data_per_epoch: (list of torch.tensor) List of the concatenated batches of data used for
        training. This is what we optimize for.
    :param local_batch_size: (int) The batch size of the local training.
    :param priors: (list of tuple(float, str)) The regularization parameter(s) plus the name(s) of the prior(s) we wish
        to use. Default None accounts to no prior.
    :param epoch_matching_prior: tuple(float, str) The regularization parameter of the epoch matching prior plus its
        name. If None is given (default), then no epoch matching prior will be applied.
    :param softmax_trick: (bool) Toggle to apply the softmax trick to the categorical features. Effectively, it serves
        as a structural prior on the features.
    :param gumbel_softmax_trick: (bool) Toggle to apply the gumbel-softmax trick to the categorical features.
    :param sigmoid_trick: (bool) Apply the sigmoid trick to the continuous features to enforce the bounds.
    :param apply_projection_to_features: (list) If given, both the softmax trick and the gumbel softmax trick will be
        applied only to the set of features given in this list.
    :param temperature: (float) Temperature parameter for the softmax in the categorical prior.
    :param device: (str) Name of the device on which the tensors are stored.
    :return: resulting_two_point_gradient: (list of torch.tensor) Two-point gradient estimate over a local training.
    """
    if device is None:
        device = dataset.device

    if apply_projection_to_features is None:
        apply_projection_to_features = 'all'

    available_priors = {
        'categorical_prior': _categorical_prior,
        'cont_uniform': _continuous_uniform_prior,
        'cont_joint_gmm': _joint_gmm_prior,
        'cont_mean_field_gmm': _mean_field_gmm_prior,
        'cat_mean_field_JS': _categorical_mean_field_jensen_shannon_prior,
        'cat_l2': _categorical_l2_prior,
        'theoretical_optimal': _theoretical_optimal_prior,
        'theoretical_typicality': _theoretical_typicality_prior,
        'theoretical_marginal': _theoretical_marginal_prior,
        'theoretical_marginal_typicality': _theoretical_marginal_typicality_prior
    }

    available_epoch_matching_priors = {
        'mean_squared_error': epoch_matching_prior_mean_square_error
    }

    if priors is not None:
        # will raise a key error of we chose a non-implemented prior
        prior_params = [prior_params[0] for prior_params in priors]
        prior_loss_functions = [available_priors[prior_params[1]] for prior_params in priors]
    else:
        prior_loss_functions = None
        prior_params = None

    regularizer = torch.as_tensor([0.0], device=device)

    n_data_lines = labels.size()[0]
    for local_epoch, reconstructed_data in enumerate(reconstructed_data_per_epoch):

        n_batches = int(np.ceil(n_data_lines / local_batch_size))
        for b in range(n_batches):
            current_batch_X = reconstructed_data[b*local_batch_size:min(n_data_lines, (b+1)*local_batch_size)]
            current_batch_y = labels[b*local_batch_size:min(n_data_lines, (b+1)*local_batch_size)].clone().detach()

            # apply softmax or gumbel-softmax
            if gumbel_softmax_trick:
                x_rec = categorical_gumbel_softmax_sampling(current_batch_X, tau=temperature, dataset=dataset)
                categoricals_projected = True
            elif softmax_trick:
                x_rec = categorical_softmax(current_batch_X, tau=temperature, dataset=dataset,
                                            apply_to=apply_projection_to_features)
                categoricals_projected = True
            else:
                x_rec = current_batch_X * 1.
                categoricals_projected = False

            if sigmoid_trick:
                x_rec = continuous_sigmoid_bound(x_rec, dataset=dataset, T=temperature)

            outputs = client_net(x_rec, client_net.parameters)

            #Change by Chirag

            current_batch_y=current_batch_y.unsqueeze(1).float()

            training_loss = criterion(outputs, current_batch_y)
            grad = torch.autograd.grad(training_loss, client_net.parameters.values(), retain_graph=True,
                                       create_graph=True, only_inputs=True, allow_unused=True)

            client_net.parameters = OrderedDict((name, param - lr * param_grad) for ((name, param), param_grad) in zip(client_net.parameters.items(), grad))

            # keep track of a regularizer if needed
            if priors is not None:
                for prior_param, prior_function in zip(prior_params, prior_loss_functions):
                    regularizer += 1/(n_batches*local_epoch) * prior_param * prior_function(x_reconstruct=x_rec,
                                                                                            dataset=dataset,
                                                                                            softmax_trick=categoricals_projected,
                                                                                            labels=current_batch_y,
                                                                                            T=temperature)

    # if we have an epoch matching prior, we calculate its value, for this, we have to reapply any projections made on
    # the data previously
    if epoch_matching_prior is not None:
        epoch_matching_prior_param = epoch_matching_prior[0]
        epoch_matching_prior_function = available_epoch_matching_priors[epoch_matching_prior[1]]

        # reapply the projections if any
        if softmax_trick or gumbel_softmax_trick:
            projected_epoch_data = [categorical_softmax(epoch_data, dataset=dataset, tau=temperature,
                                                        apply_to=apply_projection_to_features) for epoch_data in reconstructed_data_per_epoch]
        else:
            projected_epoch_data = reconstructed_data_per_epoch
        # reapply the sigmoid if given
        if sigmoid_trick:
            projected_bounded_epoch_data = [continuous_sigmoid_bound(pd, dataset=dataset, T=temperature) for pd in projected_epoch_data]
        else:
            projected_bounded_epoch_data = projected_epoch_data
        regularizer += epoch_matching_prior_param * epoch_matching_prior_function(projected_bounded_epoch_data, device=device)

    # end of training, time to extract the parameters
    resulting_parameters = list(client_net.parameters.values())
    resulting_two_point_gradient = [original_param - param for original_param, param in
                                    zip(original_params, resulting_parameters)]

    return resulting_two_point_gradient, regularizer

In [5]:
def fed_avg_attack(original_net, attacked_clients_params, n_local_epochs, local_batch_size, lr,
                   dataset, per_client_ground_truth_data, per_client_ground_truth_labels, attack_iterations=1000,
                   attack_learning_rate=0.06, reconstruction_loss='cosine_sim', priors=None, epoch_matching_prior=None,
                   initialization_mode='uniform', softmax_trick=True, gumbel_softmax_trick=False, temperature_mode=None,
                   sigmoid_trick=False, sign_trick=True, apply_projection_to_features=None, device=None):
    """
    FedAVG attack following Dimitrov et al. 2022.
    """
    if device is None:
        device = dataset.device

    # attack setups
    rec_loss_function = {
        'squared_error': _squared_error_loss,
        'cosine_sim': _cosine_similarity_loss,
        'weighted_combined': _weighted_CS_SE_loss,
        'norm_weighted_combined': _gradient_norm_weighted_CS_SE_loss
    }

    initialization = {
        'uniform': _uniform_initialization,
        'gaussian': _gaussian_initialization,
        'mean': _mean_initialization,
        'dataset_sample': _dataset_sample_initialization,
        'likelihood_sample': _likelihood_prior_sample_initialization,
        'mixed': _mixed_initialization,
        'best_sample': _best_sample_initialization
    }

    temperature_configs = {
        'cool': (1000., 0.98),
        'constant': (1., 1.),
        'heat': (0.1, 1.01)
    }

    if reconstruction_loss not in list(rec_loss_function.keys()):
        raise NotImplementedError(
            f'The desired loss function is not implemented, available loss function are: {list(rec_loss_function.keys())}')

    final_reconstructions_per_client = []
    final_loss_per_client = []

    # we will go by attacked client and then completely restart every time
    for attacked_client, (attacked_client_params, ground_truth_data, ground_truth_labels) in enumerate(zip(attacked_clients_params, per_client_ground_truth_data, per_client_ground_truth_labels)):
        # fix the client network and extract its starting parameters
        original_params = [param.detach().clone() for param in original_net.parameters()]
        
        for original_param, new_param in zip(original_params, attacked_client_params):
             print(f"Original param shape: {original_param.shape}, New param shape: {new_param.shape}")

        true_two_point_gradient = [(original_param - new_param).detach().clone() for original_param, new_param in zip(original_params, attacked_client_params)]

        # we reconstruct independently in each epoch and aggregate in the end, as per Dimitrov et al.
        # initialize the data
        reconstructed_data_per_epoch = [initialization[initialization_mode](ground_truth_data, dataset, device) for _ in range(n_local_epochs)]
        for reconstructed_data in reconstructed_data_per_epoch:
            reconstructed_data.requires_grad = True

        optimizer = torch.optim.Adam(reconstructed_data_per_epoch, lr=attack_learning_rate)

        T = temperature_configs[temperature_mode][0]

        for it in range(attack_iterations):

            optimizer.zero_grad()
            original_net.zero_grad()
            client_net = MetaMonkey(copy.deepcopy(original_net))
            # criterion = torch.nn.CrossEntropyLoss()
            criterion = torch.nn.BCELoss()

            resulting_two_point_gradient, regularizer = simulate_local_training_for_attack(
                client_net=client_net,
                lr=lr,
                criterion=criterion,
                dataset=dataset,
                labels=ground_truth_labels,
                original_params=original_params,
                reconstructed_data_per_epoch=reconstructed_data_per_epoch,
                local_batch_size=local_batch_size,
                priors=priors,
                epoch_matching_prior=epoch_matching_prior,
                softmax_trick=softmax_trick,
                gumbel_softmax_trick=gumbel_softmax_trick,
                sigmoid_trick=sigmoid_trick,
                apply_projection_to_features=apply_projection_to_features,
                temperature=T
            )

            # calculate the final objective
            loss = rec_loss_function[reconstruction_loss](resulting_two_point_gradient, true_two_point_gradient, device)
            loss += regularizer
            loss.backward()


            if sign_trick:
                for reconstructed_data in reconstructed_data_per_epoch:
                    reconstructed_data.grad.sign_()

            optimizer.step()

            # adjust the temperature
            T *= temperature_configs[temperature_mode][1]

        # if we used the sigmoid trick, we reapply it
        if sigmoid_trick:
            sigmoid_reconstruction = [continuous_sigmoid_bound(rd, dataset=dataset, T=T) for rd in reconstructed_data_per_epoch]
            reconstructed_data_per_epoch = sigmoid_reconstruction

        # after the optimization has finished for the given client, we project and match the data
        epoch_pooling = 'soft_avg+softmax' if softmax_trick or gumbel_softmax_trick else 'soft_avg'
        final_reconstruction = pooled_ensemble([reconstructed_data.clone().detach() for reconstructed_data in reconstructed_data_per_epoch],
                                               reconstructed_data_per_epoch[0].clone().detach(), dataset,
                                               pooling=epoch_pooling)
        final_reconstructions_per_client.append(final_reconstruction)

        # with the aggregated datapoint, we can finally run it again through the process to record its loss
        final_reconstruction_projected = dataset.project_batch(final_reconstruction, standardized=dataset.standardized)
        client_net = MetaMonkey(copy.deepcopy(original_net))

        # criterion = torch.nn.CrossEntropyLoss()
        criterion = torch.nn.BCELoss()
        final_resulting_two_point_gradient, _ = simulate_local_training_for_attack(
                client_net=client_net,
                lr=lr,
                criterion=criterion,
                dataset=dataset,
                labels=ground_truth_labels,
                original_params=original_params,
                reconstructed_data_per_epoch=[final_reconstruction_projected for _ in range(n_local_epochs)],
                local_batch_size=local_batch_size,
                priors=None,
                softmax_trick=softmax_trick,
                gumbel_softmax_trick=gumbel_softmax_trick,
                apply_projection_to_features=apply_projection_to_features,
                temperature=T
        )
        final_loss = rec_loss_function[reconstruction_loss](final_resulting_two_point_gradient, true_two_point_gradient, device)
        final_loss_per_client.append(final_loss.detach().item())

    return final_reconstructions_per_client, final_loss_per_client

In [6]:
# def train_and_attack_fed_avg(net, n_clients, n_global_epochs, n_local_epochs, local_batch_size, lr, dataset, shuffle=False,
#                              attacked_clients=None, attack_iterations=1000, reconstruction_loss='cosine_sim', priors=None,
#                              epoch_matching_prior=None, post_selection=1, attack_learning_rate=0.06, return_all=False,
#                              pooling=None, perfect_pooling=False, initialization_mode='uniform', softmax_trick=True,
#                              gumbel_softmax_trick=False, sigmoid_trick=False, temperature_mode='constant',
#                              sign_trick=True, fish_for_features=None, device=None, verbose=False, max_n_cpus=50, first_cpu=0,
#                              max_client_dataset_size=None, parallelized=False, metadata_path='metadata', state_name="AL"):

#     if device is None:
#         device = dataset.device

#     if attacked_clients is None:
#         attacked_clients = []
#     elif attacked_clients == 'all':
#         attacked_clients = list(np.arange(n_clients))

#     if max_client_dataset_size is None:
#         max_client_dataset_size = len(dataset)

#     per_global_epoch_per_client_reconstructions = []
#     per_global_epoch_per_client_ground_truth = []
#     training_data = np.zeros((n_global_epochs, 2))
    

#     # Split data into client datasets
#     if shuffle:
#         dataset.shuffle()

#     Xtrain, ytrain = dataset.get_Xtrain(), dataset.get_ytrain()
#     split_size = min(max_client_dataset_size, int(np.ceil(Xtrain.size()[0] / n_clients)))
#     Xtrain_splits = [Xtrain[i*split_size:min(int(Xtrain.size()[0]), (i+1)*split_size)].clone().detach() for i in range(n_clients)]
#     ytrain_splits = [ytrain[i*split_size:min(int(Xtrain.size()[0]), (i+1)*split_size)].clone().detach() for i in range(n_clients)]

#     # Loss function
#     criterion = torch.nn.BCELoss()
#     timer = Timer(n_global_epochs)
#     # Load pre-trained model
#     pre_trained_model_path = "50_clients_data/clients_trained_model/pre_trained_model.pth"
#     state_dict = torch.load(pre_trained_model_path)
#     weights_dict = {k: v for k, v in state_dict.items() if 'weight' in k}
#     net.load_state_dict(weights_dict, strict=False)
#     print("Pre-trained model loaded.")

    
#     def prepare_qat_model(model, backend='qnnpack'):
#         """
#         Prepare model for Quantization Aware Training
#         """
#         model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
#         torch.backends.quantized.engine = backend
#         model.fuse_model()
#         torch.quantization.prepare_qat(model, inplace=True)
#         return model

#     # Training loop
#     for global_epoch in range(n_global_epochs):
#         acc, bac = get_acc_and_bac(net, dataset.get_Xtest(), dataset.get_ytest())
#         if verbose:
#             print(f'Global Epochs: {global_epoch + 1}/{n_global_epochs}    Acc: {acc * 100:.2f}%    BAcc: {bac * 100:.2f}%')

#         training_data[global_epoch] = acc, bac

#         # Create client networks
#         # client_nets = [copy.deepcopy(net) for _ in range(n_clients)]
#         # client_nets = [prepare_qat_model(client_net) for client_net in client_nets]


#         # After creating client networks
#         client_nets = [copy.deepcopy(net) for _ in range(n_clients)]
#         client_nets = [prepare_qat_model(client_net) for client_net in client_nets]

#         for client, (client_X, client_y, client_net) in enumerate(zip(Xtrain_splits, ytrain_splits, client_nets)):
#             client_net.train()  # Set to training mode for QAT
#             n_batches = int(np.ceil(client_X.size()[0] / local_batch_size))

#             print(f"Training client {client + 1}/{n_clients}")
#             print("QAT training")
#             print("n_batches is", n_batches)
#             print("local_epoch is", n_local_epochs)

#             # Training loop remains the same
#             for local_epoch in range(n_local_epochs):
#                 for b in range(n_batches):
#                     current_batch_X = client_X[b * local_batch_size:min(int(client_X.size()[0]), (b+1)*local_batch_size)].clone().detach()
#                     current_batch_y = client_y[b * local_batch_size:min(int(client_X.size()[0]), (b+1)*local_batch_size)].clone().detach()
#                     outputs = client_net(current_batch_X)

#                     current_batch_y = current_batch_y.unsqueeze(1).float()
#                     loss = criterion(outputs, current_batch_y)
#                     grad = torch.autograd.grad(loss, client_net.parameters(), retain_graph=True)

#                     with torch.no_grad():
#                         for param, param_grad in zip(client_net.parameters(), grad):
#                             param -= lr * param_grad

#             # Evaluation phase
#             client_net.eval()
            
#             # First evaluate QAT model
#             val_running_loss = 0.0
#             val_correct = 0
#             val_total = 0

#             inputs, labels = dataset.get_Xtest(), dataset.get_ytest()
#             labels = labels.unsqueeze(1).float()
#             val_n_batches = int(np.ceil(inputs.size()[0] / local_batch_size))
            
#             with torch.no_grad():
#                 for b in range(val_n_batches):
#                     val_batch_X = inputs[b * local_batch_size:min(int(inputs.size()[0]), (b + 1) * local_batch_size)].clone().detach()
#                     val_batch_y = labels[b * local_batch_size:min(int(labels.size()[0]), (b + 1) * local_batch_size)].clone().detach()

#                     outputs = client_net(val_batch_X)
#                     val_loss = criterion(outputs, val_batch_y)
#                     val_running_loss += val_loss.item()

#                     predicted_classes = (outputs > 0.5).float()
#                     val_correct += (predicted_classes == val_batch_y).sum().item()
#                     val_total += val_batch_y.size(0)

#             val_epoch_loss = val_running_loss / len(inputs)
#             val_accuracy = val_correct / val_total
#             print(f"QAT Model - Validation Accuracy: {val_accuracy:.4f}, Loss: {val_epoch_loss:.4f}")

#             # Save QAT model
#             qat_model_path = f"50_clients_data/clients_trained_model/qat_{state_name}.pth"
#             torch.save(client_net.state_dict(), qat_model_path)
#             print(f"QAT model is saved to {qat_model_path}")
#             print("QAT Model Size:", os.path.getsize(qat_model_path) / 1024, "KB")

#             quantized_net = torch.quantization.convert(client_net.eval(), inplace=False)
            
#             # # Evaluate quantized model
#             # val_running_loss = 0.0
#             # val_correct = 0
#             # val_total = 0

#             # with torch.no_grad():
#             #     for b in range(val_n_batches):
#             #         val_batch_X = inputs[b * local_batch_size:min(int(inputs.size()[0]), (b + 1) * local_batch_size)].clone().detach()
#             #         val_batch_y = labels[b * local_batch_size:min(int(labels.size()[0]), (b + 1) * local_batch_size)].clone().detach()

#             #         outputs = quantized_net(val_batch_X)
#             #         val_loss = criterion(outputs, val_batch_y)
#             #         val_running_loss += val_loss.item()

#             #         predicted_classes = (outputs > 0.5).float()
#             #         val_correct += (predicted_classes == val_batch_y).sum().item()
#             #         val_total += val_batch_y.size(0)

#             # val_epoch_loss = val_running_loss / len(inputs)
#             # val_accuracy = val_correct / val_total
#             # print(f"Quantized Model - Validation Accuracy: {val_accuracy:.4f}, Loss: {val_epoch_loss:.4f}")

#             # # Save quantized model
#             model_path = f"50_clients_data/clients_trained_model/quantized_{state_name}.pth"
#             torch.save(quantized_net.state_dict(), model_path)
#             print(f"Quantized model saved to {model_path}")
#             print("Quantized Model Size:", os.path.getsize(model_path) / 1024, "KB")


#         # quantization_flag= True
#         # if quantization_flag is True:
#         #     print("Attack on quantized_net")
#         #     # clients_params = [[param.clone().detach() for param in quantized_net.parameters()] for quantized_net in client_nets]
#         #     clients_params = [[param.clone().detach() for param in quantized_net.parameters()] for client_net in client_nets]
#         # else:
#         #     clients_params = [[param.clone().detach() for param in client_net.parameters()] for client_net in client_nets]

#         clients_params = [[param.clone().detach() for param in client_net.parameters()] for client_net in client_nets]

#         # quantization_flag = True
#         # if quantization_flag is True:
#         #     print("Attack on quantized_net")                                    
#         #     quantized_nets = []
#         #     for client_net in client_nets:
#         #         client_net.eval()
#         #         quantized_net = torch.quantization.convert(client_net, inplace=False)
#         #         quantized_nets.append(quantized_net)
#         #     clients_params = [[param.clone().detach() for param in qnet.parameters()] for qnet in quantized_nets]
#         # else:
#         #     clients_params = [[param.clone().detach() for param in client_net.parameters()] for client_net in client_nets]

#         attack_bool= True
#         if attack_bool is False:
#             print("---------------------------------------")
#             print("-----Attack Is NOT Being Applied-------")
#             print("---------------------------------------")
#             break
#         else:
#             # -------------- ATTACK -------------- #
#             per_client_all_reconstructions = [[] for _ in range(len(attacked_clients))]
#             per_client_best_reconstructions = [None for _ in range(len(attacked_clients))]
#             per_client_best_scores = [None for _ in range(len(attacked_clients))]
#             per_client_ground_truth_data = [Xtrain_splits[attacked_client].detach().clone() for attacked_client in attacked_clients]
#             per_client_ground_truth_labels = [ytrain_splits[attacked_client].detach().clone() for attacked_client in attacked_clients]
#             attacked_clients_params = [[param.clone().detach() for param in clients_params[attacked_client]] for attacked_client in attacked_clients]
#             print(attacked_clients_params,"attacked_clients_params")
#             for _ in range(post_selection):

#                 if parallelized:
#                     print("parallelized")
#                     per_client_candidate_reconstructions, per_client_final_losses = fed_avg_attack_parallelized_over_clients(
#                         original_net=copy.deepcopy(net),
#                         attacked_clients_params=attacked_clients_params,
#                         attack_iterations=attack_iterations,
#                         attack_learning_rate=attack_learning_rate,
#                         n_local_epochs=n_local_epochs,
#                         local_batch_size=local_batch_size,
#                         lr=lr,
#                         dataset=dataset,
#                         per_client_ground_truth_data=per_client_ground_truth_data,
#                         per_client_ground_truth_labels=per_client_ground_truth_labels,
#                         reconstruction_loss=reconstruction_loss,
#                         priors=priors,
#                         epoch_matching_prior=epoch_matching_prior,
#                         initialization_mode=initialization_mode,
#                         softmax_trick=softmax_trick,
#                         gumbel_softmax_trick=gumbel_softmax_trick,
#                         sigmoid_trick=sigmoid_trick,
#                         temperature_mode=temperature_mode,
#                         sign_trick=sign_trick,
#                         apply_projection_to_features=fish_for_features,
#                         max_n_cpus=max_n_cpus,
#                         first_cpu=first_cpu,
#                         device=device,
#                         metadata_path=metadata_path
#                     )

#                 else:
#                     print("parallelized--OFF ")
#                     per_client_candidate_reconstructions, per_client_final_losses = fed_avg_attack(
#                         original_net=copy.deepcopy(net),
#                         attacked_clients_params=attacked_clients_params,
#                         attack_iterations=attack_iterations,
#                         attack_learning_rate=attack_learning_rate,
#                         n_local_epochs=n_local_epochs,
#                         local_batch_size=local_batch_size,
#                         lr=lr,
#                         dataset=dataset,
#                         per_client_ground_truth_data=per_client_ground_truth_data,
#                         per_client_ground_truth_labels=per_client_ground_truth_labels,
#                         reconstruction_loss=reconstruction_loss,
#                         priors=priors,
#                         epoch_matching_prior=epoch_matching_prior,
#                         initialization_mode=initialization_mode,
#                         softmax_trick=softmax_trick,
#                         gumbel_softmax_trick=gumbel_softmax_trick,
#                         sigmoid_trick=sigmoid_trick,
#                         temperature_mode=temperature_mode,
#                         sign_trick=sign_trick,
#                         apply_projection_to_features=fish_for_features,
#                         device=device
#                     )

#                 # enter the results in the collectors
#                 for client_idx in range(len(attacked_clients)):
#                     per_client_all_reconstructions[client_idx].append(per_client_candidate_reconstructions[client_idx].detach().clone())
#                     if (per_client_best_scores[client_idx] is None) or (per_client_best_scores[client_idx] > per_client_final_losses[client_idx]):
#                         per_client_best_scores[client_idx] = per_client_final_losses[client_idx]
#                         per_client_best_reconstructions[client_idx] = per_client_candidate_reconstructions[client_idx].detach().clone()

#             if return_all:
#                 per_global_epoch_per_client_reconstructions.append(per_client_all_reconstructions)
#             elif pooling is not None:
#                 if perfect_pooling:
#                     per_client_pooled = [pooled_ensemble(all_reconstructions, ground_truth_data, dataset, pooling=pooling)
#                                         for all_reconstructions, ground_truth_data in zip(per_client_all_reconstructions, per_client_ground_truth_data)]
#                 else:
#                     per_client_pooled = [pooled_ensemble(all_reconstructions, best_reconstruction, dataset, pooling=pooling)
#                                         for all_reconstructions, best_reconstruction in zip(per_client_all_reconstructions, per_client_best_reconstructions)]
#                 per_global_epoch_per_client_reconstructions.append(per_client_pooled)
#             else:
#                 per_global_epoch_per_client_reconstructions.append(per_client_best_reconstructions)
#             per_global_epoch_per_client_ground_truth.append(per_client_ground_truth_data)
#             # -------------- ATTACK END -------------- #

#         # Continue the training
#         # transpose the list
#         transposed_clients_params = [[] for _ in range(len(clients_params[0]))]
#         for client_params in clients_params:
#             for i, param in enumerate(client_params):
#                 transposed_clients_params[i].append(param.clone().detach())

#         # aggregate the params using mean aggregation
#         aggregated_params = [torch.mean(torch.stack(params_over_clients), dim=0) for params_over_clients in transposed_clients_params]

#         # write the new parameters into the main network
#         with torch.no_grad():
#             for param, agg_param in zip(net.parameters(), aggregated_params):
#                 param.copy_(agg_param)
        
#         # timer.end()
#     # timer.duration()

#     # random_baseline = calculate_random_baseline(dataset=dataset, recover_batch_sizes=reconstruction_batch_sizes,
#     #                                                     tolerance_map=tolerance_map, n_samples=n_samples, mode=mode,
#     #                                                     device=args.device)

#     return net, training_data, per_global_epoch_per_client_reconstructions, per_global_epoch_per_client_ground_truth



In [9]:
def train_and_attack_fed_avg(net, n_clients, n_global_epochs, n_local_epochs, local_batch_size, lr, dataset, shuffle=False,
                             attacked_clients=None, attack_iterations=1000, reconstruction_loss='cosine_sim', priors=None,
                             epoch_matching_prior=None, post_selection=1, attack_learning_rate=0.06, return_all=False,
                             pooling=None, perfect_pooling=False, initialization_mode='uniform', softmax_trick=True,
                             gumbel_softmax_trick=False, sigmoid_trick=False, temperature_mode='constant',
                             sign_trick=True, fish_for_features=None, device=None, verbose=False, max_n_cpus=50, first_cpu=0,
                             max_client_dataset_size=None, parallelized=False, metadata_path='metadata', state_name="AL"):

    if device is None:
        device = dataset.device

    if attacked_clients is None:
        attacked_clients = []
    elif attacked_clients == 'all':
        attacked_clients = list(np.arange(n_clients))

    if max_client_dataset_size is None:
        max_client_dataset_size = len(dataset)

    per_global_epoch_per_client_reconstructions = []
    per_global_epoch_per_client_ground_truth = []
    training_data = np.zeros((n_global_epochs, 2))
    

    # Split data into client datasets
    if shuffle:
        dataset.shuffle()

    Xtrain, ytrain = dataset.get_Xtrain(), dataset.get_ytrain()
    split_size = min(max_client_dataset_size, int(np.ceil(Xtrain.size()[0] / n_clients)))
    Xtrain_splits = [Xtrain[i*split_size:min(int(Xtrain.size()[0]), (i+1)*split_size)].clone().detach() for i in range(n_clients)]
    ytrain_splits = [ytrain[i*split_size:min(int(Xtrain.size()[0]), (i+1)*split_size)].clone().detach() for i in range(n_clients)]

    # Loss function
    criterion = torch.nn.BCELoss()
    timer = Timer(n_global_epochs)
    # Load pre-trained model
    pre_trained_model_path = "50_clients_data/clients_trained_model/pre_trained_model.pth"
    state_dict = torch.load(pre_trained_model_path)
    weights_dict = {k: v for k, v in state_dict.items() if 'weight' in k}
    net.load_state_dict(weights_dict, strict=False)
    print("Pre-trained model loaded.")


    def prepare_qat_model(model, backend='qnnpack'):
        """
        Prepare model for Quantization Aware Training
        """
        model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
        torch.backends.quantized.engine = backend
        model.fuse_model()
        torch.quantization.prepare_qat(model, inplace=True)
        return model

    def prepare_qat_model_aggresive(model, backend='qnnpack'):
        """
        Prepare model for Quantization Aware Training with aggressive quantization.
        """
        # Aggressive quantization config
        model.qconfig = torch.quantization.QConfig(
            activation=torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MinMaxObserver, dtype=torch.quint8,
                                                                  qscheme=torch.per_tensor_symmetric),
            weight=torch.quantization.default_per_channel_weight_fake_quant
        )
        torch.backends.quantized.engine = backend
        model.fuse_model()
        torch.quantization.prepare_qat(model, inplace=True)
        return model
    
    def quantize_model_post_train(model, calibration_data, backend='qnnpack'):
        model.qconfig = torch.quantization.get_default_qconfig(backend)
        torch.backends.quantized.engine = backend
        # model.fuse_model()
    
        # Prepare model (add observers)
        quantized_model = torch.quantization.prepare(model, inplace=False)
    
        # Calibration step
        with torch.no_grad():
            for batch_X in calibration_data:
                quantized_model(batch_X)
    
        # Convert to quantized model
        quantized_model = torch.quantization.convert(quantized_model, inplace=False)
        return quantized_model

    # Training loop
    for global_epoch in range(n_global_epochs):
        acc, bac = get_acc_and_bac(net, dataset.get_Xtest(), dataset.get_ytest())
        if verbose:
            print(f'Global Epochs: {global_epoch + 1}/{n_global_epochs}    Acc: {acc * 100:.2f}%    BAcc: {bac * 100:.2f}%')

        training_data[global_epoch] = acc, bac

        # Create client networks

        client_nets = [copy.deepcopy(net) for _ in range(n_clients)]
        client_nets = [prepare_qat_model(client_net) for client_net in client_nets]
    

        for client, (client_X, client_y, client_net) in enumerate(zip(Xtrain_splits, ytrain_splits, client_nets)):
            client_net.train()  # Set to training mode for QAT
            n_batches = int(np.ceil(client_X.size()[0] / local_batch_size))

            print(f"Training client {client + 1}/{n_clients}")
            print("QAT training")
            print("n_batches is", n_batches)
            print("local_epoch is", n_local_epochs)

            # Training loop remains the same
            for local_epoch in range(n_local_epochs):
                for b in range(n_batches):
                    current_batch_X = client_X[b * local_batch_size:min(int(client_X.size()[0]), (b+1)*local_batch_size)].clone().detach()
                    current_batch_y = client_y[b * local_batch_size:min(int(client_X.size()[0]), (b+1)*local_batch_size)].clone().detach()
                    outputs = client_net(current_batch_X)

                    current_batch_y = current_batch_y.unsqueeze(1).float()
                    loss = criterion(outputs, current_batch_y)
                    grad = torch.autograd.grad(loss, client_net.parameters(), retain_graph=True)

                    with torch.no_grad():
                        for param, param_grad in zip(client_net.parameters(), grad):
                            param -= lr * param_grad

            # Evaluation phase
            client_net.eval()

            inputs, labels = dataset.get_Xtest(), dataset.get_ytest()
            labels = labels.unsqueeze(1).float()
            val_n_batches = int(np.ceil(inputs.size()[0] / local_batch_size))
            
            # Save QAT model
            qat_model_path = f"50_clients_data/clients_trained_model/qat_{state_name}.pth"
            torch.save(client_net.state_dict(), qat_model_path)

            print(f"QAT model is saved to {qat_model_path}")
            print("QAT Model Size:", os.path.getsize(qat_model_path) / 1024, "KB")


            quantized_net = torch.quantization.convert(client_net.eval(), inplace=False)

            # # Evaluate quantized model
            val_running_loss = 0.0
            val_correct = 0
            val_total = 0
            quantized_net.eval()
            print("quantized_net is being used for testing")
            with torch.no_grad():
                for b in range(val_n_batches):
                    val_batch_X = inputs[b * local_batch_size:min(int(inputs.size()[0]), (b + 1) * local_batch_size)].clone().detach()
                    val_batch_y = labels[b * local_batch_size:min(int(labels.size()[0]), (b + 1) * local_batch_size)].clone().detach()

                    outputs = quantized_net(val_batch_X)
                    val_loss = criterion(outputs, val_batch_y)
                    val_running_loss += val_loss.item()

                    predicted_classes = (outputs > 0.5).float()
                    val_correct += (predicted_classes == val_batch_y).sum().item()
                    val_total += val_batch_y.size(0)

            val_epoch_loss = val_running_loss / len(inputs)
            val_accuracy = val_correct / val_total
            print(f"Quantized Model - Validation Accuracy: {val_accuracy:.4f}, Loss: {val_epoch_loss:.4f}")

            # # Save quantized model
            model_path = f"50_clients_data/clients_trained_model/quantized_{state_name}.pth"
            torch.save(quantized_net.state_dict(), model_path)
            print(f"Quantized model saved to {model_path}")
            print("Quantized Model Size:", os.path.getsize(model_path) / 1024, "KB")


        # print(list(quantized_net.state_dict().values()),"clients_params--After")
        # clients_params=[list(quantized_net.state_dict().values())]

       #----------------- Need to Check this ------------------------#
        processed_params = []

        # just to check the shape
        original_params = [param.clone().detach() for param in net.parameters()]

        for original_param in original_params:
            matched = False 
            for name, module in quantized_net.named_modules():
                if hasattr(module, 'weight') and module.weight is not None:
                    weight = module.weight
                    
                    weight_tensor = weight() if callable(weight) else weight
                    if isinstance(weight_tensor, torch.Tensor):
                        weight_tensor = weight_tensor.dequantize().detach() if weight_tensor.is_quantized else weight_tensor.detach()
                    
                        if weight_tensor.shape == original_param.shape:
                            processed_params.append(weight_tensor)
                            matched = True
                            break 

                if hasattr(module, 'bias') and module.bias is not None:
                    bias = module.bias
                    
                    bias_tensor = bias() if callable(bias) else bias
                    if isinstance(bias_tensor, torch.Tensor):
                        bias_tensor = bias_tensor.detach()
                       
                        if bias_tensor.shape == original_param.shape:
                            processed_params.append(bias_tensor)
                            matched = True
                            break 

            if not matched:
                print(f"Warning: No matching parameter found for shape {original_param.shape}")

        # processed_params = []

        # # Clone and detach original parameters for comparison
        # original_params = [param.clone().detach() for param in net.parameters()]

        # for original_param in original_params:
        #     matched = False  # Flag to ensure only one match per original_param
        #     for name, module in quantized_net.named_modules():
        #         if hasattr(module, 'weight') and module.weight is not None:
        #             weight = module.weight
        #             # Handle packed weights
        #             weight_tensor = weight() if callable(weight) else weight
        #             if isinstance(weight_tensor, torch.Tensor) and weight_tensor.is_quantized:  # Ensure it's quantized
        #                 weight_tensor = weight_tensor.dequantize().detach()

        #                 # Check if the shape matches the original parameter
        #                 if weight_tensor.shape == original_param.shape:
        #                     processed_params.append(weight_tensor)
        #                     matched = True
        #                     break  # Move to the next original parameter

        #         if hasattr(module, 'bias') and module.bias is not None:
        #             bias = module.bias
        #             # Handle packed biases
        #             bias_tensor = bias() if callable(bias) else bias
        #             if isinstance(bias_tensor, torch.Tensor):  # Bias is typically not quantized
        #                 bias_tensor = bias_tensor.detach()

        #                 # Check if the shape matches the original parameter
        #                 if bias_tensor.shape == original_param.shape:
        #                     processed_params.append(bias_tensor)
        #                     matched = True
        #                     break  # Move to the next original parameter

        #     if not matched:
        #         print(f"Warning: No matching quantized parameter found for shape {original_param.shape}")

        clients_params = processed_params

        attacked_clients_params=[clients_params]

        # THis will not work in Quantization.
        # clients_params = [[param.clone().detach() for param in client_net.parameters()] for client_net in client_nets]
        # print(clients_params,"clients_params After")
        
        attack_bool= True
        if attack_bool is False:
            print("---------------------------------------")
            print("-----Attack Is NOT Being Applied-------")
            print("---------------------------------------")
            break
        else:
            # -------------- ATTACK -------------- #
            per_client_all_reconstructions = [[] for _ in range(len(attacked_clients))]
            per_client_best_reconstructions = [None for _ in range(len(attacked_clients))]
            per_client_best_scores = [None for _ in range(len(attacked_clients))]
            per_client_ground_truth_data = [Xtrain_splits[attacked_client].detach().clone() for attacked_client in attacked_clients]
            per_client_ground_truth_labels = [ytrain_splits[attacked_client].detach().clone() for attacked_client in attacked_clients]
            # attacked_clients_params = [[param.clone().detach() for param in clients_params[attacked_client]] for attacked_client in attacked_clients]

            # attacked_clients_params=[clients_params]
            # print(attacked_clients_params,"attacked_clients_params")
            
            for _ in range(post_selection):

                if parallelized:
                    print("parallelized")
                    per_client_candidate_reconstructions, per_client_final_losses = fed_avg_attack_parallelized_over_clients(
                        original_net=copy.deepcopy(net),
                        attacked_clients_params=attacked_clients_params,
                        attack_iterations=attack_iterations,
                        attack_learning_rate=attack_learning_rate,
                        n_local_epochs=n_local_epochs,
                        local_batch_size=local_batch_size,
                        lr=lr,
                        dataset=dataset,
                        per_client_ground_truth_data=per_client_ground_truth_data,
                        per_client_ground_truth_labels=per_client_ground_truth_labels,
                        reconstruction_loss=reconstruction_loss,
                        priors=priors,
                        epoch_matching_prior=epoch_matching_prior,
                        initialization_mode=initialization_mode,
                        softmax_trick=softmax_trick,
                        gumbel_softmax_trick=gumbel_softmax_trick,
                        sigmoid_trick=sigmoid_trick,
                        temperature_mode=temperature_mode,
                        sign_trick=sign_trick,
                        apply_projection_to_features=fish_for_features,
                        max_n_cpus=max_n_cpus,
                        first_cpu=first_cpu,
                        device=device,
                        metadata_path=metadata_path
                    )

                else:

                    print("parallelized--OFF ")
                    per_client_candidate_reconstructions, per_client_final_losses = fed_avg_attack(
                        original_net=copy.deepcopy(net),
                        attacked_clients_params=attacked_clients_params,
                        attack_iterations=attack_iterations,
                        attack_learning_rate=attack_learning_rate,
                        n_local_epochs=n_local_epochs,
                        local_batch_size=local_batch_size,
                        lr=lr,
                        dataset=dataset,
                        per_client_ground_truth_data=per_client_ground_truth_data,
                        per_client_ground_truth_labels=per_client_ground_truth_labels,
                        reconstruction_loss=reconstruction_loss,
                        priors=priors,
                        epoch_matching_prior=epoch_matching_prior,
                        initialization_mode=initialization_mode,
                        softmax_trick=softmax_trick,
                        gumbel_softmax_trick=gumbel_softmax_trick,
                        sigmoid_trick=sigmoid_trick,
                        temperature_mode=temperature_mode,
                        sign_trick=sign_trick,
                        apply_projection_to_features=fish_for_features,
                        device=device
                    )

                # enter the results in the collectors
                for client_idx in range(len(attacked_clients)):
                    per_client_all_reconstructions[client_idx].append(per_client_candidate_reconstructions[client_idx].detach().clone())
                    if (per_client_best_scores[client_idx] is None) or (per_client_best_scores[client_idx] > per_client_final_losses[client_idx]):
                        per_client_best_scores[client_idx] = per_client_final_losses[client_idx]
                        per_client_best_reconstructions[client_idx] = per_client_candidate_reconstructions[client_idx].detach().clone()

            if return_all:
                per_global_epoch_per_client_reconstructions.append(per_client_all_reconstructions)
            elif pooling is not None:
                if perfect_pooling:
                    per_client_pooled = [pooled_ensemble(all_reconstructions, ground_truth_data, dataset, pooling=pooling)
                                        for all_reconstructions, ground_truth_data in zip(per_client_all_reconstructions, per_client_ground_truth_data)]
                else:
                    per_client_pooled = [pooled_ensemble(all_reconstructions, best_reconstruction, dataset, pooling=pooling)
                                        for all_reconstructions, best_reconstruction in zip(per_client_all_reconstructions, per_client_best_reconstructions)]
                per_global_epoch_per_client_reconstructions.append(per_client_pooled)
            else:
                per_global_epoch_per_client_reconstructions.append(per_client_best_reconstructions)
            per_global_epoch_per_client_ground_truth.append(per_client_ground_truth_data)
            print("# -------------- ATTACK END -------------- #")
            # -------------- ATTACK END -------------- #

        # Continue the training
        # transpose the list
        transposed_clients_params = [[] for _ in range(len(clients_params[0]))]
        for client_params in clients_params:
            for i, param in enumerate(client_params):
                transposed_clients_params[i].append(param.clone().detach())

        # aggregate the params using mean aggregation
        # aggregated_params = [torch.mean(torch.stack(params_over_clients), dim=0) for params_over_clients in transposed_clients_params]

        # # write the new parameters into the main network
        # with torch.no_grad():
        #     for param, agg_param in zip(net.parameters(), aggregated_params):
        #         param.copy_(agg_param)
        
        # timer.end()
    # timer.duration()

    # random_baseline = calculate_random_baseline(dataset=dataset, recover_batch_sizes=reconstruction_batch_sizes,
    #                                                     tolerance_map=tolerance_map, n_samples=n_samples, mode=mode,
    #                                                     device=args.device)

    return net, training_data, per_global_epoch_per_client_reconstructions, per_global_epoch_per_client_ground_truth



In [12]:
def calculate_fed_avg_local_dataset_inversion_performance(architecture_layout, dataset, max_client_dataset_size,
                                                          local_epochs, local_batch_sizes, epoch_prior_params,
                                                          tolerance_map, n_samples, config, max_n_cpus, first_cpu, device, state_name="AL"):
    
    collected_data = np.zeros((len(local_epochs), len(local_batch_sizes), len(epoch_prior_params), 3, 5))

    timer = Timer(len(local_epochs) * len(local_batch_sizes) * len(epoch_prior_params))
    with open(f'50_clients_data/reconstr_and_GT/dataset_{state_name}.pkl', 'wb') as f:
        pickle.dump(dataset, f)

    with open(f'50_clients_data/reconstr_and_GT/tolerance_map_{state_name}.pkl', 'wb') as f:        
        pickle.dump(tolerance_map, f)   
        
    for i, lepochs in enumerate(local_epochs):
        for j, lbatch_size in enumerate(local_batch_sizes):
            for k, epoch_prior_param in enumerate(epoch_prior_params):
                timer.start()                
                print(timer)
                print(dataset.num_features)
                net = FullyConnected(dataset.num_features, architecture_layout)
                # print(net)
                epoch_matching_prior = (epoch_prior_param, config['epoch_matching_prior']) if epoch_prior_param > 0. else None

                _, _, reconstructions, ground_truths= train_and_attack_fed_avg(
                    net=net,
                    n_clients=n_samples,
                    n_global_epochs=config['n_global_epochs'],
                    n_local_epochs=lepochs,
                    local_batch_size=lbatch_size,
                    lr=config['lr'],
                    dataset=dataset,
                    shuffle=config['shuffle'],
                    attacked_clients=config['attacked_clients'],
                    attack_iterations=config['attack_iterations'],
                    reconstruction_loss=config['reconstruction_loss'],
                    priors=config['priors'],
                    epoch_matching_prior=epoch_matching_prior,
                    post_selection=config['post_selection'],
                    attack_learning_rate=config['attack_learning_rate'],
                    return_all=config['return_all'],
                    pooling=config['pooling'],
                    perfect_pooling=config['perfect_pooling'],
                    initialization_mode=config['initialization_mode'],
                    softmax_trick=config['softmax_trick'],
                    gumbel_softmax_trick=config['gumbel_softmax_trick'],
                    sigmoid_trick=config['sigmoid_trick'],
                    temperature_mode=config['temperature_mode'],
                    sign_trick=config['sign_trick'],
                    fish_for_features=None,
                    max_n_cpus=max_n_cpus,
                    first_cpu=first_cpu,
                    device=device,
                    verbose=False,
                    max_client_dataset_size=max_client_dataset_size,
                    parallelized=False,
                    state_name=state_name
                )
                all_errors = []
                cat_errors = []
                cont_errors = []
                
                with open(f'50_clients_data/reconstr_and_GT/reconstructions_ground_truths_{state_name}.pkl', 'wb') as f:
                        pickle.dump({'reconstructions': reconstructions, 'ground_truths': ground_truths}, f)
                
                print("reconstructions_and_ground_truths is dumped")


                for epoch_reconstruction, epoch_ground_truth in zip(reconstructions, ground_truths):
                    for client_reconstruction, client_ground_truth in zip(epoch_reconstruction, epoch_ground_truth):
                        if config['post_process_cont']:
                            client_reconstruction = post_process_continuous(client_reconstruction, dataset=dataset)
                        client_recon_projected, client_gt_projected = dataset.decode_batch(client_reconstruction, standardized=True), dataset.decode_batch(client_ground_truth, standardized=True)
                        _, batch_cost_all, batch_cost_cat, batch_cost_cont = match_reconstruction_ground_truth(client_gt_projected, client_recon_projected, tolerance_map)
                        all_errors.append(np.mean(batch_cost_all))
                        cat_errors.append(np.mean(batch_cost_cat))
                        cont_errors.append(np.mean(batch_cost_cont))

                collected_data[i, j, k, 0] = np.mean(all_errors), np.std(all_errors), np.median(all_errors), np.min(all_errors), np.max(all_errors)
                collected_data[i, j, k, 1] = np.mean(cat_errors), np.std(cat_errors), np.median(cat_errors), np.min(cat_errors), np.max(cat_errors)
                collected_data[i, j, k, 2] = np.mean(cont_errors), np.std(cont_errors), np.median(cont_errors), np.min(cont_errors), np.max(cont_errors)

                timer.end()

            best_param_index = np.argmin(collected_data[i, j, :, 0, 0]).item()

            print(f'Performance at {lepochs} Epochs and {lbatch_size} Batch Size: {100*(1-collected_data[i, j, best_param_index, 0, 0]):.1f}% +- {100*collected_data[i, j, best_param_index, 0, 1]:.2f}')
            
            display_map = {
                'mean': 0,
                'std': 1,
                'median': 2,
                'min': 3,
                'max': 4
            }
            display = 'mean'
            random_baseline = calculate_random_baseline(dataset=dataset, recover_batch_sizes=[lbatch_size],
                                                        tolerance_map=tolerance_map, n_samples=n_samples)
            batch_sizes = [lbatch_size]
            # print("random acc:  ",random_baseline)
            for l, batch_size in enumerate(batch_sizes):
                print("random_baseline", (np.around(100 - 100*random_baseline[l, 0, display_map[display]], 1), np.around(100*random_baseline[l, 0, 1], 1)))

def main(args):
    datasets = {
        'ADULT': ADULT,
    }

    configs = {
        0: {
            'n_global_epochs': 1,
            'lr': 0.01,
            'shuffle': True,
            'attacked_clients': 'all',
            'attack_iterations': 1500,
            'reconstruction_loss': 'cosine_sim',
            'priors': None,
            'epoch_matching_prior': 'mean_squared_error',
            'post_selection': 1,
            'attack_learning_rate': 0.06,
            'return_all': False,
            'pooling': None,
            'perfect_pooling': False,
            'initialization_mode': 'uniform',
            'softmax_trick': False,
            'gumbel_softmax_trick': False,
            'sigmoid_trick': False,
            'temperature_mode': 'constant',
            'sign_trick': True,
            'verbose': False,
            'max_client_dataset_size': 32,
            'post_process_cont': False
        }
    }

    architecture_layout = [100, 100, 2]  
    max_client_dataset_size = 2000
    local_epochs =[5]
    local_batch_sizes = [250]
    epoch_prior_params =[0.01]
    tol = 0.319

    config = configs[0]
    # dataset = ADULT(device="cpu", random_state=2, name_state="AL")
    dataset = ADULT(device="cpu", random_state=42,name_state=args.name_state)

    dataset.standardize()
    tolerance_map = dataset.create_tolerance_map(tol=tol)

    np.random.seed(2)
    torch.manual_seed(2)

    base_path = f'experiment_data/fedavg_experiments/ADULT/experiment_0'
    os.makedirs(base_path, exist_ok=True)
    specific_file_path = base_path + f'/inversion_data_all_0_ADULT_1_{epoch_prior_params}_{tol}_2_{args.name_state}.npy'

    if os.path.isfile(specific_file_path):
        print('This experiment has already been conducted')
        os.remove(specific_file_path)
        print(f'File {specific_file_path} has been removed.')
    else:
        inversion_data = calculate_fed_avg_local_dataset_inversion_performance(
            architecture_layout=architecture_layout,
            dataset=dataset,
            max_client_dataset_size=max_client_dataset_size,
            local_epochs=local_epochs,
            local_batch_sizes=local_batch_sizes,
            epoch_prior_params=epoch_prior_params,
            tolerance_map=tolerance_map,
            n_samples=1,
            config=config,
            max_n_cpus=4,
            first_cpu=0,
            device="cpu",
            state_name=args.name_state
        )
        np.save(specific_file_path, inversion_data)
    print('Complete                           ')
    print('==================================================================')
    print('==================================================================')

import argparse

# state_name = ["AK","AZ","AR","CA","ID","NH","NM","NY"]
state_name = ["NY"]

if __name__ == '__main__':
    import sys
    if 'ipykernel' in sys.modules:  # Check if running in Jupyter Notebook
        class Args:
            def __init__(self, name_state):
                self.name_state = name_state
    else:
        import argparse
        parser = argparse.ArgumentParser('run_inversion_parser')
        parser.add_argument('--name_state', type=str, default='CA', help='State Code')
        args = parser.parse_args()

    for state in state_name:
        if 'ipykernel' in sys.modules:
            args = Args(name_state=state) 
        else:
            args.name_state = state
        main(args)  

State Code::  NY
training sample:: NY.data and len is 2000
testing sample:: NY.test and len is 99
0%: ??h ??m ??s          
10
Pre-trained model loaded.
Training client 1/1
QAT training
n_batches is 8
local_epoch is 5
QAT model is saved to 50_clients_data/clients_trained_model/qat_NY.pth
QAT Model Size: 78.2861328125 KB
quantized_net is being used for testing
Quantized Model - Validation Accuracy: 0.7677, Loss: 0.0048




Quantized model saved to 50_clients_data/clients_trained_model/quantized_NY.pth
Quantized Model Size: 30.552734375 KB
parallelized--OFF 
Original param shape: torch.Size([100, 10]), New param shape: torch.Size([100, 10])
Original param shape: torch.Size([100]), New param shape: torch.Size([100])
Original param shape: torch.Size([100, 100]), New param shape: torch.Size([100, 100])
Original param shape: torch.Size([100]), New param shape: torch.Size([100])
Original param shape: torch.Size([1, 100]), New param shape: torch.Size([1, 100])
Original param shape: torch.Size([1]), New param shape: torch.Size([1])
# -------------- ATTACK END -------------- #
reconstructions_and_ground_truths is dumped
Performance at 5 Epochs and 250 Batch Size: 56.9% +- 0.00
random_baseline (50.8, 0.0)
Complete                           


In [None]:
#QUANTIZED

# AK: Performance at 5 Epochs and 250 Batch Size: 55.3% +- 0.00  - from 66.50%
# AR: Performance at 5 Epochs and 250 Batch Size: 51.0% +- 0.00  - from 72.8% 
# AZ: Performance at 5 Epochs and 250 Batch Size: 63.4% +- 0.00  - from 69.90%
# CA: Performance at 5 Epochs and 250 Batch Size: 65.1% +- 0.00  - from 68.40% 
# ID: Performance at 5 Epochs and 250 Batch Size: 60.2% +- 0.00  - from 71.50%
# NH: Performance at 5 Epochs and 250 Batch Size: 59.4% +- 0.00  - from 71.70%
# NM: Performance at 5 Epochs and 250 Batch Size: 62.9% +- 0.00  - from 72.50%
# NY: Performance at 5 Epochs and 250 Batch Size: 56.9% +- 0.00  - from 67.11


In [9]:
# AK: Performance at 5 Epochs and 250 Batch Size: 66.7% +- 0.00   - from 73%
# AZ: Performance at 5 Epochs and 250 Batch Size: 66.6% +- 0.00 -  from 70%
# AR: Performance at 5 Epochs and 250 Batch Size: 69.2% +- 0.00 -  from 73.08%
# CA: Performance at 5 Epochs and 250 Batch Size: 67.1% +- 0.00  - from 68.98 
# ID: Performance at 5 Epochs and 250 Batch Size: 68.5% +- 0.00  - from 72.22%
# NH: Performance at 5 Epochs and 250 Batch Size: 66.3% +- 0.00  - from 71.59
# NM: Performance at 5 Epochs and 250 Batch Size: 69.3% +- 0.00  - from 72.11
# NY: Performance at 5 Epochs and 250 Batch Size: 64.7% +- 0.00 -  from 67.11


# Post quantization traning

In [5]:
class LinReLU(nn.Module):
    """
    A linear layer followed by a ReLU activation layer.
    """

    def __init__(self, in_size, out_size):
        super(LinReLU, self).__init__()
        self.linear = nn.Linear(in_size, out_size)
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(self.linear, self.relu)

    def reset_parameters(self):
        self.linear.reset_parameters()
        return self

    def forward(self, x):
        x = self.layers(x)
        return x


class FullyConnected(nn.Module):
    """
    A simple fully connected neural network with ReLU activations.
    """
    def __init__(self, input_size, layout):
        super(FullyConnected, self).__init__()
        layers = [nn.Flatten()]
        prev_fc_size = input_size
        for i, fc_size in enumerate(layout):
            if i + 1 < len(layout):
                layers += [LinReLU(prev_fc_size, fc_size)]
            else:
                layers += [nn.Linear(prev_fc_size, 1), nn.Sigmoid()]
            prev_fc_size = fc_size
        self.layers = nn.Sequential(*layers)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.layers(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        """
        Fuses Linear and ReLU layers in LinReLU modules
        """
        for module in self.modules():
            if isinstance(module, LinReLU):
                torch.quantization.fuse_modules(
                    module.layers, 
                    ['0', '1'],  # Fuse first (Linear) and second (ReLU) 
                    inplace=True
                )

In [6]:
def train_and_attack_fed_avg(net, n_clients, n_global_epochs, n_local_epochs, local_batch_size, lr, dataset, shuffle=False,
                             attacked_clients=None, attack_iterations=1000, reconstruction_loss='cosine_sim', priors=None,
                             epoch_matching_prior=None, post_selection=1, attack_learning_rate=0.06, return_all=False,
                             pooling=None, perfect_pooling=False, initialization_mode='uniform', softmax_trick=True,
                             gumbel_softmax_trick=False, sigmoid_trick=False, temperature_mode='constant',
                             sign_trick=True, fish_for_features=None, device=None, verbose=False, max_n_cpus=50, first_cpu=0,
                             max_client_dataset_size=None, parallelized=False, metadata_path='metadata', state_name="AL"):

    if device is None:
        device = dataset.device

    if attacked_clients is None:
        attacked_clients = []
    elif attacked_clients == 'all':
        attacked_clients = list(np.arange(n_clients))

    if max_client_dataset_size is None:
        max_client_dataset_size = len(dataset)

    per_global_epoch_per_client_reconstructions = []
    per_global_epoch_per_client_ground_truth = []
    training_data = np.zeros((n_global_epochs, 2))
    

    # Split data into client datasets
    if shuffle:
        dataset.shuffle()

    Xtrain, ytrain = dataset.get_Xtrain(), dataset.get_ytrain()
    split_size = min(max_client_dataset_size, int(np.ceil(Xtrain.size()[0] / n_clients)))
    Xtrain_splits = [Xtrain[i*split_size:min(int(Xtrain.size()[0]), (i+1)*split_size)].clone().detach() for i in range(n_clients)]
    ytrain_splits = [ytrain[i*split_size:min(int(Xtrain.size()[0]), (i+1)*split_size)].clone().detach() for i in range(n_clients)]

    # Loss function
    criterion = torch.nn.BCELoss()
    timer = Timer(n_global_epochs)
    # Load pre-trained model
    pre_trained_model_path = "50_clients_data/clients_trained_model/pre_trained_model.pth"
    state_dict = torch.load(pre_trained_model_path)
    weights_dict = {k: v for k, v in state_dict.items() if 'weight' in k}
    net.load_state_dict(weights_dict, strict=False)
    print("Pre-trained model loaded.")

    # Helper to quantize the model
    # def quantize_model(model, calibration_data, backend='qnnpack'):
    #     model.qconfig = torch.quantization.get_default_qconfig(backend)
    #     torch.backends.quantized.engine = backend

    #     # Prepare model (add observers)
    #     quantized_model = torch.quantization.prepare(model, inplace=False)

    #     # Calibration step
    #     with torch.no_grad():
    #         for batch_X in calibration_data:
    #             quantized_model(batch_X)

    #     # Convert to quantized model
    #     quantized_model = torch.quantization.convert(quantized_model, inplace=False)
    #     return quantized_model
    
    def quantize_model(model, calibration_data, backend='qnnpack'):
        model.qconfig = torch.quantization.get_default_qconfig(backend)
        torch.backends.quantized.engine = backend
        model.fuse_model()
    
        # Prepare model (add observers)
        quantized_model = torch.quantization.prepare(model, inplace=False)
    
        # Calibration step
        with torch.no_grad():
            for batch_X in calibration_data:
                quantized_model(batch_X)
    
        # Convert to quantized model
        quantized_model = torch.quantization.convert(quantized_model, inplace=False)
        return quantized_model

    # Training loop
    for global_epoch in range(n_global_epochs):
        acc, bac = get_acc_and_bac(net, dataset.get_Xtest(), dataset.get_ytest())
        if verbose:
            print(f'Global Epochs: {global_epoch + 1}/{n_global_epochs}    Acc: {acc * 100:.2f}%    BAcc: {bac * 100:.2f}%')

        training_data[global_epoch] = acc, bac

        # Create client networks
        client_nets = [copy.deepcopy(net) for _ in range(n_clients)]

        for client, (client_X, client_y, client_net) in enumerate(zip(Xtrain_splits, ytrain_splits, client_nets)):
            n_batches = int(np.ceil(client_X.size()[0] / local_batch_size))

            print(f"Training client {client + 1}/{n_clients}")
            print("Normal training")
            print("n_batches is",n_batches)
            print("local_epoch is ",n_local_epochs)
            for local_epoch in range(n_local_epochs):
                for b in range(n_batches):
                    current_batch_X = client_X[b * local_batch_size:min(int(client_X.size()[0]), (b+1)*local_batch_size)].clone().detach()
                    current_batch_y = client_y[b * local_batch_size:min(int(client_X.size()[0]), (b+1)*local_batch_size)].clone().detach()
                    outputs = client_net(current_batch_X)

                    current_batch_y = current_batch_y.unsqueeze(1).float()
                    loss = criterion(outputs, current_batch_y)
                    grad = torch.autograd.grad(loss, client_net.parameters(), retain_graph=True)

                    with torch.no_grad():
                        for param, param_grad in zip(client_net.parameters(), grad):
                            param -= lr * param_grad
            
            client_net.eval()
            val_running_loss = 0.0
            val_correct = 0
            val_total = 0

            inputs, labels = dataset.get_Xtest(), dataset.get_ytest()
            labels = labels.unsqueeze(1).float()
            val_n_batches = int(np.ceil(inputs.size()[0] / local_batch_size))
            with torch.no_grad():
                for b in range(val_n_batches):
                    val_batch_X = inputs[b * local_batch_size:min(int(inputs.size()[0]), (b + 1) * local_batch_size)].clone().detach()
                    val_batch_y = labels[b * local_batch_size:min(int(labels.size()[0]), (b + 1) * local_batch_size)].clone().detach()

                    outputs = client_net(val_batch_X)
                    val_loss = criterion(outputs, val_batch_y)
                    val_running_loss += val_loss.item()

                    predicted_classes = (outputs > 0.5).float()
                    val_correct += (predicted_classes == val_batch_y).sum().item()
                    val_total += val_batch_y.size(0)

            val_epoch_loss = val_running_loss / len(inputs)
            val_accuracy = val_correct / val_total
            print(f"Original Model - Validation Accuracy: {val_accuracy:.4f}, Loss: {val_epoch_loss:.4f}")

            # Save the quantized model
            model_path = f"50_clients_data/clients_trained_model/{state_name}.pth"
            torch.save(client_net.state_dict(), model_path)
            print(f"Original model for client {client + 1} saved to {model_path}")
            print("Original Model Size:", os.path.getsize(model_path) / 1024, "KB")
            
            # Quantization step
            print(f"Quantizing model for client {client + 1}")
            calibration_data = [
                client_X[b * local_batch_size:min(int(client_X.size()[0]), (b+1)*local_batch_size)].clone().detach()
                for b in range(n_batches)
            ]

            quantized_net = quantize_model(client_net, calibration_data)

            quantized_net.eval()
            val_running_loss = 0.0
            val_correct = 0
            val_total = 0

            inputs, labels = dataset.get_Xtest(), dataset.get_ytest()
            labels = labels.unsqueeze(1).float()
            val_n_batches = int(np.ceil(inputs.size()[0] / local_batch_size))

            with torch.no_grad():
                for b in range(val_n_batches):
                    val_batch_X = inputs[b * local_batch_size:min(int(inputs.size()[0]), (b + 1) * local_batch_size)].clone().detach()
                    val_batch_y = labels[b * local_batch_size:min(int(labels.size()[0]), (b + 1) * local_batch_size)].clone().detach()

                    outputs = quantized_net(val_batch_X)
                    val_loss = criterion(outputs, val_batch_y)
                    val_running_loss += val_loss.item()

                    predicted_classes = (outputs > 0.5).float()
                    val_correct += (predicted_classes == val_batch_y).sum().item()
                    val_total += val_batch_y.size(0)

            val_epoch_loss = val_running_loss / len(inputs)
            val_accuracy = val_correct / val_total
            print(f"Quantized Model - Validation Accuracy: {val_accuracy:.4f}, Loss: {val_epoch_loss:.4f}")

            # Save the quantized model
            model_path = f"50_clients_data/clients_trained_model/quantized_{state_name}.pth"
            torch.save(quantized_net.state_dict(), model_path)
            
            print("Quantized Model Size:", os.path.getsize(model_path) / 1024, "KB")
            
            print(f"Quantized model saved to {model_path}")

        quantization_flag= True
        if quantization_flag is True:
            print("Attack on quantized_net")
            clients_params = [[param.clone().detach() for param in quantized_net.parameters()] for quantized_net in client_nets]
        else:
            clients_params = [[param.clone().detach() for param in client_net.parameters()] for client_net in client_nets]

        attack_bool= True
        if attack_bool is False:
            print("---------------------------------------")
            print("-----Attack Is NOT Being Applied-------")
            print("---------------------------------------")
            break
        else:
            # -------------- ATTACK -------------- #
            per_client_all_reconstructions = [[] for _ in range(len(attacked_clients))]
            per_client_best_reconstructions = [None for _ in range(len(attacked_clients))]
            per_client_best_scores = [None for _ in range(len(attacked_clients))]
            per_client_ground_truth_data = [Xtrain_splits[attacked_client].detach().clone() for attacked_client in attacked_clients]
            per_client_ground_truth_labels = [ytrain_splits[attacked_client].detach().clone() for attacked_client in attacked_clients]
            attacked_clients_params = [[param.clone().detach() for param in clients_params[attacked_client]] for attacked_client in attacked_clients]

            for _ in range(post_selection):

                if parallelized:
                    print("parallelized")
                    per_client_candidate_reconstructions, per_client_final_losses = fed_avg_attack_parallelized_over_clients(
                        original_net=copy.deepcopy(net),
                        attacked_clients_params=attacked_clients_params,
                        attack_iterations=attack_iterations,
                        attack_learning_rate=attack_learning_rate,
                        n_local_epochs=n_local_epochs,
                        local_batch_size=local_batch_size,
                        lr=lr,
                        dataset=dataset,
                        per_client_ground_truth_data=per_client_ground_truth_data,
                        per_client_ground_truth_labels=per_client_ground_truth_labels,
                        reconstruction_loss=reconstruction_loss,
                        priors=priors,
                        epoch_matching_prior=epoch_matching_prior,
                        initialization_mode=initialization_mode,
                        softmax_trick=softmax_trick,
                        gumbel_softmax_trick=gumbel_softmax_trick,
                        sigmoid_trick=sigmoid_trick,
                        temperature_mode=temperature_mode,
                        sign_trick=sign_trick,
                        apply_projection_to_features=fish_for_features,
                        max_n_cpus=max_n_cpus,
                        first_cpu=first_cpu,
                        device=device,
                        metadata_path=metadata_path
                    )

                else:
                    print("parallelized--OFF ")
                    per_client_candidate_reconstructions, per_client_final_losses = fed_avg_attack(
                        original_net=copy.deepcopy(net),
                        attacked_clients_params=attacked_clients_params,
                        attack_iterations=attack_iterations,
                        attack_learning_rate=attack_learning_rate,
                        n_local_epochs=n_local_epochs,
                        local_batch_size=local_batch_size,
                        lr=lr,
                        dataset=dataset,
                        per_client_ground_truth_data=per_client_ground_truth_data,
                        per_client_ground_truth_labels=per_client_ground_truth_labels,
                        reconstruction_loss=reconstruction_loss,
                        priors=priors,
                        epoch_matching_prior=epoch_matching_prior,
                        initialization_mode=initialization_mode,
                        softmax_trick=softmax_trick,
                        gumbel_softmax_trick=gumbel_softmax_trick,
                        sigmoid_trick=sigmoid_trick,
                        temperature_mode=temperature_mode,
                        sign_trick=sign_trick,
                        apply_projection_to_features=fish_for_features,
                        device=device
                    )

                # enter the results in the collectors
                for client_idx in range(len(attacked_clients)):
                    per_client_all_reconstructions[client_idx].append(per_client_candidate_reconstructions[client_idx].detach().clone())
                    if (per_client_best_scores[client_idx] is None) or (per_client_best_scores[client_idx] > per_client_final_losses[client_idx]):
                        per_client_best_scores[client_idx] = per_client_final_losses[client_idx]
                        per_client_best_reconstructions[client_idx] = per_client_candidate_reconstructions[client_idx].detach().clone()

            if return_all:
                per_global_epoch_per_client_reconstructions.append(per_client_all_reconstructions)
            elif pooling is not None:
                if perfect_pooling:
                    per_client_pooled = [pooled_ensemble(all_reconstructions, ground_truth_data, dataset, pooling=pooling)
                                        for all_reconstructions, ground_truth_data in zip(per_client_all_reconstructions, per_client_ground_truth_data)]
                else:
                    per_client_pooled = [pooled_ensemble(all_reconstructions, best_reconstruction, dataset, pooling=pooling)
                                        for all_reconstructions, best_reconstruction in zip(per_client_all_reconstructions, per_client_best_reconstructions)]
                per_global_epoch_per_client_reconstructions.append(per_client_pooled)
            else:
                per_global_epoch_per_client_reconstructions.append(per_client_best_reconstructions)
            per_global_epoch_per_client_ground_truth.append(per_client_ground_truth_data)
            # -------------- ATTACK END -------------- #

        # Continue the training
        # transpose the list
        transposed_clients_params = [[] for _ in range(len(clients_params[0]))]
        for client_params in clients_params:
            for i, param in enumerate(client_params):
                transposed_clients_params[i].append(param.clone().detach())

        # aggregate the params using mean aggregation
        aggregated_params = [torch.mean(torch.stack(params_over_clients), dim=0) for params_over_clients in transposed_clients_params]

        # write the new parameters into the main network
        with torch.no_grad():
            for param, agg_param in zip(net.parameters(), aggregated_params):
                param.copy_(agg_param)
        
        # timer.end()
    # timer.duration()

    # random_baseline = calculate_random_baseline(dataset=dataset, recover_batch_sizes=reconstruction_batch_sizes,
    #                                                     tolerance_map=tolerance_map, n_samples=n_samples, mode=mode,
    #                                                     device=args.device)

    return net, training_data, per_global_epoch_per_client_reconstructions, per_global_epoch_per_client_ground_truth



In [7]:
def calculate_fed_avg_local_dataset_inversion_performance(architecture_layout, dataset, max_client_dataset_size,
                                                          local_epochs, local_batch_sizes, epoch_prior_params,
                                                          tolerance_map, n_samples, config, max_n_cpus, first_cpu, device, state_name="AL"):
    
    collected_data = np.zeros((len(local_epochs), len(local_batch_sizes), len(epoch_prior_params), 3, 5))

    timer = Timer(len(local_epochs) * len(local_batch_sizes) * len(epoch_prior_params))
    with open(f'50_clients_data/reconstr_and_GT/dataset_{state_name}.pkl', 'wb') as f:
        pickle.dump(dataset, f)

    with open(f'50_clients_data/reconstr_and_GT/tolerance_map_{state_name}.pkl', 'wb') as f:        
        pickle.dump(tolerance_map, f)   
        
    for i, lepochs in enumerate(local_epochs):
        for j, lbatch_size in enumerate(local_batch_sizes):
            for k, epoch_prior_param in enumerate(epoch_prior_params):
                timer.start()                
                print(timer)

                net = FullyConnected(dataset.num_features, architecture_layout)

                epoch_matching_prior = (epoch_prior_param, config['epoch_matching_prior']) if epoch_prior_param > 0. else None

                _, _, reconstructions, ground_truths= train_and_attack_fed_avg(
                    net=net,
                    n_clients=n_samples,
                    n_global_epochs=config['n_global_epochs'],
                    n_local_epochs=lepochs,
                    local_batch_size=lbatch_size,
                    lr=config['lr'],
                    dataset=dataset,
                    shuffle=config['shuffle'],
                    attacked_clients=config['attacked_clients'],
                    attack_iterations=config['attack_iterations'],
                    reconstruction_loss=config['reconstruction_loss'],
                    priors=config['priors'],
                    epoch_matching_prior=epoch_matching_prior,
                    post_selection=config['post_selection'],
                    attack_learning_rate=config['attack_learning_rate'],
                    return_all=config['return_all'],
                    pooling=config['pooling'],
                    perfect_pooling=config['perfect_pooling'],
                    initialization_mode=config['initialization_mode'],
                    softmax_trick=config['softmax_trick'],
                    gumbel_softmax_trick=config['gumbel_softmax_trick'],
                    sigmoid_trick=config['sigmoid_trick'],
                    temperature_mode=config['temperature_mode'],
                    sign_trick=config['sign_trick'],
                    fish_for_features=None,
                    max_n_cpus=max_n_cpus,
                    first_cpu=first_cpu,
                    device=device,
                    verbose=False,
                    max_client_dataset_size=max_client_dataset_size,
                    parallelized=False,
                    state_name=state_name
                )
                all_errors = []
                cat_errors = []
                cont_errors = []
                
                with open(f'50_clients_data/reconstr_and_GT/reconstructions_ground_truths_{state_name}.pkl', 'wb') as f:
                        pickle.dump({'reconstructions': reconstructions, 'ground_truths': ground_truths}, f)
                
                print("reconstructions_and_ground_truths is dumped")


                for epoch_reconstruction, epoch_ground_truth in zip(reconstructions, ground_truths):
                    for client_reconstruction, client_ground_truth in zip(epoch_reconstruction, epoch_ground_truth):
                        if config['post_process_cont']:
                            client_reconstruction = post_process_continuous(client_reconstruction, dataset=dataset)
                        client_recon_projected, client_gt_projected = dataset.decode_batch(client_reconstruction, standardized=True), dataset.decode_batch(client_ground_truth, standardized=True)
                        _, batch_cost_all, batch_cost_cat, batch_cost_cont = match_reconstruction_ground_truth(client_gt_projected, client_recon_projected, tolerance_map)
                        all_errors.append(np.mean(batch_cost_all))
                        cat_errors.append(np.mean(batch_cost_cat))
                        cont_errors.append(np.mean(batch_cost_cont))

                collected_data[i, j, k, 0] = np.mean(all_errors), np.std(all_errors), np.median(all_errors), np.min(all_errors), np.max(all_errors)
                collected_data[i, j, k, 1] = np.mean(cat_errors), np.std(cat_errors), np.median(cat_errors), np.min(cat_errors), np.max(cat_errors)
                collected_data[i, j, k, 2] = np.mean(cont_errors), np.std(cont_errors), np.median(cont_errors), np.min(cont_errors), np.max(cont_errors)

                timer.end()

            best_param_index = np.argmin(collected_data[i, j, :, 0, 0]).item()

            print(f'Performance at {lepochs} Epochs and {lbatch_size} Batch Size: {100*(1-collected_data[i, j, best_param_index, 0, 0]):.1f}% +- {100*collected_data[i, j, best_param_index, 0, 1]:.2f}')
            
            display_map = {
                'mean': 0,
                'std': 1,
                'median': 2,
                'min': 3,
                'max': 4
            }
            display = 'mean'
            random_baseline = calculate_random_baseline(dataset=dataset, recover_batch_sizes=[lbatch_size],
                                                        tolerance_map=tolerance_map, n_samples=n_samples)
            batch_sizes = [lbatch_size]
            # print("random acc:  ",random_baseline)
            for l, batch_size in enumerate(batch_sizes):
                print("random_baseline", (np.around(100 - 100*random_baseline[l, 0, display_map[display]], 1), np.around(100*random_baseline[l, 0, 1], 1)))

def main(args):
    datasets = {
        'ADULT': ADULT,
    }

    configs = {
        0: {
            'n_global_epochs': 1,
            'lr': 0.01,
            'shuffle': True,
            'attacked_clients': 'all',
            'attack_iterations': 1500,
            'reconstruction_loss': 'cosine_sim',
            'priors': None,
            'epoch_matching_prior': 'mean_squared_error',
            'post_selection': 1,
            'attack_learning_rate': 0.06,
            'return_all': False,
            'pooling': None,
            'perfect_pooling': False,
            'initialization_mode': 'uniform',
            'softmax_trick': False,
            'gumbel_softmax_trick': False,
            'sigmoid_trick': False,
            'temperature_mode': 'constant',
            'sign_trick': True,
            'verbose': False,
            'max_client_dataset_size': 32,
            'post_process_cont': False
        }
    }

    architecture_layout = [100, 100, 2]  
    max_client_dataset_size = 2000
    local_epochs =[5]
    local_batch_sizes = [250]
    epoch_prior_params =[0.01]
    tol = 0.319

    config = configs[0]
    # dataset = ADULT(device="cpu", random_state=2, name_state="AL")
    dataset = ADULT(device="cpu", random_state=42,name_state=args.name_state)

    dataset.standardize()
    tolerance_map = dataset.create_tolerance_map(tol=tol)

    np.random.seed(2)
    torch.manual_seed(2)

    base_path = f'experiment_data/fedavg_experiments/ADULT/experiment_0'
    os.makedirs(base_path, exist_ok=True)
    specific_file_path = base_path + f'/inversion_data_all_0_ADULT_1_{epoch_prior_params}_{tol}_2_{args.name_state}.npy'

    if os.path.isfile(specific_file_path):
        print('This experiment has already been conducted')
        os.remove(specific_file_path)
        print(f'File {specific_file_path} has been removed.')
    else:
        inversion_data = calculate_fed_avg_local_dataset_inversion_performance(
            architecture_layout=architecture_layout,
            dataset=dataset,
            max_client_dataset_size=max_client_dataset_size,
            local_epochs=local_epochs,
            local_batch_sizes=local_batch_sizes,
            epoch_prior_params=epoch_prior_params,
            tolerance_map=tolerance_map,
            n_samples=1,
            config=config,
            max_n_cpus=4,
            first_cpu=0,
            device="cpu",
            state_name=args.name_state
        )
        np.save(specific_file_path, inversion_data)
    print('Complete                           ')
    print('==================================================================')
    print('==================================================================')

import argparse
# state_name=["AZ","ID"]
# if __name__ == '__main__':
#     import sys
#     if 'ipykernel' in sys.modules:  # Check if running in Jupyter Notebook
#         class Args:
#             name_state = 'AZ'  
#         args = Args()
#     else:
#         parser = argparse.ArgumentParser('run_inversion_parser')
#         parser.add_argument('--name_state', type=str, default='CA', help='State Code')
#         args = parser.parse_args()
    
#     main(args)

# client_models = ["AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA",
#                "HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD",
#                "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ",
#                "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC",
#                "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"]

state_name = ["AK","AZ","AR","CA", "ID","NH","NM","NY"]

if __name__ == '__main__':
    import sys
    if 'ipykernel' in sys.modules:  # Check if running in Jupyter Notebook
        class Args:
            def __init__(self, name_state):
                self.name_state = name_state
    else:
        import argparse
        parser = argparse.ArgumentParser('run_inversion_parser')
        parser.add_argument('--name_state', type=str, default='CA', help='State Code')
        args = parser.parse_args()

    for state in state_name:
        if 'ipykernel' in sys.modules:
            args = Args(name_state=state)  # Create args for each state in the list
        else:
            args.name_state = state  # Modify the argument directly
        main(args)  # Call main with the current state


State Code::  AK
training sample:: AK.data and len is 2000
testing sample:: AK.test and len is 99
0%: ??h ??m ??s          
Pre-trained model loaded.
Training client 1/1
Normal training
n_batches is 8
local_epoch is  5
Original Model - Validation Accuracy: 0.7475, Loss: 0.0049
Original model for client 1 saved to 50_clients_data/clients_trained_model/AK.pth
Original Model Size: 47.0634765625 KB
Quantizing model for client 1




Quantized Model - Validation Accuracy: 0.7475, Loss: 0.0049
Quantized Model Size: 30.5478515625 KB
Quantized model saved to 50_clients_data/clients_trained_model/quantized_AK.pth
Attack on quantized_net
parallelized--OFF 
reconstructions_and_ground_truths is dumped
Performance at 5 Epochs and 250 Batch Size: 66.5% +- 0.00
random_baseline (51.1, 0.0)
Complete                           
State Code::  AZ
training sample:: AZ.data and len is 2000
testing sample:: AZ.test and len is 99
0%: ??h ??m ??s          
Pre-trained model loaded.
Training client 1/1
Normal training
n_batches is 8
local_epoch is  5
Original Model - Validation Accuracy: 0.8384, Loss: 0.0047
Original model for client 1 saved to 50_clients_data/clients_trained_model/AZ.pth
Original Model Size: 47.06640625 KB
Quantizing model for client 1
Quantized Model - Validation Accuracy: 0.8283, Loss: 0.0048
Quantized Model Size: 30.5546875 KB
Quantized model saved to 50_clients_data/clients_trained_model/quantized_AZ.pth
Attack on 

In [None]:
# AK: Performance at 5 Epochs and 250 Batch Size: 66.5% +- 0.00     - from 68.4%
# AZ: Performance at 5 Epochs and 250 Batch Size: 69.9% +- 0.00     - from 70%
# AR: Performance at 5 Epochs and 250 Batch Size: 72.8% +- 0.00     - from 73.08%
# CA: Performance at 5 Epochs and 250 Batch Size: 68.4% +- 0.00      - from 68.98 
# ID: Performance at 5 Epochs and 250 Batch Size: 71.5%  +- 0.00    - from 72.22%
# NH: Performance at 5 Epochs and 250 Batch Size: 71.7% +- 0.00     - from 71.59
# NM: Performance at 5 Epochs and 250 Batch Size: 72.5% +- 0.00     - from 72.11
# NY: Performance at 5 Epochs and 250 Batch Size: 68.8% +- 0.00     - from 67.11



In [None]:
##QAT

# AK: Performance at 5 Epochs and 250 Batch Size: 69.2%/66.7%    - from 73%
# AZ: Performance at 5 Epochs and 250 Batch Size: 66.6% +- 0.00  - from 70%
# AR: Performance at 5 Epochs and 250 Batch Size: 69.2% +- 0.00  - from 73.08%
# CA: Performance at 5 Epochs and 250 Batch Size: 67.1% +- 0.00  - from 68.98 
# ID: Performance at 5 Epochs and 250 Batch Size: 68.5% +- 0.00  - from 72.22%
# NH: Performance at 5 Epochs and 250 Batch Size: 66.3% +- 0.00  - from 71.59
# NM: Performance at 5 Epochs and 250 Batch Size: 66.3% +- 0.00  - from 72.11
# NY: Performance at 5 Epochs and 250 Batch Size: 64.7% +- 0.00 -  from 67.11


In [None]:
# client_models = ["AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA",
#                "HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD",
#                "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ",
#                "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC",
#                "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"]