In [1]:
import os
os.chdir('..')

In [2]:
from comet_ml import Experiment
import torch as th

In [3]:
len(th.optim.Optimizer.__subclasses__())

11

In [4]:
def all_subclasses(cls):
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in all_subclasses(c)])

In [5]:
len(all_subclasses(th.optim.Optimizer))

11

In [5]:
from FLF.TorchOptRepo import TorchOptRepo

In [7]:
{opt: TorchOptRepo.supported_parameters(opt) for opt in TorchOptRepo.get_opt_names()}

{'Adadelta': ['lr', 'rho', 'eps', 'weight_decay'],
 'Adagrad': ['lr',
  'lr_decay',
  'weight_decay',
  'initial_accumulator_value',
  'eps',
  'group',
  'p',
  'state'],
 'Adam': ['lr', 'betas', 'eps', 'weight_decay', 'amsgrad'],
 'AdamW': ['lr', 'betas', 'eps', 'weight_decay', 'amsgrad'],
 'Adamax': ['lr', 'betas', 'eps', 'weight_decay'],
 'ASGD': ['lr', 'lambd', 'alpha', 't0', 'weight_decay'],
 'SGD': ['lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'],
 'Rprop': ['lr', 'etas', 'step_sizes'],
 'RMSprop': ['lr', 'alpha', 'eps', 'weight_decay', 'momentum', 'centered']}

In [3]:
from FLF.TorchFederatedLearnerMNIST import Net

In [4]:
model = Net()

In [6]:
opt = TorchOptRepo.name2cls('Adam')(model.parameters())

In [7]:
dir(opt)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'add_param_group',
 'defaults',
 'load_state_dict',
 'param_groups',
 'state',
 'state_dict',
 'step',
 'zero_grad']

In [8]:
import logging
from FLF.TorchFederatedLearnerMNIST import (
    TorchFederatedLearnerMNIST,
    TorchFederatedLearnerMNISTConfig,
)
from FLF.TorchOptRepo import TorchOptRepo

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)


def train_fakes(opt):
    C = 1
    NC = 2
    E = 1
    B = 50
    is_iid = False
    lr = 0.000001
    name = f"fake"

    logging.info(name)
    experiment = Experiment(workspace="federated-learning", project_name="2C_opt_tmp")
    experiment.set_name(name)
    experiment.log_parameter("opt_srategy", "reinit")
    # TODO a paraméterek helytelen nevére nem adott hibát
    config = TorchFederatedLearnerMNISTConfig(
        LEARNING_RATE=lr,
        OPT=opt,
        IS_IID_DATA=is_iid,
        BATCH_SIZE=B,
        CLIENT_FRACTION=C,
        N_CLIENTS=NC,
        N_EPOCH_PER_CLIENT=E,
        MAX_ROUNDS=1,
    )
    learner = TorchFederatedLearnerMNIST(experiment, config)
    learner.train()
    return learner

learners = {opt: train_fakes(opt) for opt in TorchOptRepo.get_opt_names()[:3]}

2020-10-10 08:39:48 INFO     fake
COMET INFO: old comet version (3.1.13) detected. current: 3.2.3 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/federated-learning/2c-opt-tmp/6cc4d2631acf49f087d4d37664445d7f



TypeError: __init__() missing 1 required positional argument: 'config_technical'

In [19]:
learners

{'Adadelta': <FLF.TorchFederatedLearnerMNIST.TorchFederatedLearnerMNIST at 0x7f4b1a07a950>,
 'Adagrad': <FLF.TorchFederatedLearnerMNIST.TorchFederatedLearnerMNIST at 0x7f4b189497d0>,
 'Adam': <FLF.TorchFederatedLearnerMNIST.TorchFederatedLearnerMNIST at 0x7f4b192648d0>}

In [22]:
client_sample = learners['Adam'].clients
client_opt_states = [client.get_opt_state() for client in client_sample]

In [34]:
from syftutils.multipointer import avg_model_state_dicts

In [61]:
def avg_model_state_dicts(state_dicts):
    final_state_dict = {}
    print(len(state_dicts))
    with th.no_grad():
        for parameter_name in state_dicts[0].keys():
            print(parameter_name)
            if parameter_name =='step':
                final_state_dict[parameter_name] = state_dicts[0][parameter_name]
                continue
                
            final_state_dict[parameter_name] = th.mean(
                th.stack(
                    [
                        model_parameters[parameter_name]
                        for model_parameters in state_dicts
                    ]
                ),
                dim=0,
            )
    return final_state_dict

In [77]:
avg_model_state_dicts(list(zip(*client_opt_states))[0])

2
step
exp_avg
exp_avg_sq


{'step': 600,
 'exp_avg': tensor([[[[ 8.5307e-04, -1.0999e-03, -7.8328e-04,  1.3258e-03,  1.1865e-03],
           [-4.3276e-03, -5.0802e-03, -1.7018e-03,  6.9346e-04, -8.3694e-04],
           [-8.0066e-03, -9.9915e-03, -1.0106e-02, -8.1569e-03, -9.3344e-03],
           [-1.4537e-02, -1.9543e-02, -2.0863e-02, -1.9990e-02, -1.7010e-02],
           [-1.5216e-02, -1.9640e-02, -2.1702e-02, -1.9934e-02, -1.5097e-02]]],
 
 
         [[[ 3.6672e-03,  2.8733e-03, -1.7224e-03, -1.1769e-02, -2.3535e-02],
           [ 3.3332e-03,  1.4517e-03, -6.0933e-03, -2.0205e-02, -2.5510e-02],
           [ 6.8041e-04, -2.3003e-03, -1.0315e-02, -2.1203e-02, -2.1761e-02],
           [-2.3686e-03, -4.9392e-03, -1.0044e-02, -1.6146e-02, -1.5651e-02],
           [-3.1645e-03, -4.8171e-03, -8.1654e-03, -1.1326e-02, -9.3288e-03]]],
 
 
         [[[-1.1299e-02, -8.2992e-03, -7.3182e-03, -5.5364e-03, -7.7314e-03],
           [-8.9774e-03, -7.3191e-03, -8.1185e-03, -9.0809e-03, -8.9165e-03],
           [-5.1738e-03, -7

In [76]:
type(client_opt_states), type(client_opt_states[0]), type(list(zip(*client_opt_states))[0][0])

(list, dict_values, dict)

In [70]:
len(list(zip(*client_opt_states))[0]), len(client_opt_states)

(2, 2)

In [59]:
[[list(l)[0] for l in a] for a in zip(client_opt_states)][0]

[{'step': 600,
  'exp_avg': tensor([[[[ 2.7799e-03, -3.2033e-03, -5.8579e-03, -3.9240e-03, -2.9226e-03],
            [-4.8700e-03, -9.9199e-03, -9.4049e-03, -6.5694e-03, -7.2308e-03],
            [-1.1720e-02, -1.6525e-02, -1.9153e-02, -1.5167e-02, -1.6411e-02],
            [-2.0409e-02, -2.7605e-02, -2.8884e-02, -2.7229e-02, -2.3060e-02],
            [-2.3655e-02, -2.8551e-02, -2.9272e-02, -2.6844e-02, -2.0762e-02]]],
  
  
          [[[ 3.7848e-03,  3.1216e-03, -1.9664e-03, -1.2930e-02, -2.4253e-02],
            [ 3.3882e-03,  2.5533e-03, -6.0769e-03, -2.1581e-02, -2.5955e-02],
            [ 4.0254e-04, -1.6271e-03, -1.1039e-02, -2.3318e-02, -2.3015e-02],
            [-3.7343e-03, -6.0174e-03, -1.2283e-02, -1.9032e-02, -1.6045e-02],
            [-5.1996e-03, -8.1498e-03, -1.1596e-02, -1.3421e-02, -8.9791e-03]]],
  
  
          [[[-4.9190e-03, -5.0532e-03, -7.5589e-03, -7.0996e-03, -8.4855e-03],
            [-5.1049e-03, -7.6112e-03, -9.9251e-03, -1.1236e-02, -1.0163e-02],
          

In [50]:
list(client_opt_states[0])[0]['exp_avg']

tensor([[[[ 2.7799e-03, -3.2033e-03, -5.8579e-03, -3.9240e-03, -2.9226e-03],
          [-4.8700e-03, -9.9199e-03, -9.4049e-03, -6.5694e-03, -7.2308e-03],
          [-1.1720e-02, -1.6525e-02, -1.9153e-02, -1.5167e-02, -1.6411e-02],
          [-2.0409e-02, -2.7605e-02, -2.8884e-02, -2.7229e-02, -2.3060e-02],
          [-2.3655e-02, -2.8551e-02, -2.9272e-02, -2.6844e-02, -2.0762e-02]]],


        [[[ 3.7848e-03,  3.1216e-03, -1.9664e-03, -1.2930e-02, -2.4253e-02],
          [ 3.3882e-03,  2.5533e-03, -6.0769e-03, -2.1581e-02, -2.5955e-02],
          [ 4.0254e-04, -1.6271e-03, -1.1039e-02, -2.3318e-02, -2.3015e-02],
          [-3.7343e-03, -6.0174e-03, -1.2283e-02, -1.9032e-02, -1.6045e-02],
          [-5.1996e-03, -8.1498e-03, -1.1596e-02, -1.3421e-02, -8.9791e-03]]],


        [[[-4.9190e-03, -5.0532e-03, -7.5589e-03, -7.0996e-03, -8.4855e-03],
          [-5.1049e-03, -7.6112e-03, -9.9251e-03, -1.1236e-02, -1.0163e-02],
          [-5.8421e-03, -1.0742e-02, -1.7306e-02, -1.6521e-02, -9.02

In [None]:
new_opt = TorchOptRepo.name2cls("Adam")(learners["Adam"].model.parameters())

In [None]:
new_opt.state_dict()

In [None]:
new_state_dict = new_opt.state_dict()#['param_groups'][0]['params']

In [None]:
new_state_dict['state'].update(zip(new_state_dict['param_groups'][0]['params'], learners["Adam"].clients[0].opt.state_dict()['state'].values()))

In [None]:
new_opt.load_state_dict(new_state_dict)

In [None]:
new_opt.state_dict()

In [None]:
{opt: learner.clients[0].opt.state_dict()['state'].values() for opt, learner in learners.items()}#, learners['SGD'].clients[1].opt.state_dict()

In [None]:
{opt: learner.clients[0].opt.state_dict()['param_groups'] for opt, learner in learners.items()}#, learners['SGD'].clients[1].opt.state_dict()

In [None]:
[id(x) for x in learners['Adam'].clients[0].model.parameters()]#, [id(x) for x in learners['Adam'].clients[1].model.parameters()]

In [None]:
[x.shape for x in learners['Adam'].clients[0].model.state_dict().values()], [x.shape for x in learners['Adam'].clients[1].model.state_dict().values()]