<a href="https://colab.research.google.com/github/madch3m/Federated-learning-ml-graph/blob/main/federated_learning_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install torch-utils


Collecting torch-utils
  Downloading torch-utils-0.1.2.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-utils
  Building wheel for torch-utils (setup.py) ... [?25l[?25hdone
  Created wheel for torch-utils: filename=torch_utils-0.1.2-py3-none-any.whl size=6188 sha256=a74adcf30899025df19058039b05e29d5f74c9e26dff1831eb3ddc5af5b7b195
  Stored in directory: /root/.cache/pip/wheels/4e/06/32/1d26da91e30177d171ecb60995273ad8709ca2b6ce66ccefa7
Successfully built torch-utils
Installing collected packages: torch-utils
Successfully installed torch-utils-0.1.2


Hyperparams for the federated graph

In [2]:
import random
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple

@dataclass
class HParams:
  num_clients: int = 10
  smple_clients: float = 0.5
  local_epochs: int = 2
  local_batch_size: int = 64
  rounds: int = 10
  lr: float = 0.01
  momentum: float = 0.0
  seed: int = 42
  iid: bool = True
  device: str = 'cpu'

Convolutional Neural Net

In [6]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
from typing import List, Tuple

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Conv2d(1,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32,64,3,padding=1), nn.ReLU(),nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(64 * 7 * 7, 128), nn.ReLU(),
        nn.Linear(128,10)
    )
  def forward(self, x):
      return self.net(x)



Data loading and splitting for clients

In [7]:
hp = HParams()
random.seed(hp.seed)
torch.manual_seed(hp.seed)
def load_data() -> Tuple[List[Subset], torch.utils.data.Dataset]:
    transform = transforms.Compose([transforms.ToTensor()])
    train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    if hp.iid:
        sizes = [len(train) // hp.num_clients] * hp.num_clients

        sizes[-1] += len(train) - sum(sizes)
        shards = random_split(train, sizes, generator=torch.Generator().manual_seed(hp.seed))
        clients = [Subset(train, s.indices) for s in shards]
        return clients, test

    else:
        targets = torch.tensor(train.targets)
        sorted_idx = targets.sort()[1].tolist()
        sorted_ds = Subset(train, sorted_idx)
        sizes = [len(sorted_ds) // hp.num_clients] * hp.num_clients
        sizes[-1] += len(sorted_ds) - sum(sizes)
        shards = []
        start = 0
        for size in sizes:
            idxs = list(range(start, start + size))
            shards.append(Subset(sorted_ds, idxs))
            start += size
        return shards, test



Client Logic

In [None]:
def client_update(global_model: nn.Module, dataset: Subset) -> Tuple[Dict[str, torch.Tensor], int]:
    model = deepcopy(global_model).to(hp.device)
    model.train()
    loader = DataLoader(dataset, batch_size=hp.local_batch_size, shuffle=True, drop_last=False)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=hp.lr, momentum=hp.momentum)

    for _ in range(hp.local_epochs):
        for x, y in loader:
            x, y = x.to(hp.device), y.to(hp.device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

    return deepcopy(model.state_dict(), len(dataset))

Aggregation Logic

In [None]:
@torch.no_grad()
def fedavg(global_model: nn.Module, client_states: List[Tuple[Dict[str,torch.Tensor],int]]):
    total_samples = sum(n for _, n in client_states)

    avg_state = {k: torch.zeros_like(v, device=hp.device) for k, v in global_model.state_dict().items()}

    for state_dict, n in client_states:
        weight = n / total_samples
        for k in avg_state.keys():
          avg_state[k] += state_dict[k].to(hp.device) * weight
    global_model.load_state_dict(avg_state)

In [None]:
@torch.no_grad()
def evaluate(model: nn.Module, testset) -> Tuple[float,float]:
    model.eval().to(hp.device)
    loader = DataLoader(testset, batch_size=512, shuffle=False)
    criterion = nn.CrossEntropyLoss()
    correct, total, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(hp.device), y.to(hp.device)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return correct / total, total_loss / total


Orchestrator

In [None]:
def orchestrate():
    clients, testset = load_data()
    global_model = CNN().to(hp.device)

    for rnd in range(1, hp.rounds + 1):
        m = max(1, int(hp.frac_clients * hp.num_clients))
        selected = random.sample(range(hp.num_clients), m)

        client_states = []
        for cid in selected:
            state, n_samples = client_update(global_model, clients[cid])
            client_states.append((state, n_samples))

        fedavg(global_model, client_states)

        acc, los = evaluate(global_model, testset)


    torch.save(global_model.state_dict(), "mnist_cnn.pt")
    print("Model saved")
