# Project 2: Implementing FL algorithms [100 pts]
In this project, you have to implement some FL algorithms and the dirichlet distribution based non-iid partition that you learned in the class. Please refer to the paepers below. Be sure to **store** the intermediate outputs generated while running the cells. Although we will rerun the code to evaluate your final score in top-down fashion, missing outputs may lead to grade reductions.

Please submit your iPython file named as follows: `'Project2_AIGS_(student_ID)_(your name).ipynb'`. For example, `Project2_AIGS_20220000_Gildong_Hong.ipynb`.

### Library Imports  
You can import any additional libraries here, but **DO NOT** include external libraries that have a direct implementation FL algorithms.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, Subset, DataLoader
from torchvision import datasets, transforms
from collections import defaultdict
from typing import List, Dict

# import additional libraries here if necessary.

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    #random.seed(seed)
seed_everything(42)

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)

        return x

## Problem 1. Dirichlet distribution based Non-IID partition (20 pt.)

Complete the skeleton code according to the comments.

(1) Generate Dirichlet distribution probabilities

(2) Split indices based on Dirichlet probabilities, Ensure all indices are used

(3) Using .extend function, compose each user's data into user_indices[user_id].

In [2]:
def split_data_iid(dataset: Dataset, num_users: int) -> List[Subset]:
    num_items = len(dataset) // num_users
    indices = np.random.permutation(len(dataset))
    return [Subset(dataset, indices[i * num_items:(i + 1) * num_items]) for i in range(num_users)]

def split_data_non_iid(dataset: Dataset, num_users: int, alpha: float) -> List[Subset]:
    """Split dataset into Non-IID subsets for each user using Dirichlet distribution."""
    num_classes = len(dataset.classes)
    class_indices = defaultdict(list)

    # Group indices by class
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    # Ensure each class is represented by its indices
    class_indices = {k: np.array(v) for k, v in class_indices.items()}

    # (1) Generate Dirichlet distribution probabilities
    dirichlet_dist = np.random.dirichlet([alpha] * num_users, num_classes) # Implement here!

    user_indices = [[] for _ in range(num_users)]
    for class_id, probabilities in enumerate(dirichlet_dist):
        class_idx = class_indices[class_id]
        np.random.shuffle(class_idx)  # Shuffle class indices

        # (2) Split indices based on Dirichlet probabilities, Ensure all indices are used
        split = (probabilities * len(class_idx)).astype(int) # Implement here! 
        split[-1] = len(class_idx) - np.sum(split[:-1]) # Implement here! 

        # (3) Using .extend function, compose each user's data into user_indices[user_id].
        # Implement here!
        start = 0
        for user_id, size in enumerate(split): 
            end = start + size
            user_indices[user_id].extend(class_idx[start:end]) 
            start = end


    # Create Subsets for each user
    return [Subset(dataset, indices) for indices in user_indices]

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

test_loader = DataLoader(cifar10_test, batch_size=128, shuffle=False)

non_iid_data = split_data_non_iid(cifar10_train, 10, 0.5)
for i, user_data in enumerate(non_iid_data):
    labels = [cifar10_train[idx][1] for idx in user_data.indices]
    print(f"User {i}: {len(user_data)} samples, Label distribution: {np.bincount(labels, minlength=10)}")

Files already downloaded and verified
Files already downloaded and verified
User 0: 7353 samples, Label distribution: [ 268  382    0 1124  548  127 1063 2719  352  770]
User 1: 5134 samples, Label distribution: [1262   83   47   82 1303   39    0   31  428 1859]
User 2: 1472 samples, Label distribution: [ 46 202  69  20  12  63 183 756  90  31]
User 3: 4310 samples, Label distribution: [   6   38  214    5  634 1666   15   84 1636   12]
User 4: 8386 samples, Label distribution: [ 719  375 2567  836  666  326 2643  202   52    0]
User 5: 4744 samples, Label distribution: [   0    4 1722    0  230  711  283  643  786  365]
User 6: 8417 samples, Label distribution: [2097 3290  263  616    0 1223  139   12  220  557]
User 7: 2858 samples, Label distribution: [ 63  90   5 832   0 725 260 154   6 723]
User 8: 4108 samples, Label distribution: [ 177  518    1 1380   83   76  409    2 1335  127]
User 9: 3218 samples, Label distribution: [ 362   18  112  105 1524   44    5  397   95  556]


## Problem 2~4. FL Algorithms

Please implement each algorithm by referring to the papers presented below. 

And explain those algorithms, including the formula. **Write it short, focusing on the differences from FedAvg.**

Reach the highest performance in a given federated learning situation. **You can freely set up hyperparameters.**

Even if your performances are not too high, we will also consider the completeness of your code when we grade your project.


1. **FedAvgM** (25 pt.) -> You only need to modify the Server class.
2. **FedProx** (25 pt.) -> You only need to modify the User class.
3. **SCAFFOLD** (30 pt.) -> Use the given variables to implement if you want. There are no restrictions on code modification when implementing this algorithm.

**Paper List**

[1] FedAvgM - https://arxiv.org/abs/1909.06335

[2] FedProx - https://arxiv.org/pdf/1812.06127

[3] SCAFFOLD - https://arxiv.org/pdf/1910.06378

### FedAvgM (Explanation 10 pt. Code 15 pt.)

In [37]:
class ServerFedAvgM:
    def __init__(self, model: nn.Module, users: List[User], momentum: float = 0.9, lr: float = 1.0):
        self.global_model = model
        self.users = users
        self.momentum = momentum
        self.lr = lr
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.global_model.to(self.device)

        self.momentum_buffer = {name: torch.zeros_like(param) for name, param in self.global_model.state_dict().items()}

    def aggregate(self, updates: List[Dict[str, torch.Tensor]]):
        avg_update = {name: torch.zeros_like(param) for name, param in updates[0].items()}

        for update in updates:
            for name, param in update.items():
                avg_update[name] += param / len(updates)

        for name in avg_update.keys():
            self.momentum_buffer[name] = (
                self.momentum * self.momentum_buffer[name] + avg_update[name]
            )
            avg_update[name] = self.momentum_buffer[name]

        for name, param in self.global_model.state_dict().items():
            param -= self.lr * avg_update[name]

(Explain the FedAvgM algorithm here.)

## FedAvgM (Federated Averaging with Momentum)

FedAvgM extends the FedAvg algorithm by incorporating momentum into the global model updates. This helps smooth the updates and improves convergence speed and stability by considering the history of updates.

### Update Formula:

$m_{t+1} = \beta \cdot m_t + \frac{1}{K} \sum_{k=1}^{K} g_k^t$

Where:
- $m_t$: Momentum buffer at round $t$
- $\beta$: Momentum coefficient
- $g_k^t$: Local update from client $k$
- $K$: Number of selected clients

The global model is updated as: $w_{t+1} = w_t - \eta \cdot m_{t+1}$

Where:
- $\eta$: Learning rate

### Key Differences from FedAvg:

1. **Momentum Buffer**: FedAvgM introduces a momentum buffer $m_t$ that accumulates past gradients, allowing the algorithm to maintain a "velocity" in parameter space.

2. **Update Smoothing**: The momentum term $\beta \cdot m_t$ helps smooth out oscillations in the optimization trajectory, potentially leading to faster convergence.

3. **Historical Information**: By incorporating information from previous rounds, FedAvgM can overcome local optima and saddle points more effectively than FedAvg.

4. **Adaptive Step Sizes**: The effective step size in parameter space can be larger for dimensions with consistent gradients across rounds, accelerating progress in those directions.

5. **Hyperparameter Sensitivity**: FedAvgM introduces an additional hyperparameter $\beta$ (momentum coefficient), which requires tuning but can lead to improved performance when set correctly.

6. **Computational Overhead**: FedAvgM requires maintaining and updating the momentum buffer, slightly increasing the computational and memory requirements compared to FedAvg.

These differences allow FedAvgM to potentially achieve better convergence and generalization performance, especially in scenarios with non-IID data distributions or when dealing with complex optimization landscapes.

### FedProx (Explanation 10 pt. Code 15 pt.)

In [38]:
class UserFedProx:
    def __init__(self, user_id: int, model: nn.Module, data: Dataset, lr: float, mu: float):
        self.user_id = user_id
        self.model = model
        self.data = data
        self.lr = lr
        self.mu = mu
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

    def train(self, global_model: nn.Module, epochs: int):
        self.model.train()
        global_params = {name: param.clone().detach() for name, param in global_model.state_dict().items()}
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

        train_loader = DataLoader(self.data, batch_size=32, shuffle=True)
        total_loss = 0.0

        for epoch in range(epochs):
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)

                prox_term = 0.5 * self.mu * sum(
                    torch.norm(param - global_params[name]) ** 2
                    for name, param in self.model.state_dict().items()
                )
                loss += prox_term
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

        avg_loss = total_loss / (len(train_loader) * epochs)
        return self.model.state_dict(), avg_loss


(Explain the FedProx algorithm here.)

## FedProx (Federated Proximal)

FedProx is an extension of the FedAvg algorithm that adds a proximity term to the local loss function. This term penalizes the divergence of local models from the global model, ensuring stability in non-IID data scenarios.

### Loss Function:

$L_{\text{FedProx}} = L_{\text{CE}} + \frac{\mu}{2} \sum_{i} | w_i - w_{t,i} |^2$

Where:
- $L_{\text{CE}}$: Cross-entropy loss
- $w_t$: Global model parameters at round $t$
- $w$: Local model parameters
- $\mu$: Proximity term coefficient

### Key Differences from FedAvg:

1. **Proximity Term**: FedProx introduces a regularization term $\frac{\mu}{2} \sum_{i} | w_i - w_{t,i} |^2$ that penalizes large deviations of local model parameters from the global model parameters. This helps maintain consistency across clients.

2. **Handling Non-IID Data**: The proximity term stabilizes training in non-IID data settings by discouraging local models from straying too far from the global model, addressing one of the main challenges in federated learning.

3. **Regularization Effect**: The added penalty acts as a form of regularization, which can improve generalization by preventing overfitting to local data distributions.

4. **Hyperparameter $\mu$**: The proximity term introduces a new hyperparameter $\mu$, which controls the strength of the regularization. Proper tuning of this parameter is crucial for balancing convergence speed and stability.

5. **Local Training Dynamics**: Unlike FedAvg, where local updates can diverge significantly, FedProx ensures that updates are more aligned with the global objective, potentially leading to more stable and faster convergence.

These differences make FedProx particularly suitable for federated learning scenarios where data heterogeneity is a significant concern.

### SCAFFOLD (Explanation 15 pt. Code 15 pt.)

In [39]:
class UserScaffold:
    def __init__(self, user_id: int, model: nn.Module, data: Dataset, lr: float):
        self.user_id = user_id
        self.model = model
        self.data = data
        self.lr = lr
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.local_control = {name: torch.zeros_like(param) for name, param in self.model.state_dict().items()}

    def train(self, global_model: nn.Module, global_control: Dict[str, torch.Tensor], epochs: int):
        self.model.train()
        global_params = {name: param.clone().detach() for name, param in global_model.state_dict().items()}
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

        train_loader = DataLoader(self.data, batch_size=32, shuffle=True)
        total_loss = 0.0

        for epoch in range(epochs):
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()

                # Adjust gradients using control variates
                for name, param in self.model.named_parameters():
                    if param.grad is not None:
                        param.grad += self.local_control[name] - global_control[name]
                
                optimizer.step()
                total_loss += loss.item()

        delta_control = {
            name: global_control[name] - self.local_control[name]
            for name in self.local_control.keys()
        }
        self.local_control = delta_control

        avg_loss = total_loss / (len(train_loader) * epochs)
        return self.model.state_dict(), delta_control, avg_loss

class ServerScaffold:
    def __init__(self, model: nn.Module, users: List[User]):
        self.global_model = model
        self.users = users
        self.global_control = {name: torch.zeros_like(param) for name, param in self.global_model.state_dict().items()}

    def aggregate(self, updates: List[Dict[str, torch.Tensor]], delta_controls: List[Dict[str, torch.Tensor]]):
        avg_update = {name: torch.zeros_like(param) for name, param in updates[0].items()}
        for update in updates:
            for name, param in update.items():
                avg_update[name] += param / len(updates)

        for name, param in self.global_model.state_dict().items():
            param.data += avg_update[name]

        avg_delta_control = {name: torch.zeros_like(param) for name, param in delta_controls[0].items()}
        for delta_control in delta_controls:
            for name, param in delta_control.items():
                avg_delta_control[name] += param / len(delta_controls)
        for name in self.global_control.keys():
            self.global_control[name] += avg_delta_control[name]


(Explain the SCAFFOLD algorithm here.)

## SCAFFOLD (Stochastic Controlled Averaging for Federated Learning)

SCAFFOLD mitigates the effect of client drift in non-IID data by introducing control variates. These variates adjust local updates, reducing the variance caused by differences in client data distributions.

### Local Update Formula:

$w_k^{t+1} = w_k^t - \eta \cdot \left( \nabla L + c_k - c \right)$

Where:
- $c_k$: Local control variate for client $k$
- $c$: Global control variate
- $\nabla L$: Gradient of the loss function
- $\eta$: Learning rate

### Global Update Formula:

$c^{t+1} = c^t + \frac{1}{K} \sum_{k=1}^{K} \left( c_k - c^t \right)$

Where:
- $K$: Number of selected clients

### Key Differences from FedAvg:

1. **Control Variates**: SCAFFOLD introduces local and global control variates ($c_k$ and $c$) to adjust the gradients during local training. This helps align local updates more closely with the global objective.

2. **Variance Reduction**: By using control variates, SCAFFOLD reduces the variance in updates caused by non-IID data distributions, leading to more stable and consistent convergence.

3. **Mitigating Client Drift**: The use of control variates helps mitigate client drift, a common issue in federated learning where local models diverge significantly from the global model due to data heterogeneity.

4. **Convergence Speed**: SCAFFOLD can achieve faster convergence compared to FedAvg by ensuring that updates are better aligned with the global model, reducing the need for frequent communication rounds.

5. **Additional Overhead**: The algorithm requires maintaining and updating control variates, which introduces additional computational and memory overhead compared to FedAvg.

These differences make SCAFFOLD particularly effective in federated learning scenarios with highly non-IID data distributions, where traditional methods like FedAvg struggle to maintain model performance.

In [47]:
class User:
    def __init__(self, user_id: int, model: nn.Module, data: Dataset, lr: float, method: str, mu: float):
        self.user_id = user_id
        self.model = model
        self.data = data
        self.lr = lr
        self.method = method
        self.mu = mu
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        self.c_local = {name: torch.zeros_like(param) for name, param in self.model.state_dict().items()}

    def train(self, global_model: nn.Module, epochs: int, global_control: Dict[str, torch.Tensor] = {"None": torch.Tensor(0)}):
        """Train the local model."""
        self.model.load_state_dict(global_model.state_dict())
        self.model.train()
        global_params = {name: param.clone().detach() for name, param in global_model.state_dict().items()}
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.model.parameters(), lr=self.lr)

        train_loader = DataLoader(self.data, batch_size=32, shuffle=True)
        total_loss = 0.0

        # You can use this variable to implement SCAFFOLD.
        delta_c_local = {name: torch.zeros_like(param) for name, param in self.c_local.items()}

        for epoch in range(epochs):
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)

                # (2) FedProx
                if self.method == 'FedProx':
                    prox_term = 0.5 * self.mu * sum(
                        torch.norm(param - global_params[name]) ** 2
                        for name, param in self.model.state_dict().items()
                    )
                    loss += prox_term


                # (3)SCAFFOLD
                elif self.method == 'SCAFFOLD':
                    for name, param in self.model.named_parameters():
                        if param.grad is not None:
                            param.grad += self.c_local[name] - global_control[name]

                loss.backward()
                optimizer.step()
                total_loss += loss.item()

        # Update delta_c_local for SCAFFOLD
        if self.method == 'SCAFFOLD':
            for name in delta_c_local.keys():
                delta_c_local[name] = global_control[name] - self.c_local[name]
                self.c_local[name] += delta_c_local[name]
                
        avg_loss = total_loss / (len(train_loader) * epochs)
        
        return self.model.state_dict(), delta_c_local, avg_loss

# Define the Server class
class Server:
    def __init__(self, model: nn.Module, users: List[User], method: str, momentum: float):
        self.global_model = model
        self.users = users
        self.method = method
        self.momentum = momentum
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.global_model.to(self.device)

        # FedAvgM momentum state
        self.momentum_buffer = {name: torch.zeros_like(param) for name, param in self.global_model.state_dict().items()}

        # SCAFFOLD c_global
        self.c_global = {name: torch.zeros_like(param) for name, param in self.global_model.state_dict().items()}


    def select_users(self, fraction: float):
        num_selected = max(1, int(fraction * len(self.users)))
        return np.random.choice(self.users, num_selected, replace=False)

    def aggregate(self, updates: List[Dict[str, torch.Tensor]], delta_cs: List[Dict[str, torch.Tensor]] = [{"None": torch.Tensor(0)}]):

        avg_update = {name: torch.zeros_like(param) for name, param in updates[0].items()}

        if self.method in ['FedAvg', 'FedAvgM']:
            for update in updates:
                for name, param in update.items():
                    avg_update[name] += param / len(updates)

            # (1) FedAvgM
            if self.method == 'FedAvgM':
                for name in avg_update.keys():
                    self.momentum_buffer[name] = (
                        self.momentum * self.momentum_buffer[name] + avg_update[name]
                    )
                    avg_update[name] = self.momentum_buffer[name]
            
                for name, param in self.global_model.state_dict().items():
                    param.data -= self.momentum_buffer[name]
                
            self.global_model.load_state_dict(avg_update)

        elif self.method == 'FedProx':
            # FedProx uses FedAvg-style aggregation
            for update in updates:
                for name, param in update.items():
                    avg_update[name] += param / len(updates)
            self.global_model.load_state_dict(avg_update)

        elif self.method == 'SCAFFOLD':
            for update in updates:
                for name, param in update.items():
                    avg_update[name] += param / len(updates)

            self.global_model.load_state_dict(avg_update)

            # Update c_global for SCAFFOLD
            if self.method == 'SCAFFOLD' and delta_cs is not None:
                # (3-3) Implement SCAFFOLD.
                avg_delta_c = {name: torch.zeros_like(param) for name, param in delta_cs[0].items()}
                for delta_c in delta_cs:
                    for name, param in delta_c.items():
                        avg_delta_c[name] += param / len(delta_cs)
            
                for name in self.c_global.keys():
                    self.c_global[name] += avg_delta_c[name]

    def evaluate(self, test_loader: DataLoader):
        self.global_model.eval()
        correct, total = 0, 0
        criterion = nn.CrossEntropyLoss()
        total_loss = 0.0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.global_model(images)
                loss = criterion(outputs, labels)
                total_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        avg_loss = total_loss / len(test_loader)
        return accuracy, avg_loss


def federated_learning(
    server: Server, rounds: int, epochs: int, fraction: float, test_loader: DataLoader
):
    for round_num in range(rounds):
        print(f"Round {round_num + 1}/{rounds} starting...")
        selected_users = server.select_users(fraction)
        updates = []
        delta_cs = []  # For SCAFFOLD

        local_losses = []
        for user in selected_users:
            if server.method == 'SCAFFOLD':
                local_update, delta_c, local_loss = user.train(server.global_model, epochs, server.c_global)
                delta_cs.append(delta_c)
            else:
                local_update, _, local_loss = user.train(server.global_model, epochs)
            updates.append(local_update)
            local_losses.append(local_loss)

        avg_local_loss = sum(local_losses) / len(local_losses)

        if server.method == 'SCAFFOLD':
            server.aggregate(updates, delta_cs)
        else:
            server.aggregate(updates)
        accuracy, global_loss = server.evaluate(test_loader)
        print(f"Round {round_num + 1}/{rounds} completed.")
        print(f" Avg Local Loss: {avg_local_loss:.4f}")
        print(f" Global Loss: {global_loss:.4f}, Accuracy: {accuracy:.2f}%")

def federated_learning_with_scenario(
    server_class,
    user_class,
    global_model_class,
    dataset: Dataset,
    num_users: int,
    rounds: int,
    epochs: int,
    fraction: float,
    iid: bool,
    lr: float,
    alpha: float,
    test_loader: DataLoader,
    momentum: float,
    method: str,
    mu: float,
):

    if iid:
        user_data = split_data_iid(dataset, num_users)
        print(f"Data split: IID with {num_users} users.")
    else:
        user_data = split_data_non_iid(dataset, num_users, alpha)
        print(f"Data split: Non-IID with {num_users} users (alpha={alpha}).")

    users = [
        user_class(user_id=i, model=global_model_class(), data=user_data[i], lr=lr, method=method, mu=mu)
        for i in range(num_users)
    ]

    global_model = global_model_class()
    server = server_class(model=global_model, users=users, method=method, momentum=momentum)

    federated_learning(server, rounds, epochs, fraction, test_loader)

In [48]:
num_users = 20
fraction = 0.4
rounds = 200
iid = False
alpha = 0.5

In [67]:
epochs = 2
learning_rate = 0.0002
momentum = 0.01
mu = 0
method = 'FedAvgM'

federated_learning_with_scenario(server_class=Server, user_class=User, global_model_class=SimpleCNN, dataset=cifar10_train, test_loader=test_loader,
        num_users=num_users,
        rounds=rounds,
        epochs=epochs,
        fraction=fraction,
        iid=iid,
        lr=learning_rate,
        alpha=alpha,
        method=method,
        momentum=momentum,
        mu=mu,
    )

Data split: Non-IID with 20 users (alpha=0.5).
Round 1/200 starting...
Round 1/200 completed.
 Avg Local Loss: 2.2681
 Global Loss: 2.3048, Accuracy: 10.00%
Round 2/200 starting...
Round 2/200 completed.
 Avg Local Loss: 2.2799
 Global Loss: 2.3028, Accuracy: 10.91%
Round 3/200 starting...
Round 3/200 completed.
 Avg Local Loss: 2.2685
 Global Loss: 2.3013, Accuracy: 13.13%
Round 4/200 starting...
Round 4/200 completed.
 Avg Local Loss: 2.2730
 Global Loss: 2.2995, Accuracy: 11.00%
Round 5/200 starting...
Round 5/200 completed.
 Avg Local Loss: 2.2588
 Global Loss: 2.2983, Accuracy: 11.21%
Round 6/200 starting...
Round 6/200 completed.
 Avg Local Loss: 2.2831
 Global Loss: 2.2942, Accuracy: 12.84%
Round 7/200 starting...
Round 7/200 completed.
 Avg Local Loss: 2.2605
 Global Loss: 2.2922, Accuracy: 13.10%
Round 8/200 starting...
Round 8/200 completed.
 Avg Local Loss: 2.2413
 Global Loss: 2.2905, Accuracy: 12.01%
Round 9/200 starting...
Round 9/200 completed.
 Avg Local Loss: 2.2284
 G

In [68]:
epochs = 10
learning_rate = 0.01
momentum = 0
mu = 0.1
method = 'FedProx'

federated_learning_with_scenario(server_class=Server, user_class=User, global_model_class=SimpleCNN, dataset=cifar10_train, test_loader=test_loader,
        num_users=num_users,
        rounds=rounds,
        epochs=epochs,
        fraction=fraction,
        iid=iid,
        lr=learning_rate,
        alpha=alpha,
        method=method,
        momentum=momentum,
        mu=mu,
    )

Data split: Non-IID with 20 users (alpha=0.5).
Round 1/200 starting...
Round 1/200 completed.
 Avg Local Loss: 1.4621
 Global Loss: 2.0178, Accuracy: 28.70%
Round 2/200 starting...
Round 2/200 completed.
 Avg Local Loss: 1.1767
 Global Loss: 1.7090, Accuracy: 37.24%
Round 3/200 starting...
Round 3/200 completed.
 Avg Local Loss: 1.0290
 Global Loss: 1.5668, Accuracy: 42.21%
Round 4/200 starting...
Round 4/200 completed.
 Avg Local Loss: 0.9361
 Global Loss: 1.4691, Accuracy: 49.20%
Round 5/200 starting...
Round 5/200 completed.
 Avg Local Loss: 0.8172
 Global Loss: 1.3487, Accuracy: 51.28%
Round 6/200 starting...
Round 6/200 completed.
 Avg Local Loss: 0.7572
 Global Loss: 1.3543, Accuracy: 52.16%
Round 7/200 starting...
Round 7/200 completed.
 Avg Local Loss: 0.6613
 Global Loss: 1.3462, Accuracy: 54.17%
Round 8/200 starting...
Round 8/200 completed.
 Avg Local Loss: 0.6710
 Global Loss: 1.2918, Accuracy: 55.27%
Round 9/200 starting...
Round 9/200 completed.
 Avg Local Loss: 0.6146
 G

In [60]:
epochs = 5
learning_rate = 0.01
momentum = 0
mu = 0
method = 'SCAFFOLD'

federated_learning_with_scenario(server_class=Server, user_class=User, global_model_class=SimpleCNN, dataset=cifar10_train, test_loader=test_loader,
        num_users=num_users,
        rounds=rounds,
        epochs=epochs,
        fraction=fraction,
        iid=iid,
        lr=learning_rate,
        alpha=alpha,
        method=method,
        momentum=momentum,
        mu=mu,
    )

Data split: Non-IID with 20 users (alpha=0.5).
Round 1/200 starting...
Round 1/200 completed.
 Avg Local Loss: 1.5947
 Global Loss: 2.1825, Accuracy: 21.14%
Round 2/200 starting...
Round 2/200 completed.
 Avg Local Loss: 1.4982
 Global Loss: 1.8738, Accuracy: 32.38%
Round 3/200 starting...
Round 3/200 completed.
 Avg Local Loss: 1.3265
 Global Loss: 1.7501, Accuracy: 35.79%
Round 4/200 starting...
Round 4/200 completed.
 Avg Local Loss: 1.2061
 Global Loss: 1.6226, Accuracy: 41.29%
Round 5/200 starting...
Round 5/200 completed.
 Avg Local Loss: 1.1866
 Global Loss: 1.5286, Accuracy: 45.58%
Round 6/200 starting...
Round 6/200 completed.
 Avg Local Loss: 1.1498
 Global Loss: 1.5172, Accuracy: 46.07%
Round 7/200 starting...
Round 7/200 completed.
 Avg Local Loss: 1.0214
 Global Loss: 1.4465, Accuracy: 49.17%
Round 8/200 starting...
Round 8/200 completed.
 Avg Local Loss: 1.0016
 Global Loss: 1.4067, Accuracy: 50.10%
Round 9/200 starting...
Round 9/200 completed.
 Avg Local Loss: 0.9183
 G