# Day 35: Federated Learning Simulation

In this lab, we will simulate **Federated Averaging (FedAvg)**.
We will have multiple clients, each with their own local dataset. They will train locally and share only their model updates with the server.

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt

# Add root directory to sys.path
sys.path.append(os.path.abspath('../../'))

from src.privacy.federated import FederatedClient, FederatedServer

## 1. Create Global Model and Clients

We want to learn `y = 2x + 1`.
We split data across 3 clients.

In [None]:
np.random.seed(42)

# True function: y = 2x + 1
# Assume we learn y = w1*x + w0 (bias as feature)

def generate_data(n):
    x = np.random.rand(n, 1) * 10
    y = 2 * x + 1 + np.random.randn(n, 1) * 2 # with noise
    X = np.hstack([x, np.ones((n, 1))]) # Add bias column
    return X, y

client1 = FederatedClient("C1", *generate_data(20))
client2 = FederatedClient("C2", *generate_data(20))
client3 = FederatedClient("C3", *generate_data(20))

# Initial Global Weights [w, b]
global_weights = np.array([[0.0], [0.0]])
server = FederatedServer()

## 2. Federated Training Loop

We run 5 rounds of training.

In [None]:
rounds = 5
clients = [client1, client2, client3]
history = []

print(f"Initial Global Weights: {global_weights.flatten()}")

for r in range(rounds):
    client_updates = []
    
    # 1. Broadcast global weights to clients
    # 2. Clients train locally
    for client in clients:
        # In real world, we'd run multiple epochs locally
        # Here we run enough updates to make progress
        w = global_weights.copy()
        for _ in range(50): # 50 local steps
            w = client.train_epoch(w, learning_rate=0.01)
        client_updates.append(w)
    
    # 3. Server aggregates
    global_weights = server.aggregate(client_updates)
    history.append(global_weights.copy())
    
    print(f"Round {r+1}: Global Weights = {global_weights.flatten()}")

print("\nFinal Weights (Expect approx [2.0, 1.0])")
print(global_weights.flatten())