# data/utils/partition/dirichlet.py
 * Client's Data Split Code
 * Odd-numbered clients are allocated 10 times more data than even-numbered clients.
 * data_indices : list of integers

In [None]:
from collections import Counter
from typing import Dict, List, Tuple

import numpy as np
from torch.utils.data import Dataset
import math


def dirichlet(
    ori_dataset: Dataset, num_clients: int, alpha: float, least_samples: int
) -> Tuple[List[List[int]], Dict]:
    num_classes = len(ori_dataset.classes)
    min_size = 0
    stats = {}
    partition = {"separation": None, "data_indices": None}

    targets_numpy = np.array(ori_dataset.targets, dtype=np.int32)
    idx = [np.where(targets_numpy == i)[0] for i in range(num_classes)]

    while min_size < least_samples:
        data_indices = [[] for _ in range(num_clients)]
        for k in range(num_classes):
            np.random.shuffle(idx[k])
            distrib = np.random.dirichlet(np.repeat(alpha, num_clients))
            distrib = np.array(
                [
                    p * (len(idx_j) < len(targets_numpy) / num_clients)
                    for p, idx_j in zip(distrib, data_indices)
                ]
            )
            distrib = distrib / distrib.sum()
            distrib = (np.cumsum(distrib) * len(idx[k])).astype(int)[:-1]
            data_indices = [
                np.concatenate((idx_j, idx.tolist())).astype(np.int64)
                for idx_j, idx in zip(data_indices, np.split(idx[k], distrib))
            ]
            min_size = min([len(idx_j) for idx_j in data_indices])
    
    
    data_indices =  [ i[math.floor(len(i)* 0.9 ):] if idx % 2 == 0 else i for idx, i in enumerate(data_indices)]
    with open("file.txt", "w") as f:
        for s in data_indices:
            f.write(str(s) +"\n")
            
    for i in range(num_clients):
        stats[i] = {"x": None, "y": None}
        stats[i]["x"] = len(targets_numpy[data_indices[i]])
        stats[i]["y"] = Counter(targets_numpy[data_indices[i]].tolist())

    num_samples = np.array(list(map(lambda stat_i: stat_i["x"], stats.values())))
    stats["sample per client"] = {
        "std": num_samples.mean(),
        "stddev": num_samples.std(),
    }

    partition["data_indices"] = data_indices

    return partition, stats


# src/config/args.py
 * Newly added parameters 
 * self.args.lmb represents the lambda parameter vector in FedProx

In [None]:

def get_fedavg_argparser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument("-prox_lambda", type=int, default=0)

def get_fedavgm_argparser() -> ArgumentParser:
    parser = get_fedavg_argparser()
    parser.add_argument("--server_momentum", type=float, default=0.9)
    return parser

def get_fedprox_argparser() -> ArgumentParser:
    parser = get_fedavg_argparser()
    parser.add_argument("--mu", type=float, default=1.0)
    parser.add_argument("--lmb", type=list, default=[])
    return parser

# src/client/fedprox.py
 * each client have different lambda


In [None]:
from fedavg import FedAvgClient
from src.config.utils import trainable_params
import numpy as np
import math
import torch 

class FedProxClient(FedAvgClient):
    def __init__(self, model, args, logger):
        super(FedProxClient, self).__init__(model, args, logger)

    def train(self, client_id, new_parameters, verbose=False):
        delta, _, stats = super().train(
            client_id, new_parameters, return_diff=True, verbose=verbose
        )
        self.client_id = client_id

        # FedProx's model aggregation doesn't need weight
        return delta, self.args.lmb[self.client_id], stats

    
    def fit(self):
        self.model.train()
        global_params = [p.clone().detach() for p in trainable_params(self.model)]
        for i in range(self.local_epoch):
            for x, y in self.trainloader:
                if len(x) <= 1:
                    continue

                x, y = x.to(self.device), y.to(self.device)
                logit = self.model(x)
                loss = self.criterion(logit, y)
                self.optimizer.zero_grad()
                loss.backward()
                for w, w_t in zip(trainable_params(self.model), global_params):
                    w.grad.data += self.args.lmb[self.client_id] * (w.data - w_t.data)
                self.optimizer.step()



src/server/fedavg.py
 * The way to assign the lambdas to the clients
 * Defines how fedprox allocates lambdas
    * self.args.prox_lambda == 1 : proportional to the number of clients dataset
    * self.args.prox_lambda == 2 : inversly proportional to the number of clients dataset
    * self.args.prox_lambda == else(usually 0) : 1
 * to enhance the difference between the lambdas, self.args.lmb is set to the l2 norm of the lmb saure.

In [None]:

self.args.lmb = [None] * self.client_num_in_total
## Prox lambda schemes
if self.args.prox_lambda == 1:
    for cid in range(self.client_num_in_total):
        self.args.lmb[cid] = len(partition["data_indices"][cid]["train"])
    self.args.lmb = torch.FloatTensor(self.args.lmb)
    self.args.lmb  = normalize(self.args.lmb**2, p=2.0, dim = 0)
elif self.args.prox_lambda == 2:
    for cid in range(self.client_num_in_total):
        self.args.lmb[cid] = len(partition["data_indices"][cid]["train"])
    self.args.lmb = torch.FloatTensor(self.args.lmb)
    self.args.lmb = 1/(self.args.lmb**2)
    self.args.lmb  = normalize(self.args.lmb, p=2.0, dim = 0)
else:
    for cid in range(self.client_num_in_total):
        self.args.lmb[cid] = 1
    self.args.lmb = torch.FloatTensor(self.args.lmb)


# src/server/fedavg.py
 * save clients stats at @self.table
 * save clients stats


In [None]:
  
self.table = []

def log_info(self):
    for label in ["train", "test"]:
        # In the `user` split, there is no test data held by train clients, so plotting is unnecessary.
        if (label == "train" and self.args.eval_train) or (
            label == "test"
            and self.args.eval_test
            and self.args.dataset_args["split"] != "user"
        ):
            correct_before = torch.tensor(
                [
                    self.client_stats[c][self.current_epoch]["before"][
                        f"{label}_correct"
                    ]
                    for c in self.selected_clients
                ]
            )
            correct_after = torch.tensor(
                [
                    self.client_stats[c][self.current_epoch]["after"][
                        f"{label}_correct"
                    ]
                    for c in self.selected_clients
                ]
            )
            num_samples = torch.tensor(
                [
                    self.client_stats[c][self.current_epoch]["before"][
                        f"{label}_size"
                    ]
                    for c in self.selected_clients
                ]
            )
            acc_before = (
                correct_before.sum(dim=-1, keepdim=True) / num_samples.sum() * 100.0
            ).item()
            acc_after = (
                correct_after.sum(dim=-1, keepdim=True) / num_samples.sum() * 100.0
            ).item()
            self.metrics[f"{label}_before"].append(acc_before)
            self.metrics[f"{label}_after"].append(acc_after)

            for label in ["test"]:
                table = [self.client_stats[c][self.current_epoch]["after"][f"{label}_correct"]/
                        self.client_stats[c][self.current_epoch]["before"][f"{label}_size"] *100
                            if c in self.selected_clients else 0 for c, elem in enumerate([0]*self.client_num_in_total)]
                self.table.append(table)


# src/server/fedavg.py
 * save clients performance figures in self.args.save_metrics

In [None]:
def run(self):

    if self.trainer is None:
        raise RuntimeError(
            "Specify your unique trainer or set `default_trainer` as True."
        )

    if self.args.visible:
        self.viz.close(win=self.viz_win_name)

    self.train()

    self.logger.log(
        "=" * 20, self.algo, "TEST RESULTS:", "=" * 20, self.test_results
    )
    self.check_convergence()

    # save log files
    if not os.path.isdir(OUT_DIR / self.algo) and (
        self.args.save_log or self.args.save_fig or self.args.save_metrics
    ):
        os.makedirs(OUT_DIR / self.algo, exist_ok=True)

    if self.args.save_log:
        self.logger.save_text(OUT_DIR / self.algo / f"{self.args.dataset}_gr{self.args.global_epoch}_le{self.args.local_epoch}_{self.args.model}_{self.args.prox_lambda}_log.html")

    if self.args.save_fig:
        import matplotlib
        from matplotlib import pyplot as plt

        matplotlib.use("Agg")
        linestyle = {
            "test_before": "solid",
            "test_after": "solid",
            "train_before": "dotted",
            "train_after": "dotted",
        }
        for label, acc in self.metrics.items():
            if len(acc) > 0:
                plt.plot(acc, label=label, ls=linestyle[label])
        plt.title(f"{self.algo}_{self.args.dataset}")
        plt.ylim(0, 100)
        plt.xlabel("Communication Rounds")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.savefig(
            OUT_DIR / self.algo / f"{self.args.dataset}_gr{self.args.global_epoch}_le{self.args.local_epoch}_{self.args.model}_{self.args.prox_lambda}.jpeg", bbox_inches="tight"
        )
    if self.args.save_metrics:
        import pandas as pd
        import numpy as np

        accuracies = []
        labels = []
        for label, acc in self.metrics.items():
            if len(acc) > 0:
                accuracies.append(np.array(acc).T)
                labels.append(label)
        pd.DataFrame(np.stack(accuracies, axis=1), columns=labels).to_csv(
            OUT_DIR / self.algo / f"{self.args.dataset}_gr{self.args.global_epoch}_le{self.args.local_epoch}_{self.args.model}_{self.args.prox_lambda}_acc_metrics.csv",
            index=False,
        )
        pd.DataFrame(np.array(self.table[1:])).to_csv(
                OUT_DIR / self.algo / f"{self.args.dataset}_gr{self.args.global_epoch}_le{self.args.local_epoch}_{self.args.model}_{self.args.prox_lambda}_client_acc_metrics.csv",
                index=False,
            )
        
