In [5]:
import os
# import json
import math
import torch
from itertools import product
from dataclasses import dataclass, field

from gamesopt.attacks import Attack
from gamesopt.optimizer import Optimizer
from gamesopt.aggregator import Aggregator
from gamesopt.games import Game, create_matrix, create_bias
from gamesopt.train_distributed import train

# %matplotlib widget
%matplotlib inline
%load_ext autoreload
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
os.environ['MLFLOW_VERBOSE'] = 'True'
os.environ['MLFLOW_EXPERIMENT_NAME'] = os.path.basename(os.getcwd())

In [7]:
# class Config(BaseConfig):
@dataclass
class Config():
    n_iter: int = 6001

    n_peers: int = 20
    n_byzan: int = 4

    # game: Game = Game.Quadratic
    game: Game = field(default_factory=lambda: Game.Quadratic)
    num_samples: int = 1000
    dim: int = 25
    with_bias: bool = True
    mu: float = 1e-1
    ell: float = 1e3

    optimizer: Optimizer = None
    batch_size: int = None
    lr: float = 1/2/ell/15
    alpha: float = None
    sigmaC: float = None

    attack: Attack = None
    n_attacking: int = None
    attack_param: float = None

    use_bucketing: bool = None
    bucketing_s: int = 2

    aggregator: Aggregator = None
    aggregator_param_a: int = 10  # trimmed_mean_b=10; krum_m=2; ,clipping_n_iter=3 clipping_tau=10
    aggregator_param_b: float = 0.1 # rfa_T=10,rfa_nu=0.1    , 0.1)  # rfa_T=10,rfa_nu=0.1

In [8]:
@dataclass
class Data():
    def __init__(self, config):
        self.matrix = create_matrix(config.dim, config.num_samples,
                                    config.mu, config.ell,
                                    with_bias=config.with_bias)
        self.bias = create_bias(config.dim, config.num_samples,
                                with_bias=config.with_bias)
        self.true = torch.linalg.solve(self.matrix.mean(dim=0),
                                       -self.bias.mean(dim=0))
        self.players = self.true + .1/math.sqrt(2 * config.dim)

In [10]:
base_config = Config()
data = Data(base_config)
# base_config_items = set(base_config.__dict__.items())

In [11]:
for at, ap, bs in product([Attack.ALIE, Attack.IPM],
                          [1e-1, 1., 1e1, 1e2],
                          [1, 10, 100]):

    # SGDA
    config = Config(attack=at, attack_param=ap, batch_size=bs)
    config.optimizer = Optimizer.SGDA
    config.aggregator = Aggregator.Mean
    config.n_attacking = config.n_byzan
    # config_items = set(config.__dict__.items())
    os.environ['MLFLOW_RUN_TAG'] = 'RA'
    os.environ['MLFLOW_RUN_NAME'] = 'SGDA'
    os.environ['MLFLOW_RUN_TITLE'] = '%s (%.e), bs=%i' % (at, ap, bs)
    # json.dumps(dict(sorted(config_items - base_config_items)))
    train(config, data)

    # SGDARA
    config = Config(attack=at, attack_param=ap, batch_size=bs)
    config.optimizer = Optimizer.SGDARA
    config.aggregator = Aggregator.RFA
    config.use_bucketing = True
    config.n_attacking = config.n_byzan
    os.environ['MLFLOW_RUN_TAG'] = 'RA'
    os.environ['MLFLOW_RUN_NAME'] = 'SGDARA'
    os.environ['MLFLOW_RUN_TITLE'] = '%s (%.e), bs=%i' % (at, ap, bs)
    train(config, data)

    # MSGDARA
    config = Config(attack=at, attack_param=ap, batch_size=bs)
    config.optimizer = Optimizer.MSGDARA
    config.alpha = 0.1
    config.aggregator = Aggregator.RFA
    config.use_bucketing = True
    config.n_attacking = config.n_byzan
    os.environ['MLFLOW_RUN_TAG'] = 'RA M'
    os.environ['MLFLOW_RUN_NAME'] = 'MSGDARA'
    os.environ['MLFLOW_RUN_TITLE'] = '%s (%.e), bs=%i' % (at, ap, bs)
    train(config, data)

    # SEGRA
    config = Config(attack=at, attack_param=ap, batch_size=bs)
    config.n_iter = int(config.n_iter / 2) + 1
    config.optimizer = Optimizer.SEGRA
    config.aggregator = Aggregator.RFA
    config.use_bucketing = True
    config.n_attacking = config.n_byzan
    os.environ['MLFLOW_RUN_TAG'] = 'RA'
    os.environ['MLFLOW_RUN_NAME'] = 'SEGRA'
    os.environ['MLFLOW_RUN_TITLE'] = '%s (%.e), bs=%i' % (at, ap, bs)
    train(config, data)

    # RDEG
    config = Config(attack=at, attack_param=ap, batch_size=bs)
    config.n_iter = int(config.n_iter / 2) + 1
    config.optimizer = Optimizer.RDEG
    config.aggregator = Aggregator.UnivariateTM
    config.n_attacking = config.n_byzan
    os.environ['MLFLOW_RUN_TAG'] = 'RA'
    os.environ['MLFLOW_RUN_NAME'] = 'RDEG'
    os.environ['MLFLOW_RUN_TITLE'] = '%s (%.e), bs=%i' % (at, ap, bs)
    train(config, data)

    # SGDACC
    config = Config(attack=at, attack_param=ap, batch_size=bs)
    config.n_iter = config.n_iter * 2 - 1
    config.optimizer = Optimizer.SGDACC
    config.aggregator = Aggregator.Mean
    config.sigmaC = 100.
    config.n_attacking = 1
    os.environ['MLFLOW_RUN_TAG'] = 'CC'
    os.environ['MLFLOW_RUN_NAME'] = 'SGDACC'
    os.environ['MLFLOW_RUN_TITLE'] = '%s (%.e), bs=%i' % (at, ap, bs)
    train(config, data)

    # SEGCC
    config = Config(attack=at, attack_param=ap, batch_size=bs)
    config.n_iter = config.n_iter + 1
    config.optimizer = Optimizer.SEGCC
    config.aggregator = Aggregator.Mean
    config.sigmaC = 100.
    config.n_attacking = 1
    os.environ['MLFLOW_RUN_TAG'] = 'CC'
    os.environ['MLFLOW_RUN_NAME'] = 'SEGCC'
    os.environ['MLFLOW_RUN_TITLE'] = '%s (%.e), bs=%i' % (at, ap, bs)
    train(config, data)


Trying port 7790
Trying port 32783
Trying port 21936
Trying port 46120
Trying port 21409
Trying port 41842
Trying port 45745
Trying port 6206
Trying port 5804
Trying port 41249
Trying port 27602
Trying port 3708
Trying port 10957
Trying port 32721
Trying port 42987
Trying port 17175
Trying port 31842
Trying port 21657
Trying port 26068
Trying port 40150
Trying port 29681
Trying port 36947
Trying port 42375
Trying port 48811
Trying port 2712
Trying port 15115
Trying port 38492
Trying port 23223
Trying port 24181
Trying port 15822
Trying port 41819
Trying port 42664
Trying port 22363
Trying port 16565
Trying port 44063
Trying port 20320
Trying port 17814
Trying port 25200
Trying port 5340
Trying port 23298
Trying port 38815
Trying port 26703
Trying port 33020
Trying port 44238
Trying port 34662
Trying port 41658
Trying port 16384
Trying port 2581
Trying port 37324
Trying port 26492
Trying port 7367
Trying port 7923
Trying port 4192
Trying port 11747
Trying port 42883
Trying port 9440
Try

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Trying port 12490
Trying port 1887
Trying port 4575
Trying port 18696
Trying port 22285
Trying port 3328
Trying port 16931
Trying port 47634
Trying port 48701
Trying port 6260
Trying port 31018
Trying port 31426
Trying port 26292
Trying port 27145
Trying port 35811
Trying port 40763
Trying port 29602
Trying port 11242
Trying port 10360
Trying port 28013
Trying port 29717
Trying port 29960
Trying port 25270
Trying port 43189
Trying port 10268
Trying port 37189
Trying port 31707
Trying port 24865
Trying port 39275
Trying port 46101
Trying port 45040
Trying port 31937
Trying port 27761
Trying port 33969
Trying port 33627
Trying port 1937
Trying port 44500
Trying port 3172
Trying port 10174
Trying port 13808
Trying port 44288
Trying port 35283
Trying port 44517
Trying port 43157
Trying port 36618
Trying port 37629
Trying port 10332
Trying port 7615
Trying port 18929
Trying port 38133
Trying port 41313
Trying port 13732
Trying port 32273
Trying port 6471
Trying port 37657
Trying port 38599
