In [1]:
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_partitions = 4

In [25]:
partitioner = DirichletPartitioner(
                num_partitions=num_partitions,
                partition_by="label",
                alpha=0.1,
                min_partition_size=0,
                self_balancing=False,
                seed = None
            )

In [26]:
fds = FederatedDataset(
    dataset="mnist",
    partitioners={"train": partitioner}
)

In [27]:
# Load the partitions
train_partitions = [fds.load_partition(i, split="train") for i in range(num_partitions)]
test_partition = fds.load_split(split="test")

In [6]:
from torchvision.transforms import Compose, Normalize, ToTensor

In [23]:
pytorch_transforms = Compose([
    ToTensor(),
    Normalize((0.5,), (0.5,))
])

def apply_transforms(batch):
    batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
    return batch

train_partitions = [train_partition.with_transform(apply_transforms) for train_partition in train_partitions]
test_partition = test_partition.with_transform(apply_transforms)

In [8]:
from torch.utils.data import DataLoader

In [9]:
trainloaders = [DataLoader(train_partition, batch_size=64, shuffle=True) for train_partition in train_partitions]
testloader = DataLoader(test_partition, batch_size=64)

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
class Net(nn.Module):
    """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [12]:
device = "cpu"
nets = [Net() for _ in train_partitions]

In [14]:
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(nets[0].parameters(), lr=0.001)
for i in range(len(train_partitions)):
    nets[i].to(device)
    nets[i].train()
    for _ in range(2):
        for batch in trainloaders[i]:
            images = batch["image"]
            labels = batch["label"]
            optimizer.zero_grad()
            loss = criterion(nets[i](images.to(device)), labels.to(device))
            loss.backward()
            optimizer.step() 

In [15]:
# net.load_state_dict(torch.load("modelo_alvo_round_3_mnist.pt"))
# net.to(device)

In [16]:
correct, loss = 0, 0.0
accuracies = [0.0 for _ in nets]
losses = [0.0 for _ in nets]
for i, net in enumerate(nets):
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    accuracies[i] = correct / len(testloader.dataset)
    losses[i] = loss / len(testloader)

In [17]:
print(f"Accuracies: {accuracies}")
print(f"Losses: {losses}")

Accuracies: [0.6144, 0.7115, 0.8055, 0.8804]
Losses: [2.515940585713478, 4.819820736623873, 7.1272786028066255, 9.431423957180826]


In [7]:
import numpy as np

In [9]:
round(np.nan, 4)

nan

In [18]:
# Obter o tamanho de cada partição de treino
partition_sizes = [len(partition) for partition in train_partitions]

# Calcular a soma dos tamanhos das partições
total_size = sum(partition_sizes)

# Calcular a média ponderada de accuracies
weighted_accuracy = sum(acc * size for acc, size in zip(accuracies, partition_sizes)) / total_size

# Calcular a média ponderada de losses
weighted_loss = sum(loss * size for loss, size in zip(losses, partition_sizes)) / total_size

print(f"Weighted Accuracy: {weighted_accuracy}")
print(f"Weighted Loss: {weighted_loss}")

Weighted Accuracy: 0.7459578
Weighted Loss: 5.77405835733034
