In [133]:
#!./.mnist-pytorch/bin/python
import numpy as np
import os
import collections
import json

import docker
import fire
import torch

from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics

HELPER_MODULE = 'numpyhelper'
helper = get_helper(HELPER_MODULE)

NUM_CLASSES = 10

In [15]:
def compile_model():
    """ Compile the pytorch model.

    :return: The compiled model.
    :rtype: torch.nn.Module
    """
    torch.manual_seed(42)

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = torch.nn.Linear(784, 64)
            self.fc2 = torch.nn.Linear(64, 32)
            self.fc3 = torch.nn.Linear(32, 10)

        def forward(self, x):
            x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
            x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
            x = torch.nn.functional.relu(self.fc2(x))
            x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
            return x

    return Net()

def load_parameters(model_path):
    """ Load model parameters from file and populate model.

    param model_path: The path to load from.
    :type model_path: str
    :return: The loaded model.
    :rtype: torch.nn.Module
    """
    model = compile_model()
    parameters_np = helper.load(model_path)

    params_dict = zip(model.state_dict().keys(), parameters_np)
    state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
    model.load_state_dict(state_dict, strict=True)
    return model

In [79]:
model_ids = [int(x.split(sep='.')[0]) for x in os.listdir("../parameter_store/models/")]
model_ids

[1717633814,
 1717633860,
 1717633675,
 1717633721,
 1717633537,
 1717633583,
 1717633767,
 1717633906,
 1717633629]

In [80]:
model_count = len(model_ids)
model_count

9

In [81]:
(model_count != 0 and model_count % 3 == 0)

True

In [82]:
latest_model = load_parameters(model_path = f"../parameter_store/models/{model_ids[model_count - 1]}.npz")
reference_model = load_parameters(model_path = f"../parameter_store/models/{model_ids[model_count - 3]}.npz")

In [71]:
[param for param in latest_model.parameters()]

[Parameter containing:
 tensor([[ 0.0273,  0.0296, -0.0084,  ..., -0.0142,  0.0093,  0.0135],
         [-0.0188, -0.0354,  0.0187,  ..., -0.0106, -0.0001,  0.0115],
         [-0.0008,  0.0017,  0.0045,  ..., -0.0127, -0.0188,  0.0059],
         ...,
         [-0.0141, -0.0243,  0.0114,  ...,  0.0268, -0.0158,  0.0333],
         [ 0.0172, -0.0045,  0.0321,  ...,  0.0190,  0.0077, -0.0114],
         [ 0.0308, -0.0262,  0.0271,  ...,  0.0297, -0.0043,  0.0139]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0322, -0.0146,  0.0050,  0.0400,  0.0590,  0.0200,  0.0563,  0.0096,
          0.0218,  0.0055, -0.0109, -0.0203,  0.0461,  0.0405,  0.0025,  0.0048,
          0.0062, -0.0074, -0.0050,  0.0118, -0.0132, -0.0062,  0.0362,  0.0507,
          0.0011,  0.0011,  0.0326, -0.0158,  0.0027, -0.0041, -0.0211,  0.0280,
         -0.0037,  0.0208,  0.0319, -0.0069,  0.0055,  0.0445, -0.0339,  0.0228,
          0.0110, -0.0008,  0.0320,  0.0073,  0.0282,  0.0072,  0.0238,  0.0320

In [83]:
latest_model_parameters_np = helper.load(f"../parameter_store/models/{model_ids[model_count - 1]}.npz")
reference_model_parameters_np = helper.load(f"../parameter_store/models/{model_ids[model_count - 3]}.npz")

In [132]:
updated_model = []

for i in range(6):
    updated_model.append(latest_model_parameters_np[i] - reference_model_parameters_np[i])

updated_model[0][1]

array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        1.64471567e-06,  3.55145894e-05,  3.55150551e-05,  1.47894025e-06,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  9.98200849e-05,  2.33381987e-04,
        1.49566680e-04,  2.53878534e-04,  4.81236726e-04,  5.05007803e-04,
        4.69768420e-04,  4.90930397e-04,  7.14605674e-04,  6.89825043e-04,
        5.86127862e-04,  4.53555956e-04,  2.67783180e-04,  2.02784780e-04,
        1.96289271e-04,  7.25183636e-05,  6.33150339e-05,  1.38040632e-05,
        0.00000000e+00,  

In [126]:
reference_model_parameters_np[0]

array([[ 0.02730495,  0.02964314, -0.00836688, ..., -0.01415538,
         0.00929471,  0.0134693 ],
       [-0.01881211, -0.03541354,  0.01869409, ..., -0.01060584,
        -0.0001305 ,  0.01146434],
       [-0.0007847 ,  0.00165338,  0.00452454, ..., -0.01267597,
        -0.01883879,  0.0058522 ],
       ...,
       [-0.01414198, -0.02433132,  0.01139325, ...,  0.02677841,
        -0.01575457,  0.03333875],
       [ 0.01716694, -0.00450961,  0.03212668, ...,  0.01900063,
         0.00766945, -0.01143111],
       [ 0.03075903, -0.02617348,  0.02706156, ...,  0.02968384,
        -0.00425721,  0.01390254]], dtype=float32)

In [130]:
(latest_model_parameters_np[0] - reference_model_parameters_np[0])[1]

array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        1.64471567e-06,  3.55145894e-05,  3.55150551e-05,  1.47894025e-06,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  9.98200849e-05,  2.33381987e-04,
        1.49566680e-04,  2.53878534e-04,  4.81236726e-04,  5.05007803e-04,
        4.69768420e-04,  4.90930397e-04,  7.14605674e-04,  6.89825043e-04,
        5.86127862e-04,  4.53555956e-04,  2.67783180e-04,  2.02784780e-04,
        1.96289271e-04,  7.25183636e-05,  6.33150339e-05,  1.38040632e-05,
        0.00000000e+00,  

In [139]:
with open('../parameter_store/client_counts.json', 'r') as json_file:
    counts = json.load(json_file)
    ben_count = counts['ben_count']
    mal_count = counts['mal_count']

ben_count = 10
mal_count = 3

int(ben_count / mal_count)

3