In [1]:
import torch
import pickle

from typing import Dict, List, Optional, Tuple, Any
from collections import OrderedDict
import numpy as np
import numpy.typing as npt
from functools import reduce

NDArray = npt.NDArray[Any]
NDArrayInt = npt.NDArray[np.int_]
NDArrayFloat = npt.NDArray[np.float_]
NDArrays = List[NDArray]

def fedavg_aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
    """Compute weighted average."""
    # Calculate the total number of examples used during training
    num_examples_total = sum([num_examples for _, num_examples in results])

    # Create a list of weights, each multiplied by the related number of examples
    weighted_weights = [
        [layer * num_examples for layer in weights] for weights, num_examples in results
    ]

    # Compute average weights of each layer
    weights_prime: NDArrays = [
        reduce(np.add, layer_updates) / num_examples_total
        for layer_updates in zip(*weighted_weights)
    ]
    return weights_prime

with open('database/server_database.pkl', 'rb') as file:
    data = pickle.load(file)

    
client_model_record = data['client_model_record']

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from io import BytesIO
from dataclasses import dataclass
from typing import cast

@dataclass
class Parameters:
    """Model parameters."""

    tensors: List[bytes]
    tensor_type: str

def ndarray_to_bytes(ndarray: NDArray) -> bytes:
    """Serialize NumPy ndarray to bytes."""
    bytes_io = BytesIO()
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
    np.save(bytes_io, ndarray, allow_pickle=False)
    return bytes_io.getvalue()

def bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize NumPy ndarray from bytes."""
    bytes_io = BytesIO(tensor)
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
    ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
    return cast(NDArray, ndarray_deserialized)

def ndarrays_to_parameters(ndarrays: NDArrays) -> Parameters:
    """Convert NumPy ndarrays to parameters object."""
    tensors = [ndarray_to_bytes(ndarray) for ndarray in ndarrays]
    return Parameters(tensors=tensors, tensor_type="numpy.ndarray")

def parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [bytes_to_ndarray(tensor) for tensor in parameters.tensors]

client_results = [(client_model_record[uid],1) for uid in client_model_record.keys()]

agged_params = parameters_to_ndarrays(ndarrays_to_parameters(fedavg_aggregate(client_results)))

In [5]:
print(type(agged_params))

<class 'list'>


In [6]:
import cifar10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v.astype(float)) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)
    
net = cifar10.load_model().to(DEVICE)
set_parameters(net, agged_params)  # Update model with the latest parameters
test_loader = cifar10.load_test_data()
loss, accuracy = cifar10.test(net, test_loader, DEVICE)

Test:  Total time: 0:00:07
Test: test_acc1=10.016, test_acc5=50.020, test_loss=0.094603, samples/s=1336.441


In [13]:
params_dict = zip(net.state_dict().keys(), agged_params, client_results[0][0])

for k, v1, v2 in params_dict:
    try:
        x = k, torch.Tensor(v1.astype(float))
    except Exception as e:
        # print(f"Error: {e}")
        print(k, v1, type(v1), v2, type(v2))

TypeError: 'Parameters' object is not iterable

In [26]:
print(len(client_results[0][0]))
print(len(agged_params))
for (a, b) in zip(agged_params, client_results[0][0]):
    if not a.shape == b.shape:
        print("Shoot")

62
62


In [28]:
print(client_results[0][0])

[array([[[[-0.03193101, -0.04681609, -0.08554577],
         [ 0.04826915, -0.10189971,  0.00575278],
         [ 0.05455356,  0.05608114, -0.00593127]],

        [[ 0.00157055, -0.01169919, -0.03589138],
         [ 0.1343464 ,  0.00608538,  0.00397135],
         [-0.02272212,  0.0255009 ,  0.07078151]],

        [[ 0.00760836,  0.04291692,  0.00889237],
         [-0.03915408,  0.05053937, -0.02259889],
         [ 0.05619406,  0.01907415, -0.09961216]]],


       [[[-0.08808386,  0.00061832,  0.06730109],
         [-0.08916007,  0.09640606,  0.03528554],
         [-0.01792393,  0.01245996, -0.01472389]],

        [[-0.09512833,  0.00654377, -0.02469591],
         [-0.03000345, -0.08427338, -0.07946981],
         [-0.0580694 , -0.11525851,  0.10477521]],

        [[-0.0827251 , -0.0659014 ,  0.03870768],
         [-0.03071537,  0.14246555,  0.03732529],
         [ 0.04258449, -0.03963811,  0.06814142]]],


       [[[-0.06990564, -0.05646835, -0.01391071],
         [-0.0253953 , -0.0744684

In [29]:
print(agged_params)

[array([[[[-3.73555385e-02, -5.10356762e-02, -8.74608830e-02],
         [ 4.11670543e-02, -1.08922265e-01,  3.67169199e-03],
         [ 4.82296906e-02,  4.97252606e-02, -9.26045328e-03]],

        [[-6.89759268e-04, -1.07384315e-02, -3.13703604e-02],
         [ 1.30725145e-01,  3.21049849e-03,  5.62964659e-03],
         [-2.48031635e-02,  2.34678350e-02,  7.07523823e-02]],

        [[ 7.10816728e-03,  4.60211560e-02,  1.41827185e-02],
         [-4.43727523e-02,  4.74103987e-02, -2.24622581e-02],
         [ 5.23091555e-02,  1.65270250e-02, -1.02627754e-01]]],


       [[[-8.32807273e-02,  3.48277343e-03,  7.30921924e-02],
         [-8.31925124e-02,  9.66393054e-02,  4.13115919e-02],
         [-1.85209718e-02,  1.00266105e-02, -1.05716866e-02]],

        [[-9.16458964e-02,  9.25024692e-03, -2.09117476e-02],
         [-2.61753350e-02, -8.16411600e-02, -7.32468143e-02],
         [-5.85471541e-02, -1.14945695e-01,  1.09984778e-01]],

        [[-7.93888792e-02, -6.26898408e-02,  4.24771905e-