In [1]:
import os
import math
import torch
from dataclasses import dataclass, field

from code.train import train
from code.optimizers import Optimizer
# from code.problems import Problem
from code.problem import Loss
from code.datasets import Dataset
from code.models import Model

# %matplotlib widget
%load_ext autoreload
%autoreload 1

In [2]:
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

In [3]:
os.environ['MLFLOW_VERBOSE'] = 'True'
os.environ['MLFLOW_EXPERIMENT_NAME'] = os.path.basename(os.getcwd())

In [4]:
os.path.basename(os.getcwd())

'MeritFed-M'

In [100]:
@dataclass
class Config():
    n_iters: int = 5000
    n_peers: int = 20
    seed: int = 0

    model: Model = field(default_factory=lambda: Model.Linear)
    loss: Loss = field(default_factory=lambda: Loss.CrossEntropy)

    dataset: Dataset = field(default_factory=lambda: Dataset.MNIST)

    
    # model: Model = field(default_factory=lambda: Model.Mean)
    # loss: Loss = field(default_factory=lambda: Loss.MSE)

    # dataset: Dataset = field(default_factory=lambda: Dataset.Normal)
    n_samples: int = 500
    h_ratio: float = 0.99
    mu_normal: float = None

    optimizer: Optimizer = field(default_factory=lambda: Optimizer.SGD)
    batch_size: int = 10
    lr: float = 1e-2

    true_weights: bool = None

    md_n_iters_: int = None
    md_full_: bool = None
    md_lr_: int = None

In [105]:
config = Config()
config.optimizer = Optimizer.SGD
config.true_weights = True
os.environ['MLFLOW_RUN_NAME'] = config.optimizer.name
%time train(config)

Trying port 25065
test len  500
8011
full test len  8011
full test len  8011
calc3 9
calc1 438
calc2 62
actual3 9
actual1 438
actual2 62
CPU times: user 67.7 ms, sys: 69 ms, total: 137 ms
Wall time: 26.9 s


In [102]:
config = Config()
config.optimizer = Optimizer.MeritFed
config.md_full_ = True
config.md_n_iters_ = 20
config.md_lr_ = 0.05
os.environ['MLFLOW_RUN_NAME'] = config.optimizer.name
%time train(config)

Trying port 5393
test len  500
8011
full test len  8011
full test len  8011
actual3 9
calc3 9
CPU times: user 115 ms, sys: 79.5 ms, total: 194 ms
Wall time: 2min 1s


In [101]:
config = Config()
config.optimizer = Optimizer.SGD
config.true_weights = False
os.environ['MLFLOW_RUN_NAME'] = config.optimizer.name
%time train(config)

Trying port 10946
100
CPU times: user 62.1 ms, sys: 72.3 ms, total: 134 ms
Wall time: 44.6 s


In [31]:
config = Config()
config.optimizer = Optimizer.MeritFed
config.md_full_ = False
config.md_n_iters_ = 20
config.md_lr_ = 0.05
os.environ['MLFLOW_RUN_NAME'] = config.optimizer.name
%time train(config)

Trying port 46002
2 3
3 3
650
full test len  1000
full test len  1000
CPU times: user 130 ms, sys: 50.8 ms, total: 181 ms
Wall time: 48.1 s


In [173]:
class MNIST(datasets.MNIST):
    def __init__(self, config, rank, train=True):
        root = '/tmp'
        self.root = root
        if config.n_peers and not rank and self.download:
            self.download()

        # torch.distributed.barrier()
        super().__init__(root=root, train=train, transform=transforms.ToTensor(), download=False)

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

        self.data, self.targets = self._load_data()
        
        if train is False:
            if rank:
                raise RuntimeError("Non-master client accessed test dataset")
            mask = self.targets == 1
            indices = torch.nonzero(mask).squeeze()
            indices = indices[:config.n_samples]
            mask = torch.zeros_like(mask).scatter_(0, indices, 1)
            self.targets = self.targets[mask].float()
            self.data = self.data[mask]
            return

        self.output_dim = len(self.classes)
        self.input_dim = len(self.data[0].view(-1))
        # print(self.input_dim)

        target_rank_below = 1
        near_target_rank_below = 10
        self.true_weights = torch.zeros(config.n_peers)
        self.true_weights[:target_rank_below] = 1 / target_rank_below
        mask = self.targets == 1
        if rank < target_rank_below:
            indices = torch.nonzero(mask).squeeze()
            n = target_rank_below*config.n_samples
            if n > len(indices):
                raise ValueError('target_rank_below*n_samples too big')
            per_worker = config.n_samples
            beg = rank * per_worker
            end = beg + per_worker
            if end > len(indices) - 1:
                raise ValueError('invalid partitioning')
            indices = indices[beg:end]
            # print(beg, end)
            # print(indices)
            # return dataset
        elif target_rank_below <= rank and rank < near_target_rank_below:
            target_ratio = config.h_ratio
            indices = torch.nonzero(mask).squeeze()
            n = target_rank_below*config.n_samples
            indices = indices[n:]
            if len(indices) < target_ratio*config.n_samples*(near_target_rank_below-target_rank_below):
                raise ValueError(f'target_ratio*n_samples*(near_target_rank_below-target_rank_below) too big: {len(indices)} available {target_ratio*config.n_samples*(near_target_rank_below-target_rank_below)} needed')
                
            per_worker = int(target_ratio*config.n_samples)
            beg = (rank-target_rank_below) * per_worker
            end = beg + per_worker
            # print(end, len(indices))
            if end > len(indices) - 1:
                raise ValueError('invalid partitioning')
            indices = indices[beg:end]

            mask = self.targets == 0
            for i in range(2, 3):
                mask = torch.logical_or(mask, self.targets == i)
            more_indices = torch.nonzero(mask).squeeze()
            # print((1-target_ratio)*config.n_samples)
            per_worker = config.n_samples - int(target_ratio*config.n_samples)
            beg = (rank-target_rank_below) * per_worker
            end = beg + per_worker
            if end > len(more_indices) - 1:
                raise ValueError('invalid rounding')
            more_indices = more_indices[beg:end]
            # print(beg, end)
            # print(indices)
            # print(len(indices))
            # print(more_indices)
            # print(len(more_indices))
            # print(len(torch.cat((indices, more_indices), 0)))
            indices = torch.cat((indices, more_indices), 0)
            # print(indices)
            # print(len(indices))    
            # return dataset
        else:
            mask = self.targets > 4
            indices = torch.nonzero(mask).squeeze()
            per_worker = config.n_samples
            beg = (rank-target_rank_below) * per_worker
            end = min(beg + per_worker, len(indices) - 1)
            if end > len(indices) - 1:
                raise ValueError('invalid partitioning')
            indices = indices[beg:end]
            # print(indices)
            # print(len(indices))    
            # return dataset
        # print(indices)
        # mask = torch.zeros_like(mask).scatter_(0, indices, 1)
        self.targets = self.targets[indices]
        print(len(self.targets))
        # self.data = self.data[mask]
        # if len(self.targets) != config.n_samples:
        #         raise ValueError('config failed')

    def model_args(self):
        return self.input_dim, self.output_dim
    
    def loss_star(self, full_batch, criterion):
        return criterion(full_batch[1], full_batch[1]).data

NameError: name 'datasets' is not defined

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
@dataclass
class Config():
    n_iters: int = 6000
    n_peers: int = 20
    seed: int = 0

    model: Model = field(default_factory=lambda: Model.Mean)
    loss: Loss = field(default_factory=lambda: Loss.MSE)

    dataset: Dataset = field(default_factory=lambda: Dataset.Normal)
    n_samples: int = 600
    h_ratio: float = 0.58
    mu_normal: float = None

    optimizer: Optimizer = field(default_factory=lambda: Optimizer.SGD)
    batch_size: int = 100
    lr: float = 1e-2

    true_weights: bool = None

    md_n_iters_: int = None
    md_full_: bool = None
    md_lr_: int = None

config = Config()
for i in range(config.n_peers):
    MNIST(config, i, True)