In [3]:
import threading
import time
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
!pip install torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


Collecting torchvision
  Downloading torchvision-0.24.1-cp311-cp311-win_amd64.whl (4.0 MB)
     ---------------------------------------- 4.0/4.0 MB 2.8 MB/s eta 0:00:00
Installing collected packages: torchvision
Successfully installed torchvision-0.24.1



[notice] A new release of pip available: 22.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


In [5]:
def load_data(client_id):
    transform = transforms.Compose([transforms.ToTensor()])
    trainset = datasets.MNIST("./data", train=True, download=True, transform=transform)
    testset = datasets.MNIST("./data", train=False, download=True, transform=transform)

    # 简单切分数据（模拟非 IID）
    indices = list(range(len(trainset)))
    split = len(indices) // 2
    if client_id == 0:
        trainset.data = trainset.data[:split]
        trainset.targets = trainset.targets[:split]
    else:
        trainset.data = trainset.data[split:]
        trainset.targets = trainset.targets[split:]

    return (
        DataLoader(trainset, batch_size=32, shuffle=True),
        DataLoader(testset, batch_size=32),
    )


In [11]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, client_id):
        self.model = Net()
        self.trainloader, self.testloader = load_data(client_id)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
        self.criterion = nn.CrossEntropyLoss()

    def get_parameters(self, config):
        return [v.cpu().numpy() for v in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        params = zip(self.model.state_dict().keys(), parameters)
        self.model.load_state_dict(
            {k: torch.tensor(v) for k, v in params}, strict=True
        )

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()

        for _ in range(1):
            for x, y in self.trainloader:
                self.optimizer.zero_grad()
                loss = self.criterion(self.model(x), y)
                loss.backward()
                self.optimizer.step()

        print(f"[Client {id(self)}] Local training done")

        return self.get_parameters(config), len(self.trainloader.dataset), {}


    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        loss, correct = 0, 0

        with torch.no_grad():
            for x, y in self.testloader:
                out = self.model(x)
                loss += self.criterion(out, y).item()
                correct += (out.argmax(1) == y).sum().item()

        return loss, len(self.testloader.dataset), {
            "accuracy": correct / len(self.testloader.dataset)
        }


In [12]:
def start_server():
    strategy = fl.server.strategy.FedAvg(
        min_fit_clients=2,
        min_available_clients=2,
    )
    fl.server.start_server(
        server_address="129.0.0.1:8080",
        config=fl.server.ServerConfig(num_rounds=3),
        strategy=strategy,
    )

server_thread = threading.Thread(target=start_server)
server_thread.start()

time.sleep(2)  # 等 server 起来


	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=3, no round_timeout


In [None]:
def start_client(client_id):
    fl.client.start_numpy_client(
        server_address="129.0.0.1:8080",
        client=FlowerClient(client_id),
    )

client_threads = []
for cid in [0, 1]:
    t = threading.Thread(target=start_client, args=(cid,))
    t.start()
    client_threads.append(t)


	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode 

: 