In [2]:
import os
os.chdir('..')

In [None]:
import torch as th

In [None]:
len(th.optim.Optimizer.__subclasses__())

In [None]:
def all_subclasses(cls):
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in all_subclasses(c)])

In [None]:
len(all_subclasses(th.optim.Optimizer))

In [None]:
from FLF.TorchOptRepo import TorchOptRepo

In [None]:
{opt: TorchOptRepo.supported_parameters(opt) for opt in TorchOptRepo.get_opt_names()}

In [None]:
from FLF.TorchFederatedLearnerMNIST import Net

In [None]:
model = Net()

In [None]:
opt = TorchOptRepo.name2cls('Adam')(model.parameters())

In [None]:
dir(opt)

In [None]:
from comet_ml import Experiment
import logging
from FLF.TorchFederatedLearnerMNIST import (
    TorchFederatedLearnerMNIST,
    TorchFederatedLearnerMNISTConfig,
)
from FLF.TorchOptRepo import TorchOptRepo

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)


def train_fakes(opt):
    C = 1
    NC = 2
    E = 1
    B = 50
    is_iid = False
    lr = 0.000001
    name = f"fake"

    logging.info(name)
    experiment = Experiment(workspace="federated-learning", project_name="2C_opt")
    experiment.set_name(name)
    experiment.log_parameter("opt_srategy", "reinit")
    # TODO a paraméterek helytelen nevére nem adott hibát
    config = TorchFederatedLearnerMNISTConfig(
        LEARNING_RATE=lr,
        OPT=opt,
        IS_IID_DATA=is_iid,
        BATCH_SIZE=B,
        CLIENT_FRACTION=C,
        N_CLIENTS=NC,
        N_EPOCH_PER_CLIENT=E,
        MAX_ROUNDS=1,
    )
    learner = TorchFederatedLearnerMNIST(experiment, config)
    learner.train()
    return learner

learners = {opt: train_fakes(opt) for opt in TorchOptRepo.get_opt_names()}

In [None]:
learner.clients[0].opt.state_dict()['state']

In [99]:
{opt: learner.clients[0].opt.state_dict()['state'].keys() for opt, learner in learners.items()}#, learners['SGD'].clients[1].opt.state_dict()

{'Adadelta': dict_keys([140520088536160, 140520088536000, 140520091289872, 140520088538400, 140522572013072, 140520088134832, 140520088555024, 140520088555264]),
 'Adagrad': dict_keys([140520087621296, 140520087617936, 140520087738640, 140520088029296, 140520088029856, 140520088030976, 140520088028656, 140520088030816]),
 'Adam': dict_keys([140520058870272, 140520058870832, 140520058867872, 140520058870352, 140520058868592, 140520058871312, 140521900644224, 140520058828832]),
 'AdamW': dict_keys([140520059851472, 140520059853952, 140520059853232, 140520086758576, 140520058892768, 140520086758656, 140520086758176, 140520086757616]),
 'Adamax': dict_keys([140520058819040, 140520058822560, 140520058820080, 140520058820400, 140520058820880, 140520058941680, 140520088252000, 140520088248800]),
 'ASGD': dict_keys([140520058016384, 140520058017184, 140520058016784, 140520058016304, 140520058017344, 140520058017264, 140520059203904, 140520058469424]),
 'SGD': dict_keys([]),
 'Rprop': dict_keys

In [98]:
{opt: learner.clients[0].opt.state_dict()['param_groups'] for opt, learner in learners.items()}#, learners['SGD'].clients[1].opt.state_dict()

{'Adadelta': [{'lr': 1e-06,
   'rho': 0.9,
   'eps': 1e-06,
   'weight_decay': 0,
   'params': [140520088536160,
    140520088536000,
    140520091289872,
    140520088538400,
    140522572013072,
    140520088134832,
    140520088555024,
    140520088555264]}],
 'Adagrad': [{'lr': 1e-06,
   'lr_decay': 0,
   'eps': 1e-10,
   'weight_decay': 0,
   'initial_accumulator_value': 0,
   'params': [140520087621296,
    140520087617936,
    140520087738640,
    140520088029296,
    140520088029856,
    140520088030976,
    140520088028656,
    140520088030816]}],
 'Adam': [{'lr': 1e-06,
   'betas': (0.9, 0.999),
   'eps': 1e-08,
   'weight_decay': 0,
   'amsgrad': False,
   'params': [140520058870272,
    140520058870832,
    140520058867872,
    140520058870352,
    140520058868592,
    140520058871312,
    140521900644224,
    140520058828832]}],
 'AdamW': [{'lr': 1e-06,
   'betas': (0.9, 0.999),
   'eps': 1e-08,
   'weight_decay': 0.01,
   'amsgrad': False,
   'params': [140520059851472,
 

In [81]:
[id(x) for x in learners['Adam'].clients[0].model.parameters()]#, [id(x) for x in learners['Adam'].clients[1].model.parameters()]

([140520058870272,
  140520058870832,
  140520058867872,
  140520058870352,
  140520058868592,
  140520058871312,
  140521900644224,
  140520058828832],
 [140520087972752,
  140520089115616,
  140520089115216,
  140520464606944,
  140520087751888,
  140520087748688,
  140520087751408,
  140520087749328])

In [80]:
[x.shape for x in learners['Adam'].clients[0].model.state_dict().values()], [x.shape for x in learners['Adam'].clients[1].model.state_dict().values()]

([torch.Size([32, 1, 5, 5]),
  torch.Size([32]),
  torch.Size([64, 32, 5, 5]),
  torch.Size([64]),
  torch.Size([512, 1024]),
  torch.Size([512]),
  torch.Size([10, 512]),
  torch.Size([10])],
 [torch.Size([32, 1, 5, 5]),
  torch.Size([32]),
  torch.Size([64, 32, 5, 5]),
  torch.Size([64]),
  torch.Size([512, 1024]),
  torch.Size([512]),
  torch.Size([10, 512]),
  torch.Size([10])])